Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

resnet18_snake.py 2.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Réseau inspiré de http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf
  2. from resnet18 import ResNet18
  3. import tensorflow as tf
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. # Pour les utilisateurs de MacOS (pour utiliser plt & keras en même temps)
  7. import os
  8. #os.environ['KMP_DUPLICATE_LIB_OK']='True'
  9. def displayConvFilers(model, layer_name, num_filter=4, filter_size=(3,3), num_channel=0, fig_size=(2,2)):
  10. layer_dict = dict([(layer.name, layer) for layer in mnist_model.layers])
  11. weight, biais = layer_dict[layer_name].get_weights()
  12. print(weight.shape)
  13. plt.figure(figsize=fig_size)
  14. for i in range(num_filter):
  15. plt.subplot(fig_size[0],fig_size[1],i+1)
  16. plt.xticks([])
  17. plt.yticks([])
  18. plt.grid(False)
  19. vis = np.reshape(weight[:,:,num_channel,i],filter_size)
  20. plt.imshow(vis, cmap=plt.cm.binary)
  21. plt.show()
  22. def snake(x):
  23. return x + tf.sin(x)**2
  24. ## Chargement et normalisation des données
  25. resnet18 = tf.keras.datasets.cifar10
  26. (train_images, train_labels), (test_images, test_labels) = resnet18.load_data()
  27. train_images = train_images / 255.0
  28. test_images = test_images / 255.0
  29. # POUR LES CNN : On rajoute une dimension pour spécifier qu'il s'agit d'imgages en NdG
  30. train_images = train_images.reshape(50000,32,32,3)
  31. test_images = test_images.reshape(10000,32,32,3)
  32. # One hot encoding
  33. train_labels = tf.keras.utils.to_categorical(train_labels)
  34. test_labels = tf.keras.utils.to_categorical(test_labels)
  35. filter_size_conv1 = (3,3)
  36. model = ResNet18(10)
  37. model.build(input_shape = (None,32,32,3))
  38. # expliquer le nombre de paramètre de ce réseau
  39. print(model.summary())
  40. sgd = tf.keras.optimizers.Adam()
  41. model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])
  42. # On visualise avant l'entrainement
  43. '''
  44. displayConvFilers(mnist_model, 'conv2d',
  45. num_filter=6,
  46. filter_size=filter_size_conv1,
  47. num_channel=0,
  48. fig_size=(2,3)
  49. )
  50. '''
  51. model.fit(train_images,
  52. train_labels,
  53. batch_size=64,
  54. epochs=4
  55. )
  56. ## Evaluation du modèle
  57. test_loss, test_acc = model.evaluate(test_images, test_labels)
  58. print('Test accuracy:', test_acc)
  59. '''
  60. displayConvFilers(mnist_model,
  61. 'conv2d',
  62. num_filter=6,
  63. filter_size=filter_size_conv1,
  64. num_channel=0,
  65. fig_size=(2,3)
  66. )
  67. '''
  68. '''
  69. displayConvFilers(mnist_model, 'conv2d_1',
  70. num_filter=16,
  71. filter_size=(5,5),
  72. num_channel=1,
  73. fig_size=(4,4)
  74. )
  75. '''