from keras.callbacks import EarlyStopping from keras.layers import Dense, Conv2D, MaxPool2D, Flatten, GlobalAveragePooling2D, BatchNormalization, Layer, Add from keras.models import Sequential from keras.models import Model import tensorflow as tf class ResnetBlock(Model): """ A standard resnet block. """ def __init__(self, channels: int, down_sample=False): """ channels: same as number of convolution kernels """ super().__init__() self.__channels = channels self.__down_sample = down_sample self.__strides = [2, 1] if down_sample else [1, 1] KERNEL_SIZE = (3, 3) # use He initialization, instead of Xavier (a.k.a 'glorot_uniform' in Keras), as suggested in [2] INIT_SCHEME = "he_normal" self.conv_1 = Conv2D(self.__channels, strides=self.__strides[0], kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME) self.bn_1 = BatchNormalization() self.conv_2 = Conv2D(self.__channels, strides=self.__strides[1], kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME) self.bn_2 = BatchNormalization() self.merge = Add() if self.__down_sample: # perform down sampling using stride of 2, according to [1]. self.res_conv = Conv2D( self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer=INIT_SCHEME, padding="same") self.res_bn = BatchNormalization() def call(self, inputs): res = inputs x = self.conv_1(inputs) x = self.bn_1(x) x = tf.nn.relu(x) x = self.conv_2(x) x = self.bn_2(x) if self.__down_sample: res = self.res_conv(res) res = self.res_bn(res) # if not perform down sample, then add a shortcut directly x = self.merge([x, res]) out = tf.nn.relu(x) return out class ResNet18(Model): def __init__(self, num_classes, **kwargs): """ num_classes: number of classes in specific classification task. """ super().__init__(**kwargs) self.conv_1 = Conv2D(64, (7, 7), strides=2, padding="same", kernel_initializer="he_normal") self.init_bn = BatchNormalization() self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same") self.res_1_1 = ResnetBlock(64) self.res_1_2 = ResnetBlock(64) self.res_2_1 = ResnetBlock(128, down_sample=True) self.res_2_2 = ResnetBlock(128) self.res_3_1 = ResnetBlock(256, down_sample=True) self.res_3_2 = ResnetBlock(256) self.res_4_1 = ResnetBlock(512, down_sample=True) self.res_4_2 = ResnetBlock(512) self.avg_pool = GlobalAveragePooling2D() self.flat = Flatten() self.fc = Dense(num_classes, activation="softmax") def call(self, inputs): out = self.conv_1(inputs) out = self.init_bn(out) out = tf.nn.relu(out) out = self.pool_2(out) for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]: out = res_block(out) out = self.avg_pool(out) out = self.flat(out) out = self.fc(out) return out