您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

resnet18.py 3.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from keras.callbacks import EarlyStopping
  2. from keras.layers import Dense, Conv2D, MaxPool2D, Flatten, GlobalAveragePooling2D, BatchNormalization, Layer, Add
  3. from keras.models import Sequential
  4. from keras.models import Model
  5. import tensorflow as tf
  6. class ResnetBlock(Model):
  7. """
  8. A standard resnet block.
  9. """
  10. def __init__(self, channels: int, down_sample=False):
  11. """
  12. channels: same as number of convolution kernels
  13. """
  14. super().__init__()
  15. self.__channels = channels
  16. self.__down_sample = down_sample
  17. self.__strides = [2, 1] if down_sample else [1, 1]
  18. KERNEL_SIZE = (3, 3)
  19. # use He initialization, instead of Xavier (a.k.a 'glorot_uniform' in Keras), as suggested in [2]
  20. INIT_SCHEME = "he_normal"
  21. self.conv_1 = Conv2D(self.__channels, strides=self.__strides[0],
  22. kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
  23. self.bn_1 = BatchNormalization()
  24. self.conv_2 = Conv2D(self.__channels, strides=self.__strides[1],
  25. kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
  26. self.bn_2 = BatchNormalization()
  27. self.merge = Add()
  28. if self.__down_sample:
  29. # perform down sampling using stride of 2, according to [1].
  30. self.res_conv = Conv2D(
  31. self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer=INIT_SCHEME, padding="same")
  32. self.res_bn = BatchNormalization()
  33. def call(self, inputs):
  34. res = inputs
  35. x = self.conv_1(inputs)
  36. x = self.bn_1(x)
  37. x = x + tf.sin(x)**2 #tf.nn.relu(x)
  38. x = self.conv_2(x)
  39. x = self.bn_2(x)
  40. if self.__down_sample:
  41. res = self.res_conv(res)
  42. res = self.res_bn(res)
  43. # if not perform down sample, then add a shortcut directly
  44. x = self.merge([x, res])
  45. out = x + tf.sin(x)**2 #tf.nn.relu(x)
  46. return out
  47. class ResNet18(Model):
  48. def __init__(self, num_classes, **kwargs):
  49. """
  50. num_classes: number of classes in specific classification task.
  51. """
  52. super().__init__(**kwargs)
  53. self.conv_1 = Conv2D(64, (7, 7), strides=2,
  54. padding="same", kernel_initializer="he_normal")
  55. self.init_bn = BatchNormalization()
  56. self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
  57. self.res_1_1 = ResnetBlock(64)
  58. self.res_1_2 = ResnetBlock(64)
  59. self.res_2_1 = ResnetBlock(128, down_sample=True)
  60. self.res_2_2 = ResnetBlock(128)
  61. self.res_3_1 = ResnetBlock(256, down_sample=True)
  62. self.res_3_2 = ResnetBlock(256)
  63. self.res_4_1 = ResnetBlock(512, down_sample=True)
  64. self.res_4_2 = ResnetBlock(512)
  65. self.avg_pool = GlobalAveragePooling2D()
  66. self.flat = Flatten()
  67. self.fc = Dense(num_classes, activation="softmax")
  68. def call(self, inputs):
  69. out = self.conv_1(inputs)
  70. out = self.init_bn(out)
  71. out = x + tf.sin(x)**2 #tf.nn.relu(out)
  72. out = self.pool_2(out)
  73. for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
  74. out = res_block(out)
  75. out = self.avg_pool(out)
  76. out = self.flat(out)
  77. out = self.fc(out)
  78. return out