最近由于需求,需要重载Keras的Model类,代码逻辑是好好的,但是最后运行的时候出现了NoImplementError这个错误,现实的是self.compute_output_shape没有在子类当中实现。代码如下:

from keras import Model
class ACModel(Model):
    '''
    Comment model class for actor and critic model
    '''
    def __init__(self,state_size,action_size):
        super().__init__()
        self.state_siz=state_size
        self.action_size=action_size
        self.dense_1=Dense(100,activation='relu')
        self.policy_logits=Dense(self.action_size,activation='softmax')
        # output of probabilities of actions in disceate space
        self.dense_2=Dense(100,activation='relu')
        self.value=Dense(1) # output of value by critic-net according action

    def call(self, inputs, mask=None):
        '''over write call() method in Model class.
            must handle inputs'''
        x=self.dense_1(inputs) # input is states,
        # this layer(dense_1) is shared layer in action and critic network
        logits=self.policy_logits(x) # output actions probabilities
        x=self.dense_2(x)
        value=self.value(x) # value produced by critic
        return logits, value
model=ACModel(4,1)
res=model(tf.constant([[1,1,1,1]],dtype=tf.float32))
#最后会抛出NoImplementError错误

查看keras以及别人的实现,发现很多都是继承的tf.keras.Model这个类。按理来说是不应该的,Keras本身就是从tf当中高度集成的,但是实验之后发现真的有不一样。改变后的代码如下

import tensorflow as tf
class ACModel(tf.keras.Model):
    '''
    Comment model class for actor and critic model
    '''
    def __init__(self,state_size,action_size):
        super().__init__()
        self.state_siz=state_size
        self.action_size=action_size
        self.dense_1=Dense(100,activation='relu')
        self.policy_logits=Dense(self.action_size,activation='softmax')
        # output of probabilities of actions in disceate space
        self.dense_2=Dense(100,activation='relu')
        self.value=Dense(1) # output of value by critic-net according action

    def call(self, inputs):
        '''over write call() method in Model class.
            must handle inputs'''
        x=self.dense_1(inputs) # input is states,
        # this layer(dense_1) is shared layer in action and critic network
        logits=self.policy_logits(x) # output actions probabilities
        x=self.dense_2(x)
        value=self.value(x) # value produced by critic
        return logits, value
model=ACModel(4,1)
res=model(tf.constant([[1,1,1,1]],dtype=tf.float32))

这下不会报错了,但其实仔细查看文档还是可以发现区别的。keras.Model当中的call方法的函数签名是call(self, inputs, mask=None),而tf.keras.Model当中call的签名是call(self, input)。这个原因可能是tensorflow和keras的版本没有统一的原因。
PS. tensorflow的版本是2.1.0,keras的版本是2.3.1。

类似文章