You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet18_snake.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Réseau inspiré de http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf
  2. from tensorflow.python.ops.gen_array_ops import tensor_scatter_min_eager_fallback
  3. from resnet18 import ResNet18
  4. import tensorflow as tf
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. def displayConvFilers(model, layer_name, num_filter=4, filter_size=(3,3), num_channel=0, fig_size=(2,2)):
  8. layer_dict = dict([(layer.name, layer) for layer in model.layers])
  9. weight, biais = layer_dict[layer_name].get_weights()
  10. print(weight.shape)
  11. plt.figure(figsize=fig_size)
  12. for i in range(num_filter):
  13. plt.subplot(fig_size[0],fig_size[1],i+1)
  14. plt.xticks([])
  15. plt.yticks([])
  16. plt.grid(False)
  17. vis = np.reshape(weight[:,:,num_channel,i],filter_size)
  18. plt.imshow(vis, cmap=plt.cm.binary)
  19. plt.show()
  20. def snake(x):
  21. return x + tf.sin(x)**2
  22. ## Chargement et normalisation des données
  23. resnet18 = tf.keras.datasets.cifar10
  24. (train_images, train_labels), (test_images, test_labels) = resnet18.load_data()
  25. from sklearn.model_selection import train_test_split
  26. train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True)
  27. '''
  28. val_images = train_images[40000:]
  29. val_labels = train_labels[40000:]
  30. train_images = train_images[:40000]
  31. train_labels = train_labels[:40000]
  32. '''
  33. train_images = train_images / 255.0
  34. val_images = val_images /255.0
  35. test_images = test_images / 255.0
  36. # POUR LES CNN : On rajoute une dimension pour spécifier qu'il s'agit d'imgages en NdG
  37. train_images = train_images.reshape(max(np.shape(train_images)),32,32,3)
  38. val_images = val_images.reshape(max(np.shape(val_images)),32,32,3)
  39. test_images = test_images.reshape(max(np.shape(test_images)),32,32,3)
  40. # One hot encoding
  41. train_labels = tf.keras.utils.to_categorical(train_labels)
  42. val_labels = tf.keras.utils.to_categorical(val_labels)
  43. test_labels = tf.keras.utils.to_categorical(test_labels)
  44. filter_size_conv1 = (3,3)
  45. model = ResNet18(10)
  46. model.build(input_shape = (None,32,32,3))
  47. '''
  48. filter_size_conv1 = (5,5)
  49. ## Définition de l'architecture du modèle
  50. model = tf.keras.models.Sequential()
  51. # Expliquez à quoi correspondent les valeurs numériques qui définissent les couches du réseau
  52. model.add(tf.keras.layers.Conv2D(filters=6,kernel_size=filter_size_conv1,padding="same", activation=snake, input_shape=(32, 32, 3)))
  53. model.add(tf.keras.layers.AveragePooling2D())
  54. model.add(tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding="valid", activation=snake))
  55. model.add(tf.keras.layers.AveragePooling2D())
  56. model.add(tf.keras.layers.Flatten())
  57. model.add(tf.keras.layers.Dense(120 , activation=snake))
  58. #mnist_model.add(tf.keras.layers.Dropout(0.5))
  59. model.add(tf.keras.layers.Dense(84 , activation=snake))
  60. model.add(tf.keras.layers.Dense(10 , activation='softmax'))
  61. '''
  62. # expliquer le nombre de paramètre de ce réseau
  63. print(model.summary())
  64. sgd = tf.keras.optimizers.Adam()
  65. model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])
  66. # On visualise avant l'entrainement
  67. history = model.fit(train_images,
  68. train_labels,
  69. batch_size=64,
  70. epochs=4,
  71. validation_data=(val_images, val_labels),
  72. )
  73. ## Evaluation du modèle
  74. test_loss, test_acc = model.evaluate(test_images, test_labels)
  75. print('Test accuracy:', test_acc)
  76. fig, axs = plt.subplots(2, 1, figsize=(15,15))
  77. axs[0].plot(history.history['loss'])
  78. axs[0].plot(history.history['val_loss'])
  79. axs[0].title.set_text('Training Loss vs Validation Loss')
  80. axs[0].legend(['Train', 'Val'])
  81. axs[1].plot(history.history['accuracy'])
  82. axs[1].plot(history.history['val_accuracy'])
  83. axs[1].title.set_text('Training Accuracy vs Validation Accuracy')
  84. axs[1].legend(['Train', 'Val'])
  85. plt.savefig('./resnet18snake.png')
  86. '''
  87. displayConvFilers(mnist_model,
  88. 'conv2d',
  89. num_filter=6,
  90. filter_size=filter_size_conv1,
  91. num_channel=0,
  92. fig_size=(2,3)
  93. )
  94. '''