Skip to content

Model transfer learning classifier

ModelTransferLearningClassifier

Bases: ModelClassifierMixin, ModelKeras

CNN model with transfer learning for classification

Source code in template_vision/models_training/classifiers/model_transfer_learning_classifier.py
class ModelTransferLearningClassifier(ModelClassifierMixin, ModelKeras):
    '''CNN model with transfer learning for classification'''

    _default_name = 'model_transfer_learning_classifier'

    def __init__(self, with_fine_tune: bool = True, second_epochs: int = 10, second_lr: float = 1e-5, second_patience: int = 5, **kwargs) -> None:
        '''Initialization of the class (see ModelClass, ModelKeras & ModelClassifierMixin for more arguments)

        Kwargs:
            with_fine_tune (bool): If a fine-tuning step should be performed after first training
            second_epochs (int): Number of epochs for the fine-tuning step
            second_lr (float): Learning rate for the fine-tuning step
            second_patience (int): Patience for the fine-tuning step
        '''
        # Init.
        super().__init__(**kwargs)

        # Get logger (must be done after super init)
        self.logger = logging.getLogger(__name__)

        # Params
        self.with_fine_tune = with_fine_tune
        self.second_epochs = second_epochs
        self.second_lr = second_lr
        self.second_patience = second_patience

    def _fit_classifier(self, df_train: pd.DataFrame, df_valid: pd.DataFrame = None, with_shuffle: bool = True, **kwargs) -> dict:
        '''Fits the model - overrides parent function

        Args:
            df_train (pd.DataFrame): Train dataset
                Must contain file_path & file_class columns
        Kwargs:
            df_valid (pd.DataFrame): Validation dataset
                Must contain file_path & file_class columns
            with_shuffle (boolean): If the train dataset must be shuffled
                This should be used if the input dataset is not shuffled & no validation set as the split_validation takes the lines in order.
                Thus, the validation set might get classes which are not in the train set ...
        Returns:
            dict: fit arguments
        '''
        # First fit
        self.logger.info("Transfer Learning - Premier entrainement")
        fit_arguments = super()._fit_classifier(df_train, df_valid=df_valid, with_shuffle=with_shuffle, **kwargs)

        # Fine tune if wanted
        if self.with_fine_tune and self.second_epochs > 0:
            # Unfreeze all layers
            for layer in self.model.layers:  # type: ignore
                layer.trainable = True

            # /!\ Recompile, otherwise unfreeze is not taken into account ! /!\
            # Cf. https://keras.io/guides/transfer_learning/#fine-tuning
            self._compile_model(self.model, lr=self.second_lr)  # Use new LR !
            if self.logger.getEffectiveLevel() < logging.ERROR:
                self.model.summary()  # type: ignore

            # Get new callbacks
            new_callbacks = self._get_second_callbacks()

            # Second fit
            self.logger.info("Transfer Learning - Fine-tuning")
            fit_history = self.model.fit(  # type: ignore
                epochs=self.second_epochs,
                callbacks=new_callbacks,
                verbose=1,
                workers=8,  # TODO : Check if this is ok if there are less CPUs
                **fit_arguments,
            )

            # Print accuracy & loss if level_save > 'LOW'
            if self.level_save in ['MEDIUM', 'HIGH']:
                # Rename first fit plots dir
                original_plots_path = os.path.join(self.model_dir, 'plots')
                renamed_plots_path = os.path.join(self.model_dir, 'plots_initial_fit')
                shutil.move(original_plots_path, renamed_plots_path)
                # Plot new fit graphs
                self._plot_metrics_and_loss(fit_history)
                # Reload best model
                self.model = load_model_keras(
                    os.path.join(self.model_dir, 'best.hdf5'),
                    custom_objects=self.custom_objects
                )

        return fit_arguments

    def _get_model(self) -> Model:
        '''Gets a model structure - returns the instance model instead if already defined

        Returns:
            (Model): a Keras model
        '''
        # Return model if already set
        if self.model is not None:
            return self.model

        # The base model will be loaded by keras's internal functions
        # Keras uses the `get_file` function to load all files from a cache directory (or from the internet)
        # Per default, all keras's application try to load models files from the keras's cache directory (.keras)
        # However, these application do not have a parameter to change the default directory, but we want all data
        # inside the project's data directory (especially as we do not always have access to the internet).
        # To do so, we change keras's internal function `get_file` to use a directory inside our package as the cache dir.
        # This allows us to predownload a model from anysource.
        # IMPORTANT : we need to reset the `get_file` function at the end of this function

        # Monkey patching : https://stackoverflow.com/questions/5626193/what-is-monkey-patching
        cache_dir = os.path.join(utils.get_data_path(), 'cache_keras')
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        old_get_file = data_utils.get_file
        data_utils.get_file = partial(data_utils.get_file, cache_dir=cache_dir)

        # Check if the base model exists, otherwise try to download it
        # VGG 16
        # base_model_name = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
        # base_model_path = os.path.join(utils.get_data_path(), 'cache_keras', 'models', base_model_name)
        # base_model_backup_urls = [
            # 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
        # ]
        # EfficientNetB6
        base_model_name = 'efficientnetb6_notop.h5'
        base_model_path = os.path.join(utils.get_data_path(), 'cache_keras', 'models', base_model_name)
        base_model_backup_urls = [
            'https://storage.googleapis.com/keras-applications/efficientnetb6_notop.h5',
        ]
        # Check model presence
        if not os.path.exists(base_model_path):
            try:
                utils.download_url(base_model_backup_urls, base_model_path)
            except Exception:
                # If we can't download it, we let the function crash alone
                self.logger.warning("Can't find / download the base model for transfer learning application.")

        # Get input/output dimensions
        input_shape = (self.width, self.height, self.depth)
        num_classes = len(self.list_classes)

        # Process
        input_layer = Input(shape=input_shape)

        # Example VGG16 - to be used with tensorflow.keras.applications.vgg16.preprocess_input - cf _get_preprocess_input
        # We must use training=False to use the batch norm layers in inference mode
        # (cf. https://keras.io/guides/transfer_learning/)
        # base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
        # base_model.trainable = False  # We disable the first layers
        # x = base_model(input_layer, training=False)
        # x = Flatten()(x)

        # Example EfficientNetB6 - to be used with tensorflow.keras.applications.efficientnet.preprocess_input - cf _get_preprocess_input
        # We must use training=False to use the batch norm layers in inference mode
        # (cf. https://keras.io/guides/transfer_learning/)
        base_model = EfficientNetB6(weights='imagenet', include_top=False, input_shape=input_shape)
        base_model.trainable = False  # We disable the first layers
        x = base_model(input_layer, training=False)
        x = GlobalAveragePooling2D()(x)

        # Last layer
        activation = 'softmax'
        out = Dense(num_classes, activation=activation, kernel_initializer='glorot_uniform')(x)

        # Set model
        model = Model(inputs=input_layer, outputs=[out])

        # Get lr & compile
        lr = self.keras_params['learning_rate'] if 'learning_rate' in self.keras_params.keys() else 0.001
        self._compile_model(model, lr=lr)

        # Display model
        if self.logger.getEffectiveLevel() < logging.ERROR:
            model.summary()

        # Try to save model as png if level_save > 'LOW'
        if self.level_save in ['MEDIUM', 'HIGH']:
            self._save_model_png(model)

        # We reset the `get_file` function (cf. explanations)
        data_utils.get_file = old_get_file

        # Return
        return model

    def _compile_model(self, model: Model, lr: float) -> None:
        '''Compiles the model. This is usually done in _get_model, but adding a function here
        helps to simplify the fine-tuning step.

        Args:
            model (Model): Model to be compiled
            lr (float): Learning rate to be used
        '''
        # Set optimizer
        decay = self.keras_params['decay'] if 'decay' in self.keras_params.keys() else 0.0
        self.logger.info(f"Learning rate: {lr}")
        self.logger.info(f"Decay: {decay}")
        optimizer = Adam(lr=lr, decay=decay)

        # Set loss & metrics
        loss = 'categorical_crossentropy'
        metrics: List[Union[str, Callable]] = ['accuracy']

        # Compile model
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    def _get_preprocess_input(self) -> Union[Callable, None]:
        '''Gets the preprocessing to be used before feeding images to the NN

        Returns:
            (Callable | None): Preprocessing function
        '''
        # Preprocessing VGG 16
        # return vgg16_preprocess_input
        # Preprocessing efficient net
        return enet_preprocess_input

    def _get_second_callbacks(self) -> list:
        '''Gets model callbacks - second fit

        Returns:
            list: List of callbacks
        '''
        # We start by renaming 'best.hdf5' & 'logger.csv'
        if os.path.exists(os.path.join(self.model_dir, 'best.hdf5')):
            os.rename(os.path.join(self.model_dir, 'best.hdf5'), os.path.join(self.model_dir, 'best_initial_fit.hdf5'))
        if os.path.exists(os.path.join(self.model_dir, 'logger.csv')):
            os.rename(os.path.join(self.model_dir, 'logger.csv'), os.path.join(self.model_dir, 'logger_initial_fit.csv'))

        # Get classic callbacks
        callbacks = [EarlyStopping(monitor='val_loss', patience=self.second_patience, restore_best_weights=True)]
        if self.level_save in ['MEDIUM', 'HIGH']:
            callbacks.append(
                ModelCheckpoint(
                    filepath=os.path.join(self.model_dir, 'best.hdf5'), monitor='val_loss', save_best_only=True, mode='auto'
                )
            )
        callbacks.append(CSVLogger(filename=os.path.join(self.model_dir, 'logger.csv'), separator=';', append=False))
        callbacks.append(TerminateOnNaN())

        # Get LearningRateScheduler
        scheduler = self._get_second_learning_rate_scheduler()
        if scheduler is not None:
            callbacks.append(LearningRateScheduler(scheduler))

        # Manage tensorboard
        if self.level_save in ['HIGH']:
            # Get log directory
            models_path = utils.get_models_path()
            tensorboard_dir = os.path.join(models_path, 'tensorboard_logs_second_fit')
            # We add a prefix so that the function load_model works correctly (it looks for a sub-folder with model name)
            log_dir = os.path.join(tensorboard_dir, f"tensorboard_{ntpath.basename(self.model_dir)}")
            if not os.path.exists(log_dir):
                os.makedirs(log_dir)

            # TODO: check if this class does not slow proccesses
            # -> For now: comment
            # Create custom class to monitore LR changes
            # https://stackoverflow.com/questions/49127214/keras-how-to-output-learning-rate-onto-tensorboard
            # class LRTensorBoard(TensorBoard):
            #     def __init__(self, log_dir, **kwargs) -> None:  # add other arguments to __init__ if you need
            #         super().__init__(log_dir=log_dir, **kwargs)
            #
            #     def on_epoch_end(self, epoch, logs=None):
            #         logs.update({'lr': K.eval(self.model.optimizer.lr)})
            #         super().on_epoch_end(epoch, logs)

            callbacks.append(TensorBoard(log_dir=log_dir, write_grads=False, write_images=False))
            self.logger.info(f"To start tensorboard: python -m tensorboard.main --logdir {tensorboard_dir} --samples_per_plugin images=10")
            # We use samples_per_plugin to avoid a rare issue between matplotlib and tensorboard
            # https://stackoverflow.com/questions/27147300/matplotlib-tcl-asyncdelete-async-handler-deleted-by-the-wrong-thread

        return callbacks

    def _get_second_learning_rate_scheduler(self) -> Union[Callable, None]:
        '''Fonction to define a Learning Rate Scheduler - second fit
           -> if it returns None, no scheduler will be used. (def.)
           -> This function will be save directly in the model configuration file
           -> This can be overridden at runing time

        Returns:
            (Callable | None): A learning rate Scheduler
        '''
        # e.g.
        # def scheduler(epoch):
        #     lim_epoch = 75
        #     if epoch < lim_epoch:
        #         return 0.01
        #     else:
        #         return max(0.001, 0.01 * math.exp(0.01 * (lim_epoch - epoch)))
        scheduler = None
        return scheduler

    def save(self, json_data: Union[dict, None] = None) -> None:
        '''Saves the model

        Kwargs:
            json_data (dict): Additional configurations to be saved
        '''
        # Save configuration JSON
        if json_data is None:
            json_data = {}

        json_data['with_fine_tune'] = self.with_fine_tune
        json_data['second_epochs'] = self.second_epochs
        json_data['second_lr'] = self.second_lr
        json_data['second_patience'] = self.second_patience
        # Add _compile_model code if not in json_data
        if '_compile_model' not in json_data.keys():
            json_data['_compile_model'] = pickle.source.getsourcelines(self._compile_model)[0]
        # Add _get_second_learning_rate_scheduler code if not in json_data
        if '_get_second_learning_rate_scheduler' not in json_data.keys():
            json_data['_get_second_learning_rate_scheduler'] = pickle.source.getsourcelines(self._get_second_learning_rate_scheduler)[0]

        # Save
        super().save(json_data=json_data)

    def reload_from_standalone(self, **kwargs) -> None:
        '''Reloads a model from its configuration and "standalones" files
        - /!\\ Experimental /!\\ -

        Kwargs:
            configuration_path (str): Path to configuration file
            hdf5_path (str): Path to hdf5 file
            preprocess_input_path (str): Path to preprocess input file
        Raises:
            ValueError: If configuration_path is None
            ValueError: If hdf5_path is None
            ValueError: If preprocess_input_path is None
            FileNotFoundError: If the object configuration_path is not an existing file
            FileNotFoundError: If the object hdf5_path is not an existing file
            FileNotFoundError: If the object preprocess_input_path is not an existing file
        '''
        # Retrieve args
        configuration_path = kwargs.get('configuration_path', None)
        hdf5_path = kwargs.get('hdf5_path', None)
        preprocess_input_path = kwargs.get('preprocess_input_path', None)

        # Checks
        if configuration_path is None:
            raise ValueError("The argument configuration_path can't be None")
        if hdf5_path is None:
            raise ValueError("The argument hdf5_path can't be None")
        if preprocess_input_path is None:
            raise ValueError("The argument preprocess_input_path can't be None")
        if not os.path.exists(configuration_path):
            raise FileNotFoundError(f"The file {configuration_path} does not exist")
        if not os.path.exists(hdf5_path):
            raise FileNotFoundError(f"The file {hdf5_path} does not exist")
        if not os.path.exists(preprocess_input_path):
            raise FileNotFoundError(f"The file {preprocess_input_path} does not exist")

        # Load confs
        with open(configuration_path, 'r', encoding='utf-8') as f:
            configs = json.load(f)
        # Can't set int as keys in json, so need to cast it after reloading
        # dict_classes keys are always ints
        if 'dict_classes' in configs.keys():
            configs['dict_classes'] = {int(k): v for k, v in configs['dict_classes'].items()}
        elif 'list_classes' in configs.keys():
            configs['dict_classes'] = {i: col for i, col in enumerate(configs['list_classes'])}

        # Set class vars
        # self.model_name = # Keep the created name
        # self.model_dir = # Keep the created folder
        self.nb_fit = configs.get('nb_fit', 1)  # Consider one unique fit by default
        self.trained = configs.get('trained', True)  # Consider trained by default

        # Try to read the following attributes from configs and, if absent, keep the current one
        for attribute in ['model_type', 'list_classes', 'dict_classes', 'level_save',
                          'batch_size', 'epochs', 'validation_split', 'patience',
                          'width', 'height', 'depth', 'color_mode', 'in_memory',
                          'data_augmentation_params', 'nb_train_generator_images_to_save',
                          'with_fine_tune', 'second_epochs', 'second_lr', 'second_patience',
                          'keras_params']:
            setattr(self, attribute, configs.get(attribute, getattr(self, attribute)))

        # Reload model
        self.model = load_model_keras(hdf5_path, custom_objects=self.custom_objects)

        # Reload preprocess_input
        with open(preprocess_input_path, 'rb') as f:
            self.preprocess_input = pickle.load(f)

        # Save best hdf5 in new folder
        new_hdf5_path = os.path.join(self.model_dir, 'best.hdf5')
        shutil.copyfile(hdf5_path, new_hdf5_path)

