| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- {
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "Untitled1.ipynb",
- "provenance": [],
- "collapsed_sections": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "hUYO09lse0Ew"
- },
- "outputs": [],
- "source": [
- "from tensorflow.python.ops.gen_array_ops import tensor_scatter_min_eager_fallback\n",
- "from resnet18 import ResNet18\n",
- "import tensorflow as tf\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "def snake(x):\n",
- " return x + tf.sin(x)**2\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "## Chargement et normalisation des données\n",
- "resnet18 = tf.keras.datasets.cifar10\n",
- "(train_images, train_labels), (test_images, test_labels) = resnet18.load_data()\n",
- "train_images = train_images.astype('float32')\n",
- "test_images = test_images.astype('float32')\n",
- "\n",
- "from sklearn.model_selection import train_test_split\n",
- "train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True)\n",
- "\n",
- "train_images = train_images / 255.0\n",
- "val_images = val_images /255.0\n",
- "test_images = test_images / 255.0\n",
- "\n",
- "# POUR LES CNN : un tenseur d'ordre 3 pour les images en couleurs\n",
- "train_images = train_images.reshape(max(np.shape(train_images)),32,32,3)\n",
- "val_images = val_images.reshape(max(np.shape(val_images)),32,32,3)\n",
- "test_images = test_images.reshape(max(np.shape(test_images)),32,32,3)\n",
- "\n",
- "\n"
- ],
- "metadata": {
- "id": "-iczeI91fEND"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# One hot encoding\n",
- "train_labels = tf.keras.utils.to_categorical(train_labels)\n",
- "val_labels = tf.keras.utils.to_categorical(val_labels)\n",
- "test_labels = tf.keras.utils.to_categorical(test_labels)\n",
- "\n",
- "filter_size_conv1 = (3,3)\n",
- "\n",
- "\n",
- "#création du réseau ResNet18\n",
- "\n",
- "model = ResNet18(10)\n",
- "model.build(input_shape = (None,32,32,3))\n",
- "print(model.summary())\n"
- ],
- "metadata": {
- "id": "9UGNJFoRfGTz"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#Adam comme optimizer et categorical-crossentropy comme norme\n",
- "sgd = tf.keras.optimizers.Adam()\n",
- "model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])\n",
- "\n",
- "\n",
- "\n",
- "history = model.fit(train_images,\n",
- " train_labels,\n",
- " batch_size=64,\n",
- " epochs=100,\n",
- " validation_data=(val_images, val_labels),\n",
- " )\n"
- ],
- "metadata": {
- "id": "PtY0q1p5fM-p"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "## Evaluation du modèle \n",
- "test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
- "print('Test accuracy:', test_acc)\n",
- "\n"
- ],
- "metadata": {
- "id": "FcehZmotfPCx"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "## on affiche et on sauvegarde les images\n",
- "\n",
- "fig, axs = plt.subplots(2, 1, figsize=(15,15))\n",
- "\n",
- "axs[0].plot(history.history['loss'])\n",
- "axs[0].plot(history.history['val_loss'])\n",
- "axs[0].title.set_text('Training Loss vs Validation Loss')\n",
- "axs[0].legend(['Train', 'Val'])\n",
- "\n",
- "axs[1].plot(history.history['accuracy'])\n",
- "axs[1].plot(history.history['val_accuracy'])\n",
- "axs[1].title.set_text('Training Accuracy vs Validation Accuracy')\n",
- "axs[1].legend(['Train', 'Val'])\n",
- "plt.savefig('./resnet18snake.png')"
- ],
- "metadata": {
- "id": "8USwL2YTfRGN"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
- }
|