好的,以下是改进后的完整代码,包括数据预处理、模型构建、损失函数定义、训练与评估以及参数调整与实验:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
import matplotlib.pyplot as plt
# 数据预处理
def load_preprocess_mnist():
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 归一化
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
# 扩展维度
x_train = np.expand_dims(x_train, -1) # [num_samples, 28, 28, 1]
x_test = np.expand_dims(x_test, -1)
# 独热编码
y_train = keras.utils.to_categorical(y_train.astype('float32'))
y_test = keras.utils.to_categorical(y_test.astype('float32'))
return (x_train, y_train), (x_test, y_test)
# 激活函数:squash
def squash(vectors, axis=-1):
s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + 1e-9)
return scale * vectors
# 胶囊层
class CapsuleLayer(layers.Layer):
"""
胶囊层,包含动态路由算法。
"""
def __init__(self, num_capsule, dim_capsule, routings=3, kernel_initializer='glorot_uniform', **kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsule = num_capsule # 输出胶囊数量(如10)
self.dim_capsule = dim_capsule # 每个胶囊的维度(如16)
self.routings = routings # 动态路由的迭代次数
self.kernel_initializer = keras.initializers.get(kernel_initializer)
def build(self, input_shape):
# input_shape: [batch_size, input_num_capsule, input_dim_capsule]
self.input_num_capsule = input_shape[1]
self.input_dim_capsule = input_shape[2]
# 初始化权重矩阵 W
self.W = self.add_weight(shape=[1, self.num_capsule, self.input_num_capsule,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
super(CapsuleLayer, self).build(input_shape)
def call(self, inputs):
# inputs.shape=[batch_size, input_num_capsule, input_dim_capsule]
batch_size = tf.shape(inputs)[0]
# Expand W to [batch_size, num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1])
# Expand inputs to [batch_size, 1, input_num_capsule, input_dim_capsule, 1]
inputs_expand = tf.expand_dims(inputs, 1)
inputs_expand = tf.expand_dims(inputs_expand, -1)
# Compute u_hat = W * u, using batch matrix multiplication
# [batch_size, num_capsule, input_num_capsule, dim_capsule, 1]
u_hat = tf.matmul(W_tiled, inputs_expand)
u_hat = tf.squeeze(u_hat, -1) # [batch_size, num_capsule, input_num_capsule, dim_capsule]
# 初始化路由权重 b
b = tf.zeros(shape=[batch_size, self.num_capsule, self.input_num_capsule], dtype=tf.float32)
# 动态路由算法
for i in range(self.routings):
# c = softmax(b) along num_capsule axis
c = tf.nn.softmax(b, axis=1) # [batch_size, num_capsule, input_num_capsule]
# Expand c for multiplication
c_expanded = tf.expand_dims(c, -1) # [batch_size, num_capsule, input_num_capsule, 1]
# s = sum(c * u_hat, axis=2)
s = tf.reduce_sum(c_expanded * u_hat, axis=2, keepdims=True) # [batch_size, num_capsule, 1, dim_capsule]
# v = squash(s)
v = squash(s, axis=-1) # [batch_size, num_capsule, 1, dim_capsule]
if i < self.routings - 1