from torch.utils.data import DataLoader
from utils import dataset_class, save_data
from Models.model import Encoder_factory, count_parameters
from thop import profile, clever_format
from Models.loss import get_loss_module
from Models.utils import load_model, get_representation
from trainer import *
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

logger = logging.getLogger('__main__')


def Rep_Learning(config, Data):
    # ---------------------------------------- Self Supervised Data -------------------------------------
    if config['problem'] !='TUAB' and config['problem'] !='TUEV' and config['problem'] !='CHB-MIT':
        pre_train_dataset = dataset_class(Data['All_train_data'], Data['All_train_label'], config['patch_size'], Data['coherence_labels'])
        train_dataset = dataset_class(Data['All_train_data'], Data['All_train_label'], config['patch_size'], Data['coherence_labels'])
        test_dataset = dataset_class(Data['test_data'], Data['test_label'], config['patch_size'])
        config['Data_shape'] = Data['All_train_data'].shape
        config['num_labels'] = int(max(Data['All_train_label'])) + 1
    else:
        pre_train_dataset = Data['pretrain_loader']
        train_dataset = Data['train_loader']
        test_dataset = Data['test_loader']
        channel_size, time_steps = pre_train_dataset.__getitem__(0)[0].shape
        config['Data_shape'] = (pre_train_dataset.__len__(), channel_size, time_steps)
        if config['problem'] =='TUAB' or config['problem'] =='CHB-MIT':
            config['num_labels'] = 1
        elif config['problem'] =='TUEV':
            config['num_labels'] = 6
        logger.info("{} samples will be used for self-supervised training".format(config['Data_shape'][0]))
        logger.info("{} samples will be used for fine tuning ".format(train_dataset.__len__()))
        logger.info("{} samples will be used for test".format(test_dataset.__len__()))
        logger.info("Each sample has {} channels, {} time steps ".format(channel_size, time_steps))
    
    pre_train_loader = DataLoader(dataset=pre_train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    train_loader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    # For Linear Probing During the Pre-Training
    test_loader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    # --------------------------------------------------------------------------------------------------------------
    # -------------------------------------------- Build Model -----------------------------------------------------
    logger.info("Pre-Training Self Supervised model ...")

    Encoder = Encoder_factory(config)
    dummy_input = torch.randn(1, config['Data_shape'][1], config['Data_shape'][2])
    macs, params = profile(Encoder, inputs=(dummy_input,))
    macs, params = clever_format([macs, params], "%.3f")
    logger.info("Model:\n{}".format(Encoder))
    logger.info("Total number of profile macs: {}, profile parameters: {}, parameters: {}".format(macs, params, count_parameters(Encoder)))
    # ---------------------------------------------- Model Initialization ----------------------------------------------
    optim_class = get_optimizer(config['optimizer_class'])
    params = []
    for name, param in Encoder.named_parameters():
        if ('bias' in name) or ('layer_norm' in name) or ('ln' in name):
            param_group = {'params': [param], 'weight_decay': 0}
        else:
            param_group = {'params': [param], 'weight_decay': config['weight_decay']}
        params.append(param_group)
    config['optimizer'] = optim_class(params, lr=config['lr'])
    config['problem_type'] = 'Self-Supervised'
    config['loss_module'] = get_loss_module(config['problem'])

    save_path = os.path.join(config['save_dir'], config['problem'] +'model_{}.pth'.format('ss'))
    Encoder.to(config['device'])
    # ------------------------------------------------- Training The Model ---------------------------------------------
    logger.info('Self-Supervised training...')
    SS_trainer = Self_Supervised_Trainer(Encoder, pre_train_loader, train_loader, test_loader, config, l2_reg=0, print_conf_mat=False)
    SS_train_runner(config, Encoder, SS_trainer, save_path)
    # **************************************************************************************************************** #
    # --------------------------------------------- Downstream Task (classification)   ---------------------------------
    # ---------------------- Loading the model and freezing layers except FC layer -------------------------------------
    config['dropout'] = config['dropout2']
    Encoder = Encoder_factory(config)
    Encoder.to(config['device'])
    SS_Encoder, optimizer, start_epoch = load_model(Encoder, save_path, config['optimizer'])  # Loading the model
    SS_Encoder.to(config['device'])
    # --------------------------------- Load Data -------------------------------------------------------------
    if config['problem'] !='TUAB' and config['problem'] !='TUEV' and config['problem'] !='CHB-MIT':
        train_dataset = dataset_class(Data['train_data'], Data['train_label'], config['patch_size'])
        val_dataset = dataset_class(Data['val_data'], Data['val_label'], config['patch_size'])
        test_dataset = dataset_class(Data['test_data'], Data['test_label'], config['patch_size'])
    else:
        train_dataset = Data['train_loader']
        val_dataset = Data['eval_loader']
        test_dataset = Data['test_loader']

    train_loader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    val_loader = DataLoader(dataset=val_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    test_loader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])

    logger.info('Starting Linear Probing...')
    
    params = []
    for name, param in Encoder.named_parameters():
        if any([f in name for f in ['pred_head', 'CLS', 'attention_pool']]):
            param_group = {'params': param, 'lr': config['lr2'], 'weight_decay':  config['weight_decay']}
            params.append(param_group)
        else:
            param_group = {'params': param, 'lr':0., 'weight_decay': 0.}
            params.append(param_group)
        
    config['optimizer'] = optim_class(params)

    S_trainer = SupervisedTrainer(SS_Encoder, None, train_loader, None, config, print_conf_mat=False)
    S_val_evaluator = SupervisedTrainer(SS_Encoder, None, val_loader, None, config, print_conf_mat=False)

    save_path2 = os.path.join(config['save_dir'], config['problem'] + '_model_{}.pth'.format('lp'))
    Strain_runner(config, SS_Encoder, S_trainer, S_val_evaluator, save_path2)

    best_Encoder, optimizer, start_epoch = load_model(Encoder, save_path2, config['optimizer'])
    best_Encoder.to(config['device'])

    best_test_evaluator = SupervisedTrainer(best_Encoder, None, test_loader, None, config, print_conf_mat=True)
    best_aggr_metrics_test, all_metrics = best_test_evaluator.evaluate(keep_all=True)
    print_str = 'Linear Probe Best Model Test Summary: '
    for k, v in best_aggr_metrics_test.items():
        print_str += '{}: {} | '.format(k, v)
    print(print_str)

    SS_Encoder, optimizer, start_epoch = load_model(Encoder, save_path, config['optimizer'])  # Loading the model
    SS_Encoder.to(config['device'])

    logger.info('Starting Fine_Tuning...')

    params = []
    for name, param in Encoder.named_parameters():
        # only update the parameters that are not frozen
        if not param.requires_grad:
            continue
        # not using decay in bias & ln
        if any([f in name for f in ['bias', 'layer_norm', 'ln']]):
            weight_decay = 0.0
        else:
            weight_decay = config['weight_decay']
            
        if any([f in name for f in ['pred_head', 'CLS', 'attention_pool']]):
            param_group = {'params': param, 'lr': config['lr2'], 'weight_decay': weight_decay}
            params.append(param_group)
        else:
            param_group = {'params': param, 'lr': config['lr2'] * config['lr_ratio'], 'weight_decay': weight_decay}
            params.append(param_group)
        
    config['optimizer'] = optim_class(params)

    S_trainer = SupervisedTrainer(SS_Encoder, None, train_loader, None, config, print_conf_mat=False)
    S_val_evaluator = SupervisedTrainer(SS_Encoder, None, val_loader, None, config, print_conf_mat=False)

    save_path3 = os.path.join(config['save_dir'], config['problem'] + '_model_{}.pth'.format('s'))
    Strain_runner(config, SS_Encoder, S_trainer, S_val_evaluator, save_path3)

    best_Encoder, optimizer, start_epoch = load_model(Encoder, save_path3, config['optimizer'])
    best_Encoder.to(config['device'])

    best_test_evaluator = SupervisedTrainer(best_Encoder, None, test_loader, None, config, print_conf_mat=True)
    best_aggr_metrics_test, all_metrics = best_test_evaluator.evaluate(keep_all=True)
    return best_aggr_metrics_test, all_metrics


