Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

resnet_snake.ipynb 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. {
  2. "nbformat": 4,
  3. "nbformat_minor": 0,
  4. "metadata": {
  5. "colab": {
  6. "name": "Untitled1.ipynb",
  7. "provenance": [],
  8. "collapsed_sections": []
  9. },
  10. "kernelspec": {
  11. "name": "python3",
  12. "display_name": "Python 3"
  13. },
  14. "language_info": {
  15. "name": "python"
  16. }
  17. },
  18. "cells": [
  19. {
  20. "cell_type": "code",
  21. "execution_count": null,
  22. "metadata": {
  23. "id": "hUYO09lse0Ew"
  24. },
  25. "outputs": [],
  26. "source": [
  27. "from tensorflow.python.ops.gen_array_ops import tensor_scatter_min_eager_fallback\n",
  28. "from resnet18 import ResNet18\n",
  29. "import tensorflow as tf\n",
  30. "import numpy as np\n",
  31. "import matplotlib.pyplot as plt\n",
  32. "\n",
  33. "def snake(x):\n",
  34. " return x + tf.sin(x)**2\n",
  35. "\n"
  36. ]
  37. },
  38. {
  39. "cell_type": "code",
  40. "source": [
  41. "## Chargement et normalisation des données\n",
  42. "resnet18 = tf.keras.datasets.cifar10\n",
  43. "(train_images, train_labels), (test_images, test_labels) = resnet18.load_data()\n",
  44. "train_images = train_images.astype('float32')\n",
  45. "test_images = test_images.astype('float32')\n",
  46. "\n",
  47. "from sklearn.model_selection import train_test_split\n",
  48. "train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True)\n",
  49. "\n",
  50. "train_images = train_images / 255.0\n",
  51. "val_images = val_images /255.0\n",
  52. "test_images = test_images / 255.0\n",
  53. "\n",
  54. "# POUR LES CNN : un tenseur d'ordre 3 pour les images en couleurs\n",
  55. "train_images = train_images.reshape(max(np.shape(train_images)),32,32,3)\n",
  56. "val_images = val_images.reshape(max(np.shape(val_images)),32,32,3)\n",
  57. "test_images = test_images.reshape(max(np.shape(test_images)),32,32,3)\n",
  58. "\n",
  59. "\n"
  60. ],
  61. "metadata": {
  62. "id": "-iczeI91fEND"
  63. },
  64. "execution_count": null,
  65. "outputs": []
  66. },
  67. {
  68. "cell_type": "code",
  69. "source": [
  70. "# One hot encoding\n",
  71. "train_labels = tf.keras.utils.to_categorical(train_labels)\n",
  72. "val_labels = tf.keras.utils.to_categorical(val_labels)\n",
  73. "test_labels = tf.keras.utils.to_categorical(test_labels)\n",
  74. "\n",
  75. "filter_size_conv1 = (3,3)\n",
  76. "\n",
  77. "\n",
  78. "#création du réseau ResNet18\n",
  79. "\n",
  80. "model = ResNet18(10)\n",
  81. "model.build(input_shape = (None,32,32,3))\n",
  82. "print(model.summary())\n"
  83. ],
  84. "metadata": {
  85. "id": "9UGNJFoRfGTz"
  86. },
  87. "execution_count": null,
  88. "outputs": []
  89. },
  90. {
  91. "cell_type": "code",
  92. "source": [
  93. "#Adam comme optimizer et categorical-crossentropy comme norme\n",
  94. "sgd = tf.keras.optimizers.Adam()\n",
  95. "model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])\n",
  96. "\n",
  97. "\n",
  98. "\n",
  99. "history = model.fit(train_images,\n",
  100. " train_labels,\n",
  101. " batch_size=64,\n",
  102. " epochs=100,\n",
  103. " validation_data=(val_images, val_labels),\n",
  104. " )\n"
  105. ],
  106. "metadata": {
  107. "id": "PtY0q1p5fM-p"
  108. },
  109. "execution_count": null,
  110. "outputs": []
  111. },
  112. {
  113. "cell_type": "code",
  114. "source": [
  115. "## Evaluation du modèle \n",
  116. "test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
  117. "print('Test accuracy:', test_acc)\n",
  118. "\n"
  119. ],
  120. "metadata": {
  121. "id": "FcehZmotfPCx"
  122. },
  123. "execution_count": null,
  124. "outputs": []
  125. },
  126. {
  127. "cell_type": "code",
  128. "source": [
  129. "## on affiche et on sauvegarde les images\n",
  130. "\n",
  131. "fig, axs = plt.subplots(2, 1, figsize=(15,15))\n",
  132. "\n",
  133. "axs[0].plot(history.history['loss'])\n",
  134. "axs[0].plot(history.history['val_loss'])\n",
  135. "axs[0].title.set_text('Training Loss vs Validation Loss')\n",
  136. "axs[0].legend(['Train', 'Val'])\n",
  137. "\n",
  138. "axs[1].plot(history.history['accuracy'])\n",
  139. "axs[1].plot(history.history['val_accuracy'])\n",
  140. "axs[1].title.set_text('Training Accuracy vs Validation Accuracy')\n",
  141. "axs[1].legend(['Train', 'Val'])\n",
  142. "plt.savefig('./resnet18snake.png')"
  143. ],
  144. "metadata": {
  145. "id": "8USwL2YTfRGN"
  146. },
  147. "execution_count": null,
  148. "outputs": []
  149. }
  150. ]
  151. }