当莎士比亚遇见Google Flax:教你用字符级语言模型和归递神经网络写“莎士比亚”式句子
有些人生来伟大,有些人成就伟大,而另一些人则拥有伟大。—— 威廉·莎士比亚《第十二夜》
递归神经网络
一对一是典型CNN或多层感知器,一个输入向量映射到一个输出向量。
一对多是用于图像字幕的RNN体系结构。输入是图像,输出是描述图像的单词序列。
多对多:第一种体系结构利用输入序列到输出序列进行机器翻译,如(德语译成英语)。第二个是适用于帧级别的视频字幕。
EDWARD: Tis even so; yet you are Warwick still. GLOUCESTER: Come, Warwick, take the time; kneel down, kneel down: Nay, when? strike now, or else the iron cools.
pip install -q git+https://github.com/google/flax.git@master
因为训练任务非常艰巨,你应该使用具有GPU支持的运行。你可以使用以下命令测试是否存在GPU支持:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
"""LSTM"""
def apply(self, carry, inputs):
carry1, outputs = jax_utils.scan_in_dim(
nn.LSTMCell.partial(name='lstm1'), carry[0], inputs, axis=1)
carry2, outputs = jax_utils.scan_in_dim(
nn.LSTMCell.partial(name='lstm2'), carry[1], outputs, axis=1)
carry3, outputs = jax_utils.scan_in_dim(
nn.LSTMCell.partial(name='lstm3'), carry[2], outputs, axis=1)
x = nn.Dense(outputs, features=params['vocab_length'], name='dense')
return [carry1, carry2, carry3], x
"""Char Generator"""
def apply(self, inputs, carry_pred=None, train=True):
batch_size = params['batch_size']
vocab_size = params['vocab_length']
hidden_size = 512
if train:
carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
carry = [carry1, carry2, carry3]
_, x = RNN(carry, inputs)
return x
else:
carry, x = RNN(carry_pred, inputs)
return carry, x
训练模型,我们要学习如何预测。
预测模型,实际上在这里我们采样一些文本。
def create_model(rng):
"""Creates a model."""
vocab_size = params['vocab_length']
_, initial_params = charRNN.init_by_shape(
rng, [((1, params['seq_length'], vocab_size), jnp.float32)])
model = nn.Model(charRNN, initial_params)
return model
def create_optimizer(model, learning_rate):
"""Creates an Adam optimizer for model."""
optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=1e-1)
optimizer = optimizer_def.create(model)
return optimizer
@jax.jit
def train_step(optimizer, batch):
"""Train one step."""
def loss_fn(model):
"""Compute cross-entropy loss and predict logits of the current batch"""
logits = model(batch[0])
loss = jnp.mean(cross_entropy_loss(logits, batch[1])) / params['batch_size']
return loss, logits
def exponential_decay(steps):
"""Decrease the learning rate every 5 epochs"""
x_decay = (steps / params['step_decay']).astype('int32')
ret = params['learning_rate']* jax.lax.pow((params['learning_rate_decay']), x_decay.astype('float32'))
return jnp.asarray(ret, dtype=jnp.float32)
current_step = optimizer.state.step
new_lr = exponential_decay(current_step)
# calculate and apply the gradient
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
new_optimizer = optimizer.apply_gradient(grad, learning_rate=new_lr)
metrics = compute_metrics(logits, batch[1])
metrics['learning_rate'] = new_lr
return new_optimizer, metrics
@jax.vmap
def cross_entropy_loss(logits, labels):
"""Returns cross-entropy loss."""
return -jnp.mean(jnp.sum(nn.log_softmax(logits) * labels))
@jax.jit
def sample(inputs, optimizer):
next_inputs = inputs
output = []
batch_size = 1
carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
carry = [carry1, carry2, carry3]
def inference(model, carry):
carry, rnn_output = model(inputs=next_inputs, train=False, carry_pred=carry)
return carry, rnn_output
for i in range(200):
carry, rnn_output = inference(optimizer.target, carry)
output.append(jnp.argmax(rnn_output, axis=-1))
# Select the argmax as the next input.
next_inputs = jnp.expand_dims(common_utils.onehot(jnp.argmax(rnn_output), params['vocab_length']), axis=0)
return output
def train_model():
"""Train and inference """
rng = jax.random.PRNGKey(0)
model = create_model(rng)
optimizer = create_optimizer(model, params['learning_rate'])
del model
for epoch in range(100):
for text in tfds.as_numpy(ds):
optimizer, metrics = train_step(optimizer, text)
print('epoch: %d, loss: %.4f, accuracy: %.2f, LR: %.8f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100, metrics['learning_rate']))
test = test_ds(params['vocab_length'])
sampled_text = ""
if ((epoch+1)%10 == 0):
for i in test:
sampled_text += vocab[int(jnp.argmax(i.numpy(),-1))]
start = np.expand_dims(i, axis=0)
text = sample(start, optimizer)
for i in text:
sampled_text += vocab[int(i)]
print(sampled_text)
peak the mariners all the merchant of the meaning of the meaning of the meaning of the meaning of the meaning of the meaning…
Of the moon, why,...
今日福利
遇见陆奇
同样作为“百万人学 AI”的重要组成部分,2020 AIProCon 开发者万人大会将于 7 月 3 日至 4 日通过线上直播形式,让开发者们一站式学习了解当下 AI 的前沿技术研究、核心技术与应用以及企业案例的实践经验,同时还可以在线参加精彩多样的开发者沙龙与编程项目。参与前瞻系列活动、在线直播互动,不仅可以与上万名开发者们一起交流,还有机会赢取直播专属好礼,与技术大咖连麦。
门票限量大放送!今日起点击阅读原文报名「2020 AI开发者万人大会」,使用优惠码“AIP211”,即可免费获得价值299元的大会在线直播门票一张。限量100张,先到先得!快来动动手指,免费获取入会资格吧!
点击阅读原文,直达大会官网。
你点的每个“在看”,我都认真当成了AI