def Supervised(config, Data):
    config['problem_type'] = 'Supervised'
    config['loss_module'] = get_loss_module(config['problem'])
    # ------------------------------------------------- Training The Model ---------------------------------------------

    # --------------------------------- Load Data -------------------------------------------------------------
    if config['problem'] != 'TUAB' and config['problem'] != 'TUEV' and config['problem'] != 'CHB-MIT':
        train_dataset = dataset_class(Data['train_data'], Data['train_label'], config['patch_size'])
        val_dataset = dataset_class(Data['val_data'], Data['val_label'], config['patch_size'])
        test_dataset = dataset_class(Data['test_data'], Data['test_label'], config['patch_size'])
        config['Data_shape'] = Data['train_data'].shape
        config['num_labels'] = int(max(Data['train_label'])) + 1

    else:
        train_dataset = Data['train_loader']
        val_dataset = Data['eval_loader']
        test_dataset = Data['test_loader']
        channel_size, time_steps = train_dataset.__getitem__(0)[0].shape
        config['Data_shape'] = (train_dataset.__len__(), channel_size, time_steps)  
        if config['problem'] =='TUAB' or config['problem'] =='CHB-MIT':
            config['num_labels'] = 1
        elif config['problem'] =='TUEV':
            config['num_labels'] = 6
    # -------------------------------------------- Build Model -----------------------------------------------------
    config['dropout'] = config['dropout2']
    Encoder = Encoder_factory(config)
    dummy_input = torch.randn(1, config['Data_shape'][1], config['Data_shape'][2])
    macs, params = profile(Encoder, inputs=(dummy_input,))
    macs, params = clever_format([macs, params], "%.3f")
    logger.info("Model:\n{}".format(Encoder))
    logger.info("Total number of profile macs: {}, profile parameters: {}, parameters: {}".format(macs, params, count_parameters(Encoder)))
    Encoder.to(config['device'])

    optim_class = get_optimizer(config['optimizer_class'])
    params = []
    if config['pretrained_model'] != '':
        print('Found pretrained model for Supervised Learning.')
        Encoder = load_model(Encoder, config['pretrained_model'])
        Encoder.to(config['device'])
        for name, param in Encoder.named_parameters():
            # only update the parameters that are not frozen
            if not param.requires_grad:
                continue
            # not using decay in bias & ln
            if any([f in name for f in ['bias', 'layer_norm', 'ln']]):
                weight_decay = 0.0
            else:
                weight_decay = config['weight_decay']
                
            if any([f in name for f in ['pred_head', 'CLS', 'attention_pool']]):
                param_group = {'params': param, 'lr': config['lr2'], 'weight_decay': weight_decay}
                params.append(param_group)
            else:
                param_group = {'params': param, 'lr': config['lr2'] * config['lr_ratio'], 'weight_decay': weight_decay}
                params.append(param_group)
    else:
        for name, param in Encoder.named_parameters():
            # only update the parameters that are not frozen
            if not param.requires_grad:
                continue
            # not using decay in bias & ln
            if any([f in name for f in ['bias', 'layer_norm', 'ln']]):
                weight_decay = 0.0
            else:
                weight_decay = config['weight_decay']
                
            param_group = {'params': param, 'lr': config['lr2'], 'weight_decay': weight_decay}
            params.append(param_group)
    config['optimizer'] = optim_class(params)

    train_loader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    val_loader = DataLoader(dataset=val_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])
    test_loader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=config['num_workers'])

    S_trainer = SupervisedTrainer(Encoder, None, train_loader, None, config, print_conf_mat=False)
    S_val_evaluator = SupervisedTrainer(Encoder, None, val_loader, None, config, print_conf_mat=False)

    save_path = os.path.join(config['save_dir'], config['problem'] + '_2_model_{}.pth'.format('supervised'))
    Strain_runner(config, Encoder, S_trainer, S_val_evaluator, save_path)
    best_Encoder, optimizer, start_epoch = load_model(Encoder, save_path, config['optimizer'])
    best_Encoder.to(config['device'])

    best_test_evaluator = SupervisedTrainer(best_Encoder, None, test_loader, None, config, print_conf_mat=True)
    best_aggr_metrics_test, all_metrics = best_test_evaluator.evaluate(keep_all=True)

    return best_aggr_metrics_test, all_metrics