__init__(with_fine_tune=True, second_epochs=10, second_lr=1e-05, second_patience=5, **kwargs)

Initialization of the class (see ModelClass, ModelKeras & ModelClassifierMixin for more arguments)

Kwargs

with_fine_tune (bool): If a fine-tuning step should be performed after first training second_epochs (int): Number of epochs for the fine-tuning step second_lr (float): Learning rate for the fine-tuning step second_patience (int): Patience for the fine-tuning step

Source code in template_vision/models_training/classifiers/model_transfer_learning_classifier.py
def __init__(self, with_fine_tune: bool = True, second_epochs: int = 10, second_lr: float = 1e-5, second_patience: int = 5, **kwargs) -> None:
    '''Initialization of the class (see ModelClass, ModelKeras & ModelClassifierMixin for more arguments)

    Kwargs:
        with_fine_tune (bool): If a fine-tuning step should be performed after first training
        second_epochs (int): Number of epochs for the fine-tuning step
        second_lr (float): Learning rate for the fine-tuning step
        second_patience (int): Patience for the fine-tuning step
    '''
    # Init.
    super().__init__(**kwargs)

    # Get logger (must be done after super init)
    self.logger = logging.getLogger(__name__)

    # Params
    self.with_fine_tune = with_fine_tune
    self.second_epochs = second_epochs
    self.second_lr = second_lr
    self.second_patience = second_patience

