| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- from keras.callbacks import EarlyStopping
- from keras.layers import Dense, Conv2D, MaxPool2D, Flatten, GlobalAveragePooling2D, BatchNormalization, Layer, Add
- from keras.models import Sequential
- from keras.models import Model
- import tensorflow as tf
-
- class ResnetBlock(Model):
- """
- A standard resnet block.
- """
-
- def __init__(self, channels: int, down_sample=False):
- """
- channels: same as number of convolution kernels
- """
- super().__init__()
-
- self.__channels = channels
- self.__down_sample = down_sample
- self.__strides = [2, 1] if down_sample else [1, 1]
-
- KERNEL_SIZE = (3, 3)
- INIT_SCHEME = "he_normal"
-
- self.conv_1 = Conv2D(self.__channels, strides=self.__strides[0],
- kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
- self.bn_1 = BatchNormalization()
- self.conv_2 = Conv2D(self.__channels, strides=self.__strides[1],
- kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
- self.bn_2 = BatchNormalization()
- self.merge = Add()
-
- if self.__down_sample:
-
- self.res_conv = Conv2D(
- self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer=INIT_SCHEME, padding="same")
- self.res_bn = BatchNormalization()
-
- def call(self, inputs):
- res = inputs
-
- x = self.conv_1(inputs)
- x = self.bn_1(x)
- x = x + tf.sin(x)**2 #tf.nn.relu(x)
- x = self.conv_2(x)
- x = self.bn_2(x)
-
- if self.__down_sample:
- res = self.res_conv(res)
- res = self.res_bn(res)
-
-
- x = self.merge([x, res])
- out = x + tf.sin(x)**2 #tf.nn.relu(x)
- return out
-
-
- class ResNet18(Model):
-
- def __init__(self, num_classes, **kwargs):
- """
- num_classes: number of classes in specific classification task.
- """
- super().__init__(**kwargs)
- self.conv_1 = Conv2D(64, (7, 7), strides=2,
- padding="same", kernel_initializer="he_normal")
- self.init_bn = BatchNormalization()
- self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
- self.res_1_1 = ResnetBlock(64)
- self.res_1_2 = ResnetBlock(64)
- self.res_2_1 = ResnetBlock(128, down_sample=True)
- self.res_2_2 = ResnetBlock(128)
- self.res_3_1 = ResnetBlock(256, down_sample=True)
- self.res_3_2 = ResnetBlock(256)
- self.res_4_1 = ResnetBlock(512, down_sample=True)
- self.res_4_2 = ResnetBlock(512)
- self.avg_pool = GlobalAveragePooling2D()
- self.flat = Flatten()
- self.fc = Dense(num_classes, activation="sigmoid")
-
- def call(self, inputs):
- out = self.conv_1(inputs)
- out = self.init_bn(out)
- out += tf.sin(out)**2 #tf.nn.relu(out)
- out = self.pool_2(out)
- for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
- out = res_block(out)
- out = self.avg_pool(out)
- out = self.flat(out)
- out = self.fc(out)
- return out
|