Pure Soul

  1. 首页
  2. 算法
  3. 正文

关于实现keras.Model时,trainable_weights为空的解决办法

2020年10月6日 1498点热度 0人点赞 0条评论

在搭建自己的Model的时候,我们有时候需要自己从稍微低层的部分进行搭建,而不是直接用Sequential搭建模型,或者是使用Model(inputs,outputs)的方式搭建,例如下面这个简单的例子:

# from tensorflow.python import keras
# from tensorflow.keras.layers import Dense,Input
# 直接使用keras或者是从tensorflow当中导入keras,两种方式二选一
import keras
from keras.layers import Dense,Input
class ActorCriticSharedModel(keras.Model):
    '''
    Comment model class 
    '''

    def __init__(self, state_size, action_size):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.dense_1 = Dense(100, activation='relu')
        self.policy_logits = Dense(self.action_size)
        self.dense_2 = Dense(100, activation='relu')
        self.value = Dense(1)  # output of value by critic-net according action

    def call(self, inputs, training=None, mask=None):
        '''over write call() method in Model class.
            must handle inputs'''
        x = self.dense_1(inputs)  # input is states,
        # this layer(dense_1) is shared layer in action and critic network
        logits = self.policy_logits(x)  # softmax(logits) is probabilities of output actions
        x = self.dense_2(inputs)
        value = self.value(x)  # value produced by critic
        return logits, value  # logits--without softmax activation

对__init__以及call的方法进行重载之后,就可以很简单的进行搭建模型,如下所示:
model=ActorCriticSharedModel(5,1)
搭建一个简单的5输入,1输出的网络,但是查看网络的可训练参数model.trainableweights是会显示空list的。其中的原因是因为在搭建网络时,没有执行参数初始化的操作。此时只需要简单的使用随机数初始化一下就可以,例如:model(tf.convert_to_tensor(np.random.random((1, 5),dtype=tf.float32))。此时再去查看可训练参数model.trainable_weights就是显示正常。
换而言之,如果自己采用的是Sequential或者说是Model(inputs,outputs)的方式搭建的模型,那么模型就会自动的执行参数的初始化,直接查看model.trainable_weights就可以。

标签: 暂无
最后更新:2020年12月28日

ycq

这个人很懒,什么都没留下

点赞
< 上一篇
下一篇 >

COPYRIGHT © 2021 oo2ee.com. ALL RIGHTS RESERVED.

THEME KRATOS MADE BY VTROIS