reload_from_standalone(**kwargs)

Reloads a model from its configuration and "standalones" files - /! Experimental /! -

Kwargs

configuration_path (str): Path to configuration file hdf5_path (str): Path to hdf5 file preprocess_input_path (str): Path to preprocess input file

Raises: ValueError: If configuration_path is None ValueError: If hdf5_path is None ValueError: If preprocess_input_path is None FileNotFoundError: If the object configuration_path is not an existing file FileNotFoundError: If the object hdf5_path is not an existing file FileNotFoundError: If the object preprocess_input_path is not an existing file

Source code in template_vision/models_training/classifiers/model_transfer_learning_classifier.py
def reload_from_standalone(self, **kwargs) -> None:
    '''Reloads a model from its configuration and "standalones" files
    - /!\\ Experimental /!\\ -

    Kwargs:
        configuration_path (str): Path to configuration file
        hdf5_path (str): Path to hdf5 file
        preprocess_input_path (str): Path to preprocess input file
    Raises:
        ValueError: If configuration_path is None
        ValueError: If hdf5_path is None
        ValueError: If preprocess_input_path is None
        FileNotFoundError: If the object configuration_path is not an existing file
        FileNotFoundError: If the object hdf5_path is not an existing file
        FileNotFoundError: If the object preprocess_input_path is not an existing file
    '''
    # Retrieve args
    configuration_path = kwargs.get('configuration_path', None)
    hdf5_path = kwargs.get('hdf5_path', None)
    preprocess_input_path = kwargs.get('preprocess_input_path', None)

    # Checks
    if configuration_path is None:
        raise ValueError("The argument configuration_path can't be None")
    if hdf5_path is None:
        raise ValueError("The argument hdf5_path can't be None")
    if preprocess_input_path is None:
        raise ValueError("The argument preprocess_input_path can't be None")
    if not os.path.exists(configuration_path):
        raise FileNotFoundError(f"The file {configuration_path} does not exist")
    if not os.path.exists(hdf5_path):
        raise FileNotFoundError(f"The file {hdf5_path} does not exist")
    if not os.path.exists(preprocess_input_path):
        raise FileNotFoundError(f"The file {preprocess_input_path} does not exist")

    # Load confs
    with open(configuration_path, 'r', encoding='utf-8') as f:
        configs = json.load(f)
    # Can't set int as keys in json, so need to cast it after reloading
    # dict_classes keys are always ints
    if 'dict_classes' in configs.keys():
        configs['dict_classes'] = {int(k): v for k, v in configs['dict_classes'].items()}
    elif 'list_classes' in configs.keys():
        configs['dict_classes'] = {i: col for i, col in enumerate(configs['list_classes'])}

    # Set class vars
    # self.model_name = # Keep the created name
    # self.model_dir = # Keep the created folder
    self.nb_fit = configs.get('nb_fit', 1)  # Consider one unique fit by default
    self.trained = configs.get('trained', True)  # Consider trained by default

    # Try to read the following attributes from configs and, if absent, keep the current one
    for attribute in ['model_type', 'list_classes', 'dict_classes', 'level_save',
                      'batch_size', 'epochs', 'validation_split', 'patience',
                      'width', 'height', 'depth', 'color_mode', 'in_memory',
                      'data_augmentation_params', 'nb_train_generator_images_to_save',
                      'with_fine_tune', 'second_epochs', 'second_lr', 'second_patience',
                      'keras_params']:
        setattr(self, attribute, configs.get(attribute, getattr(self, attribute)))

    # Reload model
    self.model = load_model_keras(hdf5_path, custom_objects=self.custom_objects)

    # Reload preprocess_input
    with open(preprocess_input_path, 'rb') as f:
        self.preprocess_input = pickle.load(f)

    # Save best hdf5 in new folder
    new_hdf5_path = os.path.join(self.model_dir, 'best.hdf5')
    shutil.copyfile(hdf5_path, new_hdf5_path)

save(json_data=None)

Saves the model

Kwargs

json_data (dict): Additional configurations to be saved

Source code in template_vision/models_training/classifiers/model_transfer_learning_classifier.py
def save(self, json_data: Union[dict, None] = None) -> None:
    '''Saves the model

    Kwargs:
        json_data (dict): Additional configurations to be saved
    '''
    # Save configuration JSON
    if json_data is None:
        json_data = {}

    json_data['with_fine_tune'] = self.with_fine_tune
    json_data['second_epochs'] = self.second_epochs
    json_data['second_lr'] = self.second_lr
    json_data['second_patience'] = self.second_patience
    # Add _compile_model code if not in json_data
    if '_compile_model' not in json_data.keys():
        json_data['_compile_model'] = pickle.source.getsourcelines(self._compile_model)[0]
    # Add _get_second_learning_rate_scheduler code if not in json_data
    if '_get_second_learning_rate_scheduler' not in json_data.keys():
        json_data['_get_second_learning_rate_scheduler'] = pickle.source.getsourcelines(self._get_second_learning_rate_scheduler)[0]

    # Save
    super().save(json_data=json_data)