551 lines
20 KiB
Plaintext
551 lines
20 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using TensorFlow backend.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"np.random.seed(123)\n",
|
|
"from six.moves import cPickle\n",
|
|
"\n",
|
|
"from keras import backend as K\n",
|
|
"from keras.models import Model\n",
|
|
"from keras.layers import Input, Dense, Flatten\n",
|
|
"from keras.layers import LSTM\n",
|
|
"from keras.layers import TimeDistributed\n",
|
|
"from keras.callbacks import LearningRateScheduler, ModelCheckpoint\n",
|
|
"from keras.optimizers import Adam\n",
|
|
"\n",
|
|
"from prednet import PredNet\n",
|
|
"from data_utils import SequenceGenerator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"WEIGHTS_DIR = './weights/'\n",
|
|
"DATA_DIR = '../data/'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_model = True # if weights will be saved\n",
|
|
"weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weather_weights.hdf5') # where weights will be saved\n",
|
|
"json_file = os.path.join(WEIGHTS_DIR, 'prednet_weather_model.json')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Data files\n",
|
|
"train_file = os.path.join(DATA_DIR, 'x_train.hkl')\n",
|
|
"train_sources = os.path.join(DATA_DIR, 'sources_train.hkl')\n",
|
|
"val_file = os.path.join(DATA_DIR, 'x_val.hkl')\n",
|
|
"val_sources = os.path.join(DATA_DIR, 'sources_val.hkl')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Training parameters\n",
|
|
"nb_epoch = 150\n",
|
|
"batch_size = 4\n",
|
|
"samples_per_epoch = 100\n",
|
|
"N_seq_val = 50 # number of sequences to use for validation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Model parameters\n",
|
|
"n_channels, im_height, im_width = (7, 20, 40)\n",
|
|
"input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)\n",
|
|
"stack_sizes = (n_channels, 12, 24)\n",
|
|
"R_stack_sizes = stack_sizes\n",
|
|
"A_filt_sizes = (2, 2)\n",
|
|
"Ahat_filt_sizes = (2, 2, 2)\n",
|
|
"R_filt_sizes = (2, 2, 2)\n",
|
|
"layer_loss_weights = np.array([1., 0., 0.]) # weighting for each layer in final loss; \"L_0\" model: [1, 0, 0, 0], \"L_all\": [1, 0.1, 0.1, 0.1]\n",
|
|
"layer_loss_weights = np.expand_dims(layer_loss_weights, 1)\n",
|
|
"nt = 24 # number of timesteps used for sequences in training\n",
|
|
"time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first\n",
|
|
"time_loss_weights[0] = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prednet = PredNet(stack_sizes, R_stack_sizes,\n",
|
|
" A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,\n",
|
|
" output_mode='error', return_sequences=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"inputs = Input(shape=(nt,) + input_shape)\n",
|
|
"errors = prednet(inputs) # errors will be (batch_size, nt, nb_layers)\n",
|
|
"errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors) # calculate weighted error by layer\n",
|
|
"errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)\n",
|
|
"final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time\n",
|
|
"model = Model(inputs=inputs, outputs=final_errors)\n",
|
|
"model.compile(loss='mean_absolute_error', optimizer='adam')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"_________________________________________________________________\n",
|
|
"Layer (type) Output Shape Param # \n",
|
|
"=================================================================\n",
|
|
"input_1 (InputLayer) (None, 24, 20, 40, 7) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"pred_net_1 (PredNet) (None, 24, 3) 49167 \n",
|
|
"_________________________________________________________________\n",
|
|
"time_distributed_1 (TimeDist (None, 24, 1) 4 \n",
|
|
"_________________________________________________________________\n",
|
|
"flatten_1 (Flatten) (None, 24) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"dense_2 (Dense) (None, 1) 25 \n",
|
|
"=================================================================\n",
|
|
"Total params: 49,196\n",
|
|
"Trainable params: 49,167\n",
|
|
"Non-trainable params: 29\n",
|
|
"_________________________________________________________________\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.summary()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=False)\n",
|
|
"val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs\n",
|
|
"callbacks = [LearningRateScheduler(lr_schedule)]\n",
|
|
"if save_model:\n",
|
|
" if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)\n",
|
|
" callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1/150\n",
|
|
" - 18s - loss: 2.2654e-04 - val_loss: 1.6606e-04\n",
|
|
"Epoch 2/150\n",
|
|
" - 13s - loss: 1.6084e-04 - val_loss: 1.6482e-04\n",
|
|
"Epoch 3/150\n",
|
|
" - 13s - loss: 1.6346e-04 - val_loss: 1.5707e-04\n",
|
|
"Epoch 4/150\n",
|
|
" - 14s - loss: 1.5891e-04 - val_loss: 1.5682e-04\n",
|
|
"Epoch 5/150\n",
|
|
" - 15s - loss: 1.5410e-04 - val_loss: 1.4740e-04\n",
|
|
"Epoch 6/150\n",
|
|
" - 13s - loss: 1.3540e-04 - val_loss: 1.4354e-04\n",
|
|
"Epoch 7/150\n",
|
|
" - 14s - loss: 1.4448e-04 - val_loss: 1.2114e-04\n",
|
|
"Epoch 8/150\n",
|
|
" - 13s - loss: 1.3061e-04 - val_loss: 1.5603e-04\n",
|
|
"Epoch 9/150\n",
|
|
" - 13s - loss: 1.2790e-04 - val_loss: 1.6199e-04\n",
|
|
"Epoch 10/150\n",
|
|
" - 13s - loss: 1.2041e-04 - val_loss: 1.2755e-04\n",
|
|
"Epoch 11/150\n",
|
|
" - 13s - loss: 1.2834e-04 - val_loss: 1.3263e-04\n",
|
|
"Epoch 12/150\n",
|
|
" - 13s - loss: 1.3211e-04 - val_loss: 1.3543e-04\n",
|
|
"Epoch 13/150\n",
|
|
" - 13s - loss: 1.2578e-04 - val_loss: 1.2785e-04\n",
|
|
"Epoch 14/150\n",
|
|
" - 14s - loss: 1.2164e-04 - val_loss: 1.1487e-04\n",
|
|
"Epoch 15/150\n",
|
|
" - 13s - loss: 1.1284e-04 - val_loss: 1.2066e-04\n",
|
|
"Epoch 16/150\n",
|
|
" - 13s - loss: 1.2309e-04 - val_loss: 1.1847e-04\n",
|
|
"Epoch 17/150\n",
|
|
" - 15s - loss: 1.2082e-04 - val_loss: 1.2631e-04\n",
|
|
"Epoch 18/150\n",
|
|
" - 15s - loss: 1.1457e-04 - val_loss: 1.2507e-04\n",
|
|
"Epoch 19/150\n",
|
|
" - 13s - loss: 1.0505e-04 - val_loss: 1.1174e-04\n",
|
|
"Epoch 20/150\n",
|
|
" - 15s - loss: 1.1421e-04 - val_loss: 1.1496e-04\n",
|
|
"Epoch 21/150\n",
|
|
" - 14s - loss: 1.1419e-04 - val_loss: 1.0648e-04\n",
|
|
"Epoch 22/150\n",
|
|
" - 14s - loss: 1.0596e-04 - val_loss: 1.0112e-04\n",
|
|
"Epoch 23/150\n",
|
|
" - 13s - loss: 1.0302e-04 - val_loss: 1.0167e-04\n",
|
|
"Epoch 24/150\n",
|
|
" - 13s - loss: 1.0402e-04 - val_loss: 9.9010e-05\n",
|
|
"Epoch 25/150\n",
|
|
" - 13s - loss: 1.0802e-04 - val_loss: 1.0683e-04\n",
|
|
"Epoch 26/150\n",
|
|
" - 13s - loss: 1.0313e-04 - val_loss: 9.9764e-05\n",
|
|
"Epoch 27/150\n",
|
|
" - 13s - loss: 9.8126e-05 - val_loss: 9.9443e-05\n",
|
|
"Epoch 28/150\n",
|
|
" - 14s - loss: 9.4907e-05 - val_loss: 9.8053e-05\n",
|
|
"Epoch 29/150\n",
|
|
" - 14s - loss: 1.0607e-04 - val_loss: 1.0455e-04\n",
|
|
"Epoch 30/150\n",
|
|
" - 13s - loss: 9.9968e-05 - val_loss: 9.9277e-05\n",
|
|
"Epoch 31/150\n",
|
|
" - 13s - loss: 9.3142e-05 - val_loss: 9.3689e-05\n",
|
|
"Epoch 32/150\n",
|
|
" - 13s - loss: 8.9096e-05 - val_loss: 9.0897e-05\n",
|
|
"Epoch 33/150\n",
|
|
" - 13s - loss: 9.4016e-05 - val_loss: 1.1416e-04\n",
|
|
"Epoch 34/150\n",
|
|
" - 13s - loss: 9.2043e-05 - val_loss: 9.6274e-05\n",
|
|
"Epoch 35/150\n",
|
|
" - 13s - loss: 8.8990e-05 - val_loss: 9.2792e-05\n",
|
|
"Epoch 36/150\n",
|
|
" - 13s - loss: 8.6284e-05 - val_loss: 9.1992e-05\n",
|
|
"Epoch 37/150\n",
|
|
" - 13s - loss: 8.4010e-05 - val_loss: 9.2135e-05\n",
|
|
"Epoch 38/150\n",
|
|
" - 13s - loss: 9.1255e-05 - val_loss: 8.4548e-05\n",
|
|
"Epoch 39/150\n",
|
|
" - 13s - loss: 8.5270e-05 - val_loss: 8.9138e-05\n",
|
|
"Epoch 40/150\n",
|
|
" - 13s - loss: 8.2153e-05 - val_loss: 8.4956e-05\n",
|
|
"Epoch 41/150\n",
|
|
" - 13s - loss: 7.9384e-05 - val_loss: 8.2901e-05\n",
|
|
"Epoch 42/150\n",
|
|
" - 13s - loss: 8.6787e-05 - val_loss: 8.0335e-05\n",
|
|
"Epoch 43/150\n",
|
|
" - 13s - loss: 8.2295e-05 - val_loss: 8.4503e-05\n",
|
|
"Epoch 44/150\n",
|
|
" - 13s - loss: 7.8093e-05 - val_loss: 8.0219e-05\n",
|
|
"Epoch 45/150\n",
|
|
" - 13s - loss: 6.7297e-05 - val_loss: 7.7004e-05\n",
|
|
"Epoch 46/150\n",
|
|
" - 14s - loss: 6.9567e-05 - val_loss: 7.4336e-05\n",
|
|
"Epoch 47/150\n",
|
|
" - 14s - loss: 8.1303e-05 - val_loss: 8.2066e-05\n",
|
|
"Epoch 48/150\n",
|
|
" - 15s - loss: 7.7026e-05 - val_loss: 7.9328e-05\n",
|
|
"Epoch 49/150\n",
|
|
" - 14s - loss: 7.3553e-05 - val_loss: 7.5846e-05\n",
|
|
"Epoch 50/150\n",
|
|
" - 14s - loss: 7.3287e-05 - val_loss: 7.5882e-05\n",
|
|
"Epoch 51/150\n",
|
|
" - 14s - loss: 7.7200e-05 - val_loss: 7.3799e-05\n",
|
|
"Epoch 52/150\n",
|
|
" - 14s - loss: 7.3123e-05 - val_loss: 7.6654e-05\n",
|
|
"Epoch 53/150\n",
|
|
" - 13s - loss: 6.9633e-05 - val_loss: 7.5168e-05\n",
|
|
"Epoch 54/150\n",
|
|
" - 14s - loss: 6.8400e-05 - val_loss: 6.9869e-05\n",
|
|
"Epoch 55/150\n",
|
|
" - 15s - loss: 7.3588e-05 - val_loss: 7.9260e-05\n",
|
|
"Epoch 56/150\n",
|
|
" - 13s - loss: 7.2371e-05 - val_loss: 7.3971e-05\n",
|
|
"Epoch 57/150\n",
|
|
" - 13s - loss: 6.7148e-05 - val_loss: 7.0150e-05\n",
|
|
"Epoch 58/150\n",
|
|
" - 13s - loss: 6.5651e-05 - val_loss: 6.9693e-05\n",
|
|
"Epoch 59/150\n",
|
|
" - 13s - loss: 6.5532e-05 - val_loss: 6.9470e-05\n",
|
|
"Epoch 60/150\n",
|
|
" - 13s - loss: 6.9948e-05 - val_loss: 6.5869e-05\n",
|
|
"Epoch 61/150\n",
|
|
" - 14s - loss: 6.3758e-05 - val_loss: 6.8406e-05\n",
|
|
"Epoch 62/150\n",
|
|
" - 13s - loss: 6.1931e-05 - val_loss: 6.5637e-05\n",
|
|
"Epoch 63/150\n",
|
|
" - 13s - loss: 6.1838e-05 - val_loss: 6.7699e-05\n",
|
|
"Epoch 64/150\n",
|
|
" - 13s - loss: 6.8736e-05 - val_loss: 6.3853e-05\n",
|
|
"Epoch 65/150\n",
|
|
" - 13s - loss: 6.2603e-05 - val_loss: 6.6491e-05\n",
|
|
"Epoch 66/150\n",
|
|
" - 13s - loss: 5.8582e-05 - val_loss: 6.2542e-05\n",
|
|
"Epoch 67/150\n",
|
|
" - 13s - loss: 5.8376e-05 - val_loss: 6.1135e-05\n",
|
|
"Epoch 68/150\n",
|
|
" - 13s - loss: 6.2913e-05 - val_loss: 6.5607e-05\n",
|
|
"Epoch 69/150\n",
|
|
" - 13s - loss: 6.3289e-05 - val_loss: 6.1227e-05\n",
|
|
"Epoch 70/150\n",
|
|
" - 13s - loss: 5.7150e-05 - val_loss: 6.2312e-05\n",
|
|
"Epoch 71/150\n",
|
|
" - 13s - loss: 5.5959e-05 - val_loss: 6.1468e-05\n",
|
|
"Epoch 72/150\n",
|
|
" - 13s - loss: 5.9391e-05 - val_loss: 6.1573e-05\n",
|
|
"Epoch 73/150\n",
|
|
" - 14s - loss: 6.1597e-05 - val_loss: 5.8205e-05\n",
|
|
"Epoch 74/150\n",
|
|
" - 14s - loss: 5.6171e-05 - val_loss: 5.8807e-05\n",
|
|
"Epoch 75/150\n",
|
|
" - 13s - loss: 5.3894e-05 - val_loss: 6.0140e-05\n",
|
|
"Epoch 76/150\n",
|
|
" - 13s - loss: 5.4176e-05 - val_loss: 5.6236e-05\n",
|
|
"Epoch 77/150\n",
|
|
" - 14s - loss: 5.5984e-05 - val_loss: 5.5843e-05\n",
|
|
"Epoch 78/150\n",
|
|
" - 15s - loss: 5.5739e-05 - val_loss: 5.7032e-05\n",
|
|
"Epoch 79/150\n",
|
|
" - 14s - loss: 5.2283e-05 - val_loss: 5.6289e-05\n",
|
|
"Epoch 80/150\n",
|
|
" - 13s - loss: 5.2437e-05 - val_loss: 5.5842e-05\n",
|
|
"Epoch 81/150\n",
|
|
" - 14s - loss: 5.3922e-05 - val_loss: 5.5559e-05\n",
|
|
"Epoch 82/150\n",
|
|
" - 13s - loss: 5.6874e-05 - val_loss: 5.5289e-05\n",
|
|
"Epoch 83/150\n",
|
|
" - 13s - loss: 5.3202e-05 - val_loss: 5.6150e-05\n",
|
|
"Epoch 84/150\n",
|
|
" - 13s - loss: 5.1825e-05 - val_loss: 5.5809e-05\n",
|
|
"Epoch 85/150\n",
|
|
" - 14s - loss: 5.2601e-05 - val_loss: 5.5992e-05\n",
|
|
"Epoch 86/150\n",
|
|
" - 15s - loss: 5.6301e-05 - val_loss: 5.5160e-05\n",
|
|
"Epoch 87/150\n",
|
|
" - 14s - loss: 5.4528e-05 - val_loss: 5.6079e-05\n",
|
|
"Epoch 88/150\n",
|
|
" - 15s - loss: 5.1597e-05 - val_loss: 5.5375e-05\n",
|
|
"Epoch 89/150\n",
|
|
" - 14s - loss: 5.1668e-05 - val_loss: 5.5335e-05\n",
|
|
"Epoch 90/150\n",
|
|
" - 13s - loss: 5.4285e-05 - val_loss: 5.4986e-05\n",
|
|
"Epoch 91/150\n",
|
|
" - 14s - loss: 5.5829e-05 - val_loss: 5.5269e-05\n",
|
|
"Epoch 92/150\n",
|
|
" - 15s - loss: 5.2061e-05 - val_loss: 5.5917e-05\n",
|
|
"Epoch 93/150\n",
|
|
" - 14s - loss: 5.1437e-05 - val_loss: 5.5878e-05\n",
|
|
"Epoch 94/150\n",
|
|
" - 13s - loss: 5.2905e-05 - val_loss: 5.5128e-05\n",
|
|
"Epoch 95/150\n",
|
|
" - 14s - loss: 5.5643e-05 - val_loss: 5.5071e-05\n",
|
|
"Epoch 96/150\n",
|
|
" - 14s - loss: 5.3255e-05 - val_loss: 5.6120e-05\n",
|
|
"Epoch 97/150\n",
|
|
" - 14s - loss: 5.0939e-05 - val_loss: 5.7411e-05\n",
|
|
"Epoch 98/150\n",
|
|
" - 13s - loss: 5.1713e-05 - val_loss: 5.4738e-05\n",
|
|
"Epoch 99/150\n",
|
|
" - 13s - loss: 5.4313e-05 - val_loss: 5.4543e-05\n",
|
|
"Epoch 100/150\n",
|
|
" - 13s - loss: 5.4338e-05 - val_loss: 5.5159e-05\n",
|
|
"Epoch 101/150\n",
|
|
" - 13s - loss: 5.1035e-05 - val_loss: 5.4962e-05\n",
|
|
"Epoch 102/150\n",
|
|
" - 14s - loss: 5.1137e-05 - val_loss: 5.4461e-05\n",
|
|
"Epoch 103/150\n",
|
|
" - 13s - loss: 5.2371e-05 - val_loss: 5.4086e-05\n",
|
|
"Epoch 104/150\n",
|
|
" - 13s - loss: 5.5356e-05 - val_loss: 5.3880e-05\n",
|
|
"Epoch 105/150\n",
|
|
" - 13s - loss: 5.1895e-05 - val_loss: 5.4717e-05\n",
|
|
"Epoch 106/150\n",
|
|
" - 13s - loss: 5.0507e-05 - val_loss: 5.4484e-05\n",
|
|
"Epoch 107/150\n",
|
|
" - 13s - loss: 5.1207e-05 - val_loss: 5.4376e-05\n",
|
|
"Epoch 108/150\n",
|
|
" - 13s - loss: 5.4695e-05 - val_loss: 5.3700e-05\n",
|
|
"Epoch 109/150\n",
|
|
" - 13s - loss: 5.3160e-05 - val_loss: 5.4604e-05\n",
|
|
"Epoch 110/150\n",
|
|
" - 13s - loss: 5.0227e-05 - val_loss: 5.3845e-05\n",
|
|
"Epoch 111/150\n",
|
|
" - 13s - loss: 5.0310e-05 - val_loss: 5.3851e-05\n",
|
|
"Epoch 112/150\n",
|
|
" - 13s - loss: 5.2718e-05 - val_loss: 5.3708e-05\n",
|
|
"Epoch 113/150\n",
|
|
" - 13s - loss: 5.4318e-05 - val_loss: 5.4004e-05\n",
|
|
"Epoch 114/150\n",
|
|
" - 13s - loss: 5.0657e-05 - val_loss: 5.4248e-05\n",
|
|
"Epoch 115/150\n",
|
|
" - 13s - loss: 5.0033e-05 - val_loss: 5.4256e-05\n",
|
|
"Epoch 116/150\n",
|
|
" - 13s - loss: 5.1428e-05 - val_loss: 5.3486e-05\n",
|
|
"Epoch 117/150\n",
|
|
" - 13s - loss: 5.4040e-05 - val_loss: 5.3235e-05\n",
|
|
"Epoch 118/150\n",
|
|
" - 13s - loss: 5.1812e-05 - val_loss: 5.4249e-05\n",
|
|
"Epoch 119/150\n",
|
|
" - 13s - loss: 4.9503e-05 - val_loss: 5.5388e-05\n",
|
|
"Epoch 120/150\n",
|
|
" - 13s - loss: 5.0289e-05 - val_loss: 5.3228e-05\n",
|
|
"Epoch 121/150\n",
|
|
" - 13s - loss: 5.2727e-05 - val_loss: 5.2826e-05\n",
|
|
"Epoch 122/150\n",
|
|
" - 13s - loss: 5.2783e-05 - val_loss: 5.3716e-05\n",
|
|
"Epoch 123/150\n",
|
|
" - 13s - loss: 4.9556e-05 - val_loss: 5.3553e-05\n",
|
|
"Epoch 124/150\n",
|
|
" - 14s - loss: 4.9713e-05 - val_loss: 5.2973e-05\n",
|
|
"Epoch 125/150\n",
|
|
" - 13s - loss: 5.0792e-05 - val_loss: 5.2560e-05\n",
|
|
"Epoch 126/150\n",
|
|
" - 13s - loss: 5.3758e-05 - val_loss: 5.2314e-05\n",
|
|
"Epoch 127/150\n",
|
|
" - 13s - loss: 5.0386e-05 - val_loss: 5.3166e-05\n",
|
|
"Epoch 128/150\n",
|
|
" - 13s - loss: 4.9029e-05 - val_loss: 5.3028e-05\n",
|
|
"Epoch 129/150\n",
|
|
" - 13s - loss: 4.9680e-05 - val_loss: 5.2968e-05\n",
|
|
"Epoch 130/150\n",
|
|
" - 13s - loss: 5.3036e-05 - val_loss: 5.2100e-05\n",
|
|
"Epoch 131/150\n",
|
|
" - 13s - loss: 5.1631e-05 - val_loss: 5.2928e-05\n",
|
|
"Epoch 132/150\n",
|
|
" - 13s - loss: 4.8700e-05 - val_loss: 5.2284e-05\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 133/150\n",
|
|
" - 13s - loss: 4.8813e-05 - val_loss: 5.2158e-05\n",
|
|
"Epoch 134/150\n",
|
|
" - 13s - loss: 5.1009e-05 - val_loss: 5.2366e-05\n",
|
|
"Epoch 135/150\n",
|
|
" - 13s - loss: 5.2787e-05 - val_loss: 5.2123e-05\n",
|
|
"Epoch 136/150\n",
|
|
" - 13s - loss: 4.9104e-05 - val_loss: 5.2410e-05\n",
|
|
"Epoch 137/150\n",
|
|
" - 14s - loss: 4.8515e-05 - val_loss: 5.2213e-05\n",
|
|
"Epoch 138/150\n",
|
|
" - 13s - loss: 4.9778e-05 - val_loss: 5.1791e-05\n",
|
|
"Epoch 139/150\n",
|
|
" - 13s - loss: 5.2401e-05 - val_loss: 5.1701e-05\n",
|
|
"Epoch 140/150\n",
|
|
" - 13s - loss: 5.0269e-05 - val_loss: 5.2384e-05\n",
|
|
"Epoch 141/150\n",
|
|
" - 13s - loss: 4.7956e-05 - val_loss: 5.3089e-05\n",
|
|
"Epoch 142/150\n",
|
|
" - 13s - loss: 4.8717e-05 - val_loss: 5.1456e-05\n",
|
|
"Epoch 143/150\n",
|
|
" - 13s - loss: 5.0973e-05 - val_loss: 5.1221e-05\n",
|
|
"Epoch 144/150\n",
|
|
" - 13s - loss: 5.1292e-05 - val_loss: 5.1897e-05\n",
|
|
"Epoch 145/150\n",
|
|
" - 13s - loss: 4.7986e-05 - val_loss: 5.1949e-05\n",
|
|
"Epoch 146/150\n",
|
|
" - 13s - loss: 4.8139e-05 - val_loss: 5.1388e-05\n",
|
|
"Epoch 147/150\n",
|
|
" - 13s - loss: 4.9055e-05 - val_loss: 5.1023e-05\n",
|
|
"Epoch 148/150\n",
|
|
" - 13s - loss: 5.2175e-05 - val_loss: 5.0815e-05\n",
|
|
"Epoch 149/150\n",
|
|
" - 13s - loss: 4.8831e-05 - val_loss: 5.1613e-05\n",
|
|
"Epoch 150/150\n",
|
|
" - 13s - loss: 4.7456e-05 - val_loss: 5.1646e-05\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"history = model.fit_generator(train_generator, steps_per_epoch=(samples_per_epoch / batch_size), \n",
|
|
" epochs=nb_epoch, callbacks=callbacks,\n",
|
|
" validation_data=val_generator, validation_steps=N_seq_val / batch_size,\n",
|
|
" verbose=2, workers=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if save_model:\n",
|
|
" json_string = model.to_json()\n",
|
|
" with open(json_file, \"w\") as f:\n",
|
|
" f.write(json_string)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|