You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet18_snake.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Réseau inspiré de http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf
  2. from keras.callbacks import History
  3. from resnet18 import ResNet18
  4. import tensorflow as tf
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. # Pour les utilisateurs de MacOS (pour utiliser plt & keras en même temps)
  8. import os
  9. #os.environ['KMP_DUPLICATE_LIB_OK']='True'
  10. def displayConvFilers(model, layer_name, num_filter=4, filter_size=(3,3), num_channel=0, fig_size=(2,2)):
  11. layer_dict = dict([(layer.name, layer) for layer in model.layers])
  12. weight, biais = layer_dict[layer_name].get_weights()
  13. print(weight.shape)
  14. plt.figure(figsize=fig_size)
  15. for i in range(num_filter):
  16. plt.subplot(fig_size[0],fig_size[1],i+1)
  17. plt.xticks([])
  18. plt.yticks([])
  19. plt.grid(False)
  20. vis = np.reshape(weight[:,:,num_channel,i],filter_size)
  21. plt.imshow(vis, cmap=plt.cm.binary)
  22. plt.show()
  23. def snake(x):
  24. return x + tf.sin(x)**2
  25. ## Chargement et normalisation des données
  26. resnet18 = tf.keras.datasets.cifar10
  27. (train_images, train_labels), (test_images, test_labels) = resnet18.load_data()
  28. val_images = train_images[40000:]
  29. val_labels = train_labels[40000:]
  30. train_images = train_images[:40000]
  31. train_labels = train_labels[:40000]
  32. train_images = train_images / 255.0
  33. val_images = val_images /255.0
  34. test_images = test_images / 255.0
  35. # POUR LES CNN : On rajoute une dimension pour spécifier qu'il s'agit d'imgages en NdG
  36. train_images = train_images.reshape(40000,32,32,3)
  37. val_images = val_images.reshape(10000,32,32,3)
  38. test_images = test_images.reshape(10000,32,32,3)
  39. # One hot encoding
  40. train_labels = tf.keras.utils.to_categorical(train_labels)
  41. val_labels = tf.keras.utils.to_categorical(val_labels)
  42. test_labels = tf.keras.utils.to_categorical(test_labels)
  43. filter_size_conv1 = (3,3)
  44. model = ResNet18(10)
  45. model.build(input_shape = (None,32,32,3))
  46. filter_size_conv1 = (5,5)
  47. '''
  48. ## Définition de l'architecture du modèle
  49. model = tf.keras.models.Sequential()
  50. # Expliquez à quoi correspondent les valeurs numériques qui définissent les couches du réseau
  51. model.add(tf.keras.layers.Conv2D(filters=6,kernel_size=filter_size_conv1,padding="same", activation=snake, input_shape=(32, 32, 3)))
  52. model.add(tf.keras.layers.AveragePooling2D())
  53. model.add(tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding="valid", activation=snake))
  54. model.add(tf.keras.layers.AveragePooling2D())
  55. model.add(tf.keras.layers.Flatten())
  56. model.add(tf.keras.layers.Dense(120 , activation=snake))
  57. #mnist_model.add(tf.keras.layers.Dropout(0.5))
  58. model.add(tf.keras.layers.Dense(84 , activation=snake))
  59. model.add(tf.keras.layers.Dense(10 , activation='softmax'))
  60. '''
  61. # expliquer le nombre de paramètre de ce réseau
  62. print(model.summary())
  63. sgd = tf.keras.optimizers.Adam()
  64. model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])
  65. # On visualise avant l'entrainement
  66. history = model.fit(train_images,
  67. train_labels,
  68. batch_size=64,
  69. epochs=4,
  70. validation_data=(val_images, val_labels),
  71. )
  72. ## Evaluation du modèle
  73. test_loss, test_acc = model.evaluate(test_images, test_labels)
  74. print('Test accuracy:', test_acc)
  75. fig, axs = plt.subplots(2, 1, figsize=(15,15))
  76. axs[0].plot(history.history['loss'])
  77. axs[0].plot(history.history['val_loss'])
  78. axs[0].title.set_text('Training Loss vs Validation Loss')
  79. axs[0].legend(['Train', 'Val'])
  80. axs[1].plot(history.history['accuracy'])
  81. axs[1].plot(history.history['val_accuracy'])
  82. axs[1].title.set_text('Training Accuracy vs Validation Accuracy')
  83. axs[1].legend(['Train', 'Val'])
  84. plt.savefig('./resnet18snake.png')
  85. '''
  86. displayConvFilers(mnist_model,
  87. 'conv2d',
  88. num_filter=6,
  89. filter_size=filter_size_conv1,
  90. num_channel=0,
  91. fig_size=(2,3)
  92. )
  93. '''