Files
Weather-Project/Project Final/Evaluation/Evalutation.ipynb
nmelone 270798f23e Scaling and Naive Case added
Implemented scaling to the data so that the final MSE calculation worked properly since the predictions are already scaled between [0,1].

Added the Naive Case of averaging the previous 5 frames together to get a naive prediction.
2018-10-22 17:21:33 -05:00

355 lines
111 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import os\n",
"import numpy as np\n",
"from six.moves import cPickle\n",
"import matplotlib\n",
"matplotlib.use('Agg')\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.gridspec as gridspec\n",
"%matplotlib inline\n",
"from keras import backend as K\n",
"from keras.models import Model, model_from_json\n",
"from keras.layers import Input, Dense, Flatten\n",
"\n",
"from prednet import PredNet\n",
"from data_utils import SequenceGenerator\n",
"\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"n_plot = 40\n",
"batch_size = 10\n",
"nt = 24"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"WEIGHTS_DIR = '../Training/weights/'\n",
"DATA_DIR = '../data/'\n",
"RESULTS_SAVE_DIR = './weather_results/'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weather_weights.hdf5')\n",
"json_file = os.path.join(WEIGHTS_DIR, 'prednet_weather_model.json')\n",
"test_file = os.path.join(DATA_DIR, 'x_test.hkl')\n",
"test_sources = os.path.join(DATA_DIR, 'sources_test.hkl')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Load trained model\n",
"f = open(json_file, 'r')\n",
"json_string = f.read()\n",
"f.close()\n",
"train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})\n",
"train_model.load_weights(weights_file)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Create testing model (to output predictions)\n",
"layer_config = train_model.layers[1].get_config()\n",
"layer_config['output_mode'] = 'prediction'\n",
"data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']\n",
"test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)\n",
"input_shape = list(train_model.layers[0].batch_input_shape[1:])\n",
"input_shape[0] = nt\n",
"inputs = Input(shape=tuple(input_shape))\n",
"predictions = test_prednet(inputs)\n",
"test_model = Model(inputs=inputs, outputs=predictions)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)\n",
"X_test = test_generator.create_all()\n",
"X_hat = test_model.predict(X_test, batch_size)\n",
"if data_format == 'channels_first':\n",
" X_test = np.transpose(X_test, (0, 1, 3, 4, 2))\n",
" X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt\n",
"mse_model = np.nanmean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first\n",
"mse_prev = np.nanmean( (X_test[:, :-1] - X_test[:, 1:])**2 )\n",
"if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR)\n",
"f = open(RESULTS_SAVE_DIR + 'prediction_scores.txt', 'w')\n",
"f.write(\"Model MSE: %f\\n\" % mse_model)\n",
"f.write(\"Previous Frame MSE: %f\" % mse_prev)\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def getAverage(test_frames):\n",
" blank = np.zeros((20,40,7))\n",
" window = [blank * 5]\n",
" average = []\n",
" for day in range(test_frames.shape[0]):\n",
" day_average = []\n",
" for hour in range(test_frames.shape[1]):\n",
" day_average.append(np.mean(window[-5:], axis=0))\n",
" window.append(test_frames[day,hour])\n",
" average.append(np.array(day_average))\n",
" return np.array(average)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"#Naive Case 5 frame average\n",
"X_test_naive = getAverage(X_test)\n",
"mse_naive = np.nanmean((X_test[:, 1:] - X_test_naive[:,1:])**2)\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model MSE:\t 2.3059108400502737e-07\n",
"Prev Frame MSE:\t 1.0880363277010474e-08\n",
"Naive MSE:\t 4.413529927867206e-08\n"
]
}
],
"source": [
"print(\"Model MSE:\\t {}\\nPrev Frame MSE:\\t {}\\nNaive MSE:\\t {}\".format(mse_model,mse_prev, mse_naive))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"y_labels = [\n",
" 'VISIBILITY',\n",
" 'DB TEMP C',\n",
" 'WB TEMP C',\n",
" 'Dew Point',\n",
" 'Humidity',\n",
" 'WindSpeed',\n",
" 'Pressure',\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x24a800314a8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(15,10))\n",
"grid = gridspec.GridSpec(7, 3)\n",
"grid.update(wspace=0., hspace=0.2)\n",
"for i in range(7):\n",
" plt.subplot(grid[i*3])\n",
" plt.imshow(X_test[0,23,:,:,i])\n",
" plt.ylabel(y_labels[i])\n",
" plt.xlabel('Actual', fontsize=10)\n",
" \n",
" plt.subplot(grid[i*3 +1])\n",
" plt.imshow(X_hat[0,23,:,:,i])\n",
" plt.ylabel(y_labels[i])\n",
" plt.xlabel('Predicted', fontsize=10)\n",
" \n",
" plt.subplot(grid[i*3 + 2])\n",
" plt.imshow(X_test_naive[0,23,:,:,i])\n",
" plt.ylabel(y_labels[i])\n",
" plt.xlabel('Naive', fontsize=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 63%|███████████████████████████████████████████████████▉ | 19/30 [09:38<05:35, 30.47s/it]"
]
}
],
"source": [
"# Plot some predictions\n",
"aspect_ratio = float(X_hat.shape[3]) / X_hat.shape[2]\n",
"plt.figure(figsize = (nt, 7*2*aspect_ratio))\n",
"gs = gridspec.GridSpec(2*7, nt)\n",
"gs.update(wspace=0., hspace=0.2)\n",
"plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots/')\n",
"if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir)\n",
"plot_idx = np.random.permutation(X_test.shape[0])[:n_plot]\n",
"for i in tqdm(plot_idx):\n",
" for t in range(nt):\n",
" for c in range(7):\n",
" plt.subplot(gs[t + c*2*nt])\n",
" plt.imshow(X_test[i,t,:,:,c], interpolation='none')\n",
" plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')\n",
" if t==0: plt.ylabel('Actual', fontsize=10)\n",
"\n",
" plt.subplot(gs[t + (c*2+1)*nt])\n",
" plt.imshow(X_hat[i,t,:,:,c], interpolation='none')\n",
" plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')\n",
" if t==0: plt.ylabel('Predicted', fontsize=10)\n",
"\n",
" plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png')\n",
" plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "unhashable type: 'numpy.ndarray'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-36-cf5f6662ea82>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mrows\u001b[0m \u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mfig\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_subplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrows\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mX_hat\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\pyplot.py\u001b[0m in \u001b[0;36mimshow\u001b[1;34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)\u001b[0m\n\u001b[0;32m 3099\u001b[0m \u001b[0mfilternorm\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfilternorm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilterrad\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfilterrad\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3100\u001b[0m \u001b[0mimlim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mimlim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresample\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mresample\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0murl\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 3101\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 3102\u001b[0m \u001b[1;32mfinally\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3103\u001b[0m \u001b[0max\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_hold\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mwashold\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\__init__.py\u001b[0m in \u001b[0;36minner\u001b[1;34m(ax, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1715\u001b[0m warnings.warn(msg % (label_namer, func.__name__),\n\u001b[0;32m 1716\u001b[0m RuntimeWarning, stacklevel=2)\n\u001b[1;32m-> 1717\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0max\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1718\u001b[0m \u001b[0mpre_doc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minner\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__doc__\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1719\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mpre_doc\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\axes\\_axes.py\u001b[0m in \u001b[0;36mimshow\u001b[1;34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)\u001b[0m\n\u001b[0;32m 5123\u001b[0m im = mimage.AxesImage(self, cmap, norm, interpolation, origin, extent,\n\u001b[0;32m 5124\u001b[0m \u001b[0mfilternorm\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfilternorm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilterrad\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfilterrad\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 5125\u001b[1;33m resample=resample, **kwargs)\n\u001b[0m\u001b[0;32m 5126\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5127\u001b[0m \u001b[0mim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mset_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\image.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, ax, cmap, norm, interpolation, origin, extent, filternorm, filterrad, resample, **kwargs)\u001b[0m\n\u001b[0;32m 763\u001b[0m \u001b[0mfilterrad\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfilterrad\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 764\u001b[0m \u001b[0mresample\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mresample\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 765\u001b[1;33m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 766\u001b[0m )\n\u001b[0;32m 767\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\image.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, ax, cmap, norm, interpolation, origin, filternorm, filterrad, resample, **kwargs)\u001b[0m\n\u001b[0;32m 223\u001b[0m \"\"\"\n\u001b[0;32m 224\u001b[0m \u001b[0mmartist\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mArtist\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 225\u001b[1;33m \u001b[0mcm\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mScalarMappable\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnorm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcmap\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 226\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_mouseover\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 227\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0morigin\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\cm.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, norm, cmap)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnorm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnorm\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;31m#: The Colormap instance of this ScalarMappable.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcmap\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_cmap\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcmap\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;31m#: The last colorbar associated with this ScalarMappable. May be None.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcolorbar\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\dasputer\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\matplotlib\\cm.py\u001b[0m in \u001b[0;36mget_cmap\u001b[1;34m(name, lut)\u001b[0m\n\u001b[0;32m 159\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 161\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mcmap_d\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 162\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlut\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 163\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mcmap_d\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mTypeError\u001b[0m: unhashable type: 'numpy.ndarray'"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKMAAACZCAYAAABDjYpLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAB7JJREFUeJzt3V+IXHcZxvHvY2MtxNqCG6GYxFhMjaEITQcJFLTSCjUXyYVFEig1ErsU/10oglJRqReivRCK1bpiqC0Y2+ZCV4kU1EhF3NgJbWMSqWzrv8VA0jTmplhbeL04Z+1mOztzdvY32bdzng8szMw5c+Y9ycM5e/Y37/kpIjDL4A2rXYDZPIfR0nAYLQ2H0dJwGC0Nh9HSGBhGSfslnZZ0fInlknSvpFlJxyRtK1+mtUGTI+MDwC19ln8Y2Fz/TALfW3lZ1kYDwxgRjwMv9FllF/BgVGaAKyVdVapAa48SvzO+Hfjngudz9Wtmy7KmwDbU47WeY4ySJqlO5axdu/b6LVu2FPh4y+To0aPPR8S6Yd5bIoxzwIYFz9cD/+q1YkRMAVMAnU4nut1ugY+3TCT9fdj3ljhNTwO311fV24HzEXGqwHatZQYeGSUdAG4EJiTNAV8F3ggQEfcDh4AdwCzwIvDxURVr421gGCNiz4DlAXyqWEXWWh6BsTQcRkvDYbQ0HEZLw2G0NBxGS8NhtDQcRkvDYbQ0HEZLw2G0NBxGS8NhtDQcRkvDYbQ0GoVR0i2Snql7o7/YY/lGSYclPVn3Tu8oX6qNuyZN/JcA91H1R28F9kjaumi1LwOPRMR1wG7gu6ULtfHX5Mj4PmA2Ip6LiP8CP6HqlV4ogLfUj69giYYss36ahLFJX/TXgNvqHplDwGd6bUjSpKSupO6ZM2eGKNfGWZMwNumL3gM8EBHrqZqzHpL0mm1HxFREdCKis27dUK21NsaahLFJX/Q+4BGAiPgDcBkwUaJAa48mYXwC2CzpnZIupbpAmV60zj+AmwAkvYcqjD4P27I0ufHTK8CngceAP1NdNZ+QdLeknfVqnwfukPQ0cADYG55GwZap0e1NIuIQ1YXJwte+suDxSeCGsqVZ23gExtJwGC0Nh9HScBgtDYfR0nAYLQ2H0dJwGC0Nh9HScBgtDYfR0nAYLQ2H0dJwGC2NIq2q9ToflXRS0glJPy5bprVBk0mJ5ltVP0TVgvCEpOn6O4zz62wGvgTcEBHnJL1tVAXb+CrVqnoHcF9EnAOIiNNly7Q2KNWqeg1wjaTfS5qR1G+ydLOemrQdNGlVXQNspppjcD3wO0nXRsS/L9jQgil+N27cuOxibbyValWdA34WES9HxF+BZ6jCeQH3TVs/pVpVfwp8EEDSBNVp+7mShdr4K9Wq+hhwVtJJ4DDwhYg4O6qibTxptdqbO51OdLvdVflsGx1JRyOiM8x7PQJjaTiMlobDaGk4jJaGw2hpOIyWhsNoaTiMlobDaGk4jJaGw2hpOIyWhsNoaTiMlobDaGkU65uu17tVUkga6vts1m6lpvhF0uXAZ4EjpYu0dijVNw3wdeBbwH8K1mctUqRvWtJ1wIaI+EXB2qxlVjzFbz2V77ep5g/svyHPN219lOibvhy4FvitpL8B24HpXhcx7pu2flbcNx0R5yNiIiI2RcQmYAbYGRFu/bNlKdU3bbZiRab4XfT6jSsvy9rIIzCWhsNoaTiMlobDaGk4jJaGw2hpOIyWhsNoaTiMlobDaGk4jJaGw2hpOIyWhsNoaTiMlkaRvmlJn6vnmj4m6deS3lG+VBt3pfqmnwQ6EfFe4CBVy6rZshTpm46IwxHxYv10hqppy2xZSs03vdA+4Je9FrhV1fpZcd/0BStKtwEd4J5ey92qav00achqMt80km4G7gI+EBEvlSnP2qTIfNP17U2+T9Uvfbp8mdYGpfqm7wHeDDwq6SlJiydHNxuoSN90RNxcuC5rIY/AWBoOo6XhMFoaDqOl4TBaGg6jpeEwWhoOo6XhMFoaDqOl4TBaGg6jpeEwWhoOo6VRqlX1TZIerpcfkbSpdKE2/kq1qu4DzkXEu6jmEfxm6UJt/JWa4ncX8KP68UHgJkm9GrnMllSqVfX/69RtCueBt5Yo0NqjSdtBk1bVRu2skiaByfrpS5KON/j8cTABPL/aRVwk7x72jaVaVefXmZO0BrgCeGHxhiJiCpgCkNSNiNdMAzyO2ravw763SKtq/fxj9eNbgd9ERM9Gf7OlDDwyRsQrkuZbVS8B9s+3qgLdiJgGfgg8JGmW6oi4e5RF23jSah3AJE3Wp+2x531t+F6fTS0LDwdaGiMPY5uGEhvs615JZ+pbwDwl6ROrUedKSdov6fRSf5pT5d763+GYpG2NNhwRI/uhuuB5FrgauBR4Gti6aJ1PAvfXj3cDD4+yplXe173Ad1a71gL7+n5gG3B8ieU7qO7RKWA7cKTJdkd9ZGzTUGKTfR0LEfE4Pf6OvMAu4MGozABXSrpq0HZHHcY2DSU2vcPvR+pT10FJG3osHwfLvdsxMPowFhtKfB1osh8/BzZFdSP+X/HqGWHcDPV/OuowLmcokX5Dia8DA/c1Is7Gq3f1/QFw/UWq7WJrdLfjxUYdxjYNJTa5w+/C35t2Ut18dRxNA7fXV9XbgfMRcWrguy7CldcO4C9UV5p31a/dTXXLZYDLgEeBWeCPwNWrfbU4wn39BnCC6kr7MLBltWsecj8PAKeAl6mOgvuAO4E76+Wi+kL2s8CfqOYIGrhdj8BYGh6BsTQcRkvDYbQ0HEZLw2G0NBxGS8NhtDQcRkvjf947yezihsagAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x249d9e7ecc0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig=plt.figure(figsize=(15,10))\n",
"columns = 3\n",
"rows = 4\n",
"for i in range(1,columns+rows +1):\n",
" fig.add_subplot(rows,columns,i)\n",
" plt.imshow(X_test[0,0,:,:,i-1],X_hat[0,0,:,:,i-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_hat[0][0][0][0][2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}