|
|
|
@@ -0,0 +1,151 @@ |
|
|
|
{ |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 0, |
|
|
|
"metadata": { |
|
|
|
"colab": { |
|
|
|
"name": "Untitled1.ipynb", |
|
|
|
"provenance": [], |
|
|
|
"collapsed_sections": [] |
|
|
|
}, |
|
|
|
"kernelspec": { |
|
|
|
"name": "python3", |
|
|
|
"display_name": "Python 3" |
|
|
|
}, |
|
|
|
"language_info": { |
|
|
|
"name": "python" |
|
|
|
} |
|
|
|
}, |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"id": "hUYO09lse0Ew" |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from tensorflow.python.ops.gen_array_ops import tensor_scatter_min_eager_fallback\n", |
|
|
|
"from resnet18 import ResNet18\n", |
|
|
|
"import tensorflow as tf\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"import matplotlib.pyplot as plt\n", |
|
|
|
"\n", |
|
|
|
"def snake(x):\n", |
|
|
|
" return x + tf.sin(x)**2\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"source": [ |
|
|
|
"## Chargement et normalisation des données\n", |
|
|
|
"resnet18 = tf.keras.datasets.cifar10\n", |
|
|
|
"(train_images, train_labels), (test_images, test_labels) = resnet18.load_data()\n", |
|
|
|
"train_images = train_images.astype('float32')\n", |
|
|
|
"test_images = test_images.astype('float32')\n", |
|
|
|
"\n", |
|
|
|
"from sklearn.model_selection import train_test_split\n", |
|
|
|
"train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True)\n", |
|
|
|
"\n", |
|
|
|
"train_images = train_images / 255.0\n", |
|
|
|
"val_images = val_images /255.0\n", |
|
|
|
"test_images = test_images / 255.0\n", |
|
|
|
"\n", |
|
|
|
"# POUR LES CNN : un tenseur d'ordre 3 pour les images en couleurs\n", |
|
|
|
"train_images = train_images.reshape(max(np.shape(train_images)),32,32,3)\n", |
|
|
|
"val_images = val_images.reshape(max(np.shape(val_images)),32,32,3)\n", |
|
|
|
"test_images = test_images.reshape(max(np.shape(test_images)),32,32,3)\n", |
|
|
|
"\n", |
|
|
|
"\n" |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"id": "-iczeI91fEND" |
|
|
|
}, |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"source": [ |
|
|
|
"# One hot encoding\n", |
|
|
|
"train_labels = tf.keras.utils.to_categorical(train_labels)\n", |
|
|
|
"val_labels = tf.keras.utils.to_categorical(val_labels)\n", |
|
|
|
"test_labels = tf.keras.utils.to_categorical(test_labels)\n", |
|
|
|
"\n", |
|
|
|
"filter_size_conv1 = (3,3)\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"#création du réseau ResNet18\n", |
|
|
|
"\n", |
|
|
|
"model = ResNet18(10)\n", |
|
|
|
"model.build(input_shape = (None,32,32,3))\n", |
|
|
|
"print(model.summary())\n" |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"id": "9UGNJFoRfGTz" |
|
|
|
}, |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"source": [ |
|
|
|
"#Adam comme optimizer et categorical-crossentropy comme norme\n", |
|
|
|
"sgd = tf.keras.optimizers.Adam()\n", |
|
|
|
"model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"history = model.fit(train_images,\n", |
|
|
|
" train_labels,\n", |
|
|
|
" batch_size=64,\n", |
|
|
|
" epochs=100,\n", |
|
|
|
" validation_data=(val_images, val_labels),\n", |
|
|
|
" )\n" |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"id": "PtY0q1p5fM-p" |
|
|
|
}, |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"source": [ |
|
|
|
"## Evaluation du modèle \n", |
|
|
|
"test_loss, test_acc = model.evaluate(test_images, test_labels)\n", |
|
|
|
"print('Test accuracy:', test_acc)\n", |
|
|
|
"\n" |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"id": "FcehZmotfPCx" |
|
|
|
}, |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"source": [ |
|
|
|
"## on affiche et on sauvegarde les images\n", |
|
|
|
"\n", |
|
|
|
"fig, axs = plt.subplots(2, 1, figsize=(15,15))\n", |
|
|
|
"\n", |
|
|
|
"axs[0].plot(history.history['loss'])\n", |
|
|
|
"axs[0].plot(history.history['val_loss'])\n", |
|
|
|
"axs[0].title.set_text('Training Loss vs Validation Loss')\n", |
|
|
|
"axs[0].legend(['Train', 'Val'])\n", |
|
|
|
"\n", |
|
|
|
"axs[1].plot(history.history['accuracy'])\n", |
|
|
|
"axs[1].plot(history.history['val_accuracy'])\n", |
|
|
|
"axs[1].title.set_text('Training Accuracy vs Validation Accuracy')\n", |
|
|
|
"axs[1].legend(['Train', 'Val'])\n", |
|
|
|
"plt.savefig('./resnet18snake.png')" |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"id": "8USwL2YTfRGN" |
|
|
|
}, |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
} |
|
|
|
] |
|
|
|
} |