Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from keras.callbacks import EarlyStopping
  2. from keras.layers import Dense, Conv2D, MaxPool2D, Flatten, GlobalAveragePooling2D, BatchNormalization, Layer, Add
  3. from keras.models import Sequential
  4. from keras.models import Model
  5. import tensorflow as tf
  6. class ResnetBlock(Model):
  7. """
  8. A standard resnet block.
  9. """
  10. def __init__(self, channels: int, down_sample=False):
  11. """
  12. channels: same as number of convolution kernels
  13. """
  14. super().__init__()
  15. self.__channels = channels
  16. self.__down_sample = down_sample
  17. self.__strides = [2, 1] if down_sample else [1, 1]
  18. KERNEL_SIZE = (3, 3)
  19. INIT_SCHEME = "he_normal"
  20. self.conv_1 = Conv2D(self.__channels, strides=self.__strides[0],
  21. kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
  22. self.bn_1 = BatchNormalization()
  23. self.conv_2 = Conv2D(self.__channels, strides=self.__strides[1],
  24. kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
  25. self.bn_2 = BatchNormalization()
  26. self.merge = Add()
  27. if self.__down_sample:
  28. self.res_conv = Conv2D(
  29. self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer=INIT_SCHEME, padding="same")
  30. self.res_bn = BatchNormalization()
  31. def call(self, inputs):
  32. res = inputs
  33. x = self.conv_1(inputs)
  34. x = self.bn_1(x)
  35. x = x + tf.sin(x)**2 #tf.nn.relu(x)
  36. x = self.conv_2(x)
  37. x = self.bn_2(x)
  38. if self.__down_sample:
  39. res = self.res_conv(res)
  40. res = self.res_bn(res)
  41. x = self.merge([x, res])
  42. out = x + tf.sin(x)**2 #tf.nn.relu(x)
  43. return out
  44. class ResNet18(Model):
  45. def __init__(self, num_classes, **kwargs):
  46. """
  47. num_classes: number of classes in specific classification task.
  48. """
  49. super().__init__(**kwargs)
  50. self.conv_1 = Conv2D(64, (7, 7), strides=2,
  51. padding="same", kernel_initializer="he_normal")
  52. self.init_bn = BatchNormalization()
  53. self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
  54. self.res_1_1 = ResnetBlock(64)
  55. self.res_1_2 = ResnetBlock(64)
  56. self.res_2_1 = ResnetBlock(128, down_sample=True)
  57. self.res_2_2 = ResnetBlock(128)
  58. self.res_3_1 = ResnetBlock(256, down_sample=True)
  59. self.res_3_2 = ResnetBlock(256)
  60. self.res_4_1 = ResnetBlock(512, down_sample=True)
  61. self.res_4_2 = ResnetBlock(512)
  62. self.avg_pool = GlobalAveragePooling2D()
  63. self.flat = Flatten()
  64. self.fc = Dense(num_classes, activation="sigmoid")
  65. def call(self, inputs):
  66. out = self.conv_1(inputs)
  67. out = self.init_bn(out)
  68. out += tf.sin(out)**2 #tf.nn.relu(out)
  69. out = self.pool_2(out)
  70. for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
  71. out = res_block(out)
  72. out = self.avg_pool(out)
  73. out = self.flat(out)
  74. out = self.fc(out)
  75. return out