Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Réseau inspiré de http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf
  2. from tensorflow.python.ops.gen_array_ops import tensor_scatter_min_eager_fallback
  3. from resnet18 import ResNet18
  4. import tensorflow as tf
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. def displayConvFilers(model, layer_name, num_filter=4, filter_size=(3,3), num_channel=0, fig_size=(2,2)):
  8. layer_dict = dict([(layer.name, layer) for layer in model.layers])
  9. weight, biais = layer_dict[layer_name].get_weights()
  10. print(weight.shape)
  11. plt.figure(figsize=fig_size)
  12. for i in range(num_filter):
  13. plt.subplot(fig_size[0],fig_size[1],i+1)
  14. plt.xticks([])
  15. plt.yticks([])
  16. plt.grid(False)
  17. vis = np.reshape(weight[:,:,num_channel,i],filter_size)
  18. plt.imshow(vis, cmap=plt.cm.binary)
  19. plt.show()
  20. def snake(x):
  21. return x + tf.sin(x)**2
  22. ## Chargement et normalisation des données
  23. resnet18 = tf.keras.datasets.cifar10
  24. (train_images, train_labels), (test_images, test_labels) = resnet18.load_data()
  25. train_images = train_images.astype('float32')
  26. test_images = test_images.astype('float32')
  27. from sklearn.model_selection import train_test_split
  28. train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True)
  29. '''
  30. val_images = train_images[40000:]
  31. val_labels = train_labels[40000:]
  32. train_images = train_images[:40000]
  33. train_labels = train_labels[:40000]
  34. '''
  35. train_images = train_images / 255.0
  36. val_images = val_images /255.0
  37. test_images = test_images / 255.0
  38. # POUR LES CNN : On rajoute une dimension pour spécifier qu'il s'agit d'imgages en NdG
  39. train_images = train_images.reshape(max(np.shape(train_images)),32,32,3)
  40. val_images = val_images.reshape(max(np.shape(val_images)),32,32,3)
  41. test_images = test_images.reshape(max(np.shape(test_images)),32,32,3)
  42. # One hot encoding
  43. train_labels = tf.keras.utils.to_categorical(train_labels)
  44. val_labels = tf.keras.utils.to_categorical(val_labels)
  45. test_labels = tf.keras.utils.to_categorical(test_labels)
  46. filter_size_conv1 = (3,3)
  47. model = ResNet18(10)
  48. model.build(input_shape = (None,32,32,3))
  49. '''
  50. filter_size_conv1 = (5,5)
  51. ## Définition de l'architecture du modèle
  52. model = tf.keras.models.Sequential()
  53. # Expliquez à quoi correspondent les valeurs numériques qui définissent les couches du réseau
  54. model.add(tf.keras.layers.Conv2D(filters=6,kernel_size=filter_size_conv1,padding="same", activation=snake, input_shape=(32, 32, 3)))
  55. model.add(tf.keras.layers.AveragePooling2D())
  56. model.add(tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding="valid", activation=snake))
  57. model.add(tf.keras.layers.AveragePooling2D())
  58. model.add(tf.keras.layers.Flatten())
  59. model.add(tf.keras.layers.Dense(120 , activation=snake))
  60. #mnist_model.add(tf.keras.layers.Dropout(0.5))
  61. model.add(tf.keras.layers.Dense(84 , activation=snake))
  62. model.add(tf.keras.layers.Dense(10 , activation='softmax'))
  63. '''
  64. # expliquer le nombre de paramètre de ce réseau
  65. print(model.summary())
  66. sgd = tf.keras.optimizers.Adam()
  67. model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])
  68. # On visualise avant l'entrainement
  69. history = model.fit(train_images,
  70. train_labels,
  71. batch_size=256,
  72. epochs=50,
  73. validation_data=(val_images, val_labels),
  74. )
  75. ## Evaluation du modèle
  76. test_loss, test_acc = model.evaluate(test_images, test_labels)
  77. print('Test accuracy:', test_acc)
  78. fig, axs = plt.subplots(2, 1, figsize=(15,15))
  79. axs[0].plot(history.history['loss'])
  80. axs[0].plot(history.history['val_loss'])
  81. axs[0].title.set_text('Training Loss vs Validation Loss')
  82. axs[0].legend(['Train', 'Val'])
  83. axs[1].plot(history.history['accuracy'])
  84. axs[1].plot(history.history['val_accuracy'])
  85. axs[1].title.set_text('Training Accuracy vs Validation Accuracy')
  86. axs[1].legend(['Train', 'Val'])
  87. plt.savefig('./resnet18snake.png')
  88. '''
  89. displayConvFilers(mnist_model,
  90. 'conv2d',
  91. num_filter=6,
  92. filter_size=filter_size_conv1,
  93. num_channel=0,
  94. fig_size=(2,3)
  95. )
  96. '''