#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script to train the multi-label classifier
"""

import torch
import pytorch_lightning as pl
from pytorch_lightning import loggers,seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import os
from pathlib import Path
import traceback
import numpy as np

import sys
sys.path.append("../")

from datasets.fastmri_annotated_multi import FastMRIMultiLabel
from datasets.fastmri_multicoil import FastMRIDataModule
from util import helper
import variables


#Get the input arguments
args = helper.flags()

#Get the checkpoint arguments if needed
load_ckpt_dir = args.load_ckpt_dir
load_last_ckpt = args.load_last_ckpt


if __name__ == "__main__":
    
    #Use new configurations if not loading a pretrained model
    if load_ckpt_dir == 'None':
        model_type = 'MultiLabelClassifier'

        #Get the configurations
        configuration = helper.select_config(model_type)
        config = configuration.config
        ckpt=None
    
    #Load the previous configurations
    else:
        ckpt_name = 'last.ckpt' if load_last_ckpt else 'best_val_loss.ckpt'
        ckpt = os.path.join(load_ckpt_dir,
                            'checkpoints',
                            ckpt_name)

        #Get the configuration file
        config_file = os.path.join(load_ckpt_dir, 'configs.pkl')
        config = helper.read_pickle(config_file)
        config['train_args']['epochs'] = 150
    

    try:
        # Get the directory of the dataset
        base_dir = variables.fastmri_paths[config['data_args']['mri_type']]

        # Get the model type
        model_type = 'MultiLabelClassifier'

        # Get the data
        if config['data_args']['challenge'] == 'multicoil':
            data = FastMRIDataModule(base_dir,
                                    batch_size=config['train_args']['batch_size'],
                                    num_data_loader_workers=4,
                                    **config['data_args'],
                                    )
        else:
            data = FastMRIMultiLabel(base_dir,
                                     batch_size=config['train_args']['batch_size'],
                                     num_data_loader_workers=8,
                                     #evaluating=True,
                                     **config['data_args'],
                                     )
        data.prepare_data()
        data.setup()

        # Change the BCE weight based on the imbalance in the data
        num_pos_samples = data.train.total_labels
        total_samples = len(data.train)
        num_neg_samples = total_samples - num_pos_samples
        bce_weights = num_neg_samples/num_pos_samples # Ratio of negative to positive samples
        #bce_weights = np.ones_like(bce_weights) # Try setting to 1
        config['net_args']['bce_weight'] = bce_weights

        # Add the labels to the configuration
        config['data_args']['labels'] = data.train.label_names


        #Load the model
        model = helper.load_model(model_type, config, ckpt)

        # Compile the model (Doesn't work if there's complex numbers like in fft2c)
        #model = torch.compile(model)

        # Create the tensorboard logger
        Path(variables.log_dir).mkdir(parents=True, exist_ok=True)
        logger = loggers.TensorBoardLogger(variables.log_dir, name=model_type)

        # Create the checkpoint callback
        ckpt_callback = ModelCheckpoint(
            save_top_k = 1,
            monitor= 'val_loss', #'Val AUROC', #TODO: Changed ckpt to monitor val loss instead of Val AUROC
            mode = 'min',
            filename='best_val_loss',
            )

        ckpt_callback_2 = ModelCheckpoint(
            save_top_k = 1,
            monitor= 'Val AUROC',
            mode = 'max',
            filename='best_val_auroc',
            )

        early_stop_callback = EarlyStopping(
            monitor='val_loss',
            mode='min',
            patience=10,
        )

        # Try to use stochastic weight averaging
        swa_callback = StochasticWeightAveraging(swa_lrs=config['train_args']['lr']/2, swa_epoch_start= 10, annealing_epochs=10)

        # Create the trainer
        trainer = pl.Trainer(
            max_epochs=config['train_args']['epochs'],
            accelerator='gpu',
            logger=logger,
            check_val_every_n_epoch=1,
            callbacks=[ckpt_callback, ckpt_callback_2, swa_callback], # TODO: Added early stopping for the validation loss, got rid of swa_callback,
            strategy='ddp_find_unused_parameters_true',
            #limit_train_batches=32,
            #limit_val_batches=16,
        )


        # Save the configurations
        model_path = trainer.logger.log_dir
        Path(model_path).mkdir(parents=True, exist_ok=True)
        config_file = os.path.join(model_path, 'configs.pkl')
        helper.write_pickle(config, config_file)

        # Train the model
        if ckpt is None:
            print("Starting Training")
            trainer.fit(model, data.train_dataloader(), data.val_dataloader())
            trainer.save_checkpoint(os.path.join(model_path,'checkpoints','last.ckpt'))

        else:
            print("Resuming Training")
            trainer.fit(model, data.train_dataloader(), data.val_dataloader(),ckpt_path=ckpt)
            trainer.save_checkpoint(os.path.join(model_path,'checkpoints','last.ckpt'))


    except:

        traceback.print_exc()
       
        

