import os
import logging
import torch
import numpy as np
import pandas as pd
from collections import OrderedDict
import time
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from itertools import product
import matplotlib.pyplot as plt
import torch.nn.functional as F
from eval import fit_lr, fit_svm, make_representation
from sklearn.metrics import accuracy_score, precision_recall_curve, roc_curve, auc
from Models.loss import l2_reg_loss
from Models import utils, analysis
from Models.optimizers import get_optimizer
import torch.distributed as dist
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import sklearn
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from utils import EarlyStopping
from pyhealth.metrics import binary_metrics_fn, multiclass_metrics_fn
from Models.connectivity import get_pseudolabels
from Models.loss import get_loss_module

logger = logging.getLogger('__main__')

NEG_METRICS = {'loss'}  # metrics for which "better" is less


class BaseTrainer(object):

    def __init__(self, model, pre_train_loader, train_loader, test_loader, config, optimizer=None, l2_reg=None, print_interval=10,
                 console=True, print_conf_mat=False):
        self.model = model
        self.pre_train_loader = pre_train_loader
        self.train_loader = train_loader
        self.test_loader = test_loader
        if config['problem'] == 'TUAB' or config['problem'] =='CHB-MIT':
            self.use_binary_metric = True
        else:
            self.use_binary_metric = False
        self.device = config['device']
        self.epochs = config['epochs']
        self.label_type= config['label_type']
        self.mixed = config['mixed']
        self.temperature = config['temperature']
        self.patch_size = config['patch_size']
        self.sampling_rate = config['sampling_rate']
        self.mask_ratio = config['mask_ratio']
        self.chunk_size = config['chunk_size']
        self.warmup_epochs = config['warmup_epochs']
        self.optimizer = config['optimizer']
        self.scheduler = self.lr_scheduler()
        self.loss_module = config['loss_module']
        self.l2_reg = l2_reg
        self.print_interval = print_interval
        self.printer = utils.Printer(console=console)
        self.print_conf_mat = print_conf_mat
        self.epoch_metrics = OrderedDict()
        self.save_path = config['output_dir']

    '''LR scheduler for supervised training'''
    def lr_scheduler(self):
        scheduler = CosineAnnealingLR(self.optimizer, T_max=self.epochs)
        return scheduler
    
    def train_epoch(self, epoch_num=None):
        raise NotImplementedError('Please override in child class')

    def evaluate(self, epoch_num=None, keep_all=True):
        raise NotImplementedError('Please override in child class')

    def print_callback(self, i_batch, metrics, prefix=''):
        total_batches = len(self.dataloader)
        template = "{:5.1f}% | batch: {:9d} of {:9d}"
        content = [100 * (i_batch / total_batches), i_batch, total_batches]
        for met_name, met_value in metrics.items():
            template += "\t|\t{}".format(met_name) + ": {:g}"
            content.append(met_value)

        dyn_string = template.format(*content)
        dyn_string = prefix + dyn_string
        self.printer.print(dyn_string)


class Self_Supervised_Trainer(BaseTrainer):
    def __init__(self, *args, **kwargs):

        super(Self_Supervised_Trainer, self).__init__(*args, **kwargs)
        if kwargs['print_conf_mat']:
            self.analyzer = analysis.Analyzer(print_conf_mat=True)
        self.mse = nn.MSELoss(reduction='none')
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.criterion = nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.0)
        self.scheduler = self.lr_scheduler()

    def calculate_rec_loss(self, rec, target):  
        target = target / target.norm(dim=-1, keepdim=True)
        rec = rec / rec.norm(dim=-1, keepdim=True)
        rec_loss = (1 - (target * rec).sum(-1)).mean()
        return rec_loss
    
    def train_epoch(self, epoch_num=None):
        self.model = self.model.train()
        epoch_loss = 0  # total loss of epoch
        total_batches = 0 
        for i, batch in enumerate(self.pre_train_loader):
            X, targets, IDs = batch
            X = X.to(self.device)
            B, C, T = X.shape
            rep_mask_prediction, mask = self.model.pretrain_forward(X)
            label = get_pseudolabels(X, self.patch_size, is_mixed=self.mixed, label_type=self.label_type, sampling_rate=self.sampling_rate, temperature=self.temperature)
            label_masked = label[mask]
            coherence_loss = self.criterion(rep_mask_prediction, label_masked)
            total_loss = coherence_loss
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            total_batches += 1
            epoch_loss += total_loss.item()

        self.scheduler.step()
        epoch_loss = epoch_loss / total_batches  # average loss per sample for whole epoch
        self.epoch_metrics['epoch'] = epoch_num
        self.epoch_metrics['loss'] = epoch_loss
        self.epoch_metrics['coherence'] = coherence_loss
        if (epoch_num + 1) % 5 == 0:
            self.model.eval()
            train_repr, train_labels = make_representation(self.model, self.train_loader)
            test_repr, test_labels = make_representation(self.model, self.test_loader)
            clf = fit_lr(train_repr.cpu().detach().numpy(), train_labels.cpu().detach().numpy(), MAX_SAMPLES=500000)
            y_hat = clf.predict(test_repr.cpu().detach().numpy())
            acc_test = accuracy_score(test_labels.cpu().detach().numpy(), y_hat)
            print('Test_acc:', acc_test)
            result_file = open(self.save_path + '/linear_result.txt', 'a+')
            print('{0}, {1}, {2}'.format(int(epoch_num), acc_test, coherence_loss),
                  file=result_file)
            result_file.close()
 
        return self.epoch_metrics, self.model

def SS_train_runner(config, model, trainer, path):
    epochs = config['epochs']
    # epochs = 5
    optimizer = config['optimizer']
    loss_module = config['loss_module']
    start_epoch = 0
    total_start_time = time.time()
    tensorboard_writer = SummaryWriter(config['tensorboard_dir'])
    metrics = []  # (for validation) list of lists: for each epoch, stores metrics like loss, ...
    save_best_model = utils.SaveBestModel()
    Total_loss = []
    for epoch in tqdm(range(start_epoch + 1, epochs + 1), desc='Training Epoch', leave=False):

        aggr_metrics_train, model = trainer.train_epoch(epoch)  # dictionary of aggregate epoch metrics
        metrics_names, metrics_values = zip(*aggr_metrics_train.items())
        metrics.append(list(metrics_values))
        Total_loss.append(aggr_metrics_train['loss'])
        print_str = 'Epoch {} Training Summary: '.format(epoch)
        for k, v in aggr_metrics_train.items():
            tensorboard_writer.add_scalar('{}/train'.format(k), v, epoch)
            print_str += '{}: {:8f} | '.format(k, v)
        logger.info(print_str)
        if epoch > 50 or epochs <= 50:
            save_best_model(aggr_metrics_train['loss'], epoch, model, optimizer, loss_module, path)
    total_runtime = time.time() - total_start_time
    logger.info("Train Time: {} hours, {} minutes, {} seconds\n".format(*utils.readable_time(total_runtime)))
    return


class SupervisedTrainer(BaseTrainer):

    def __init__(self, *args, **kwargs):
        super(SupervisedTrainer, self).__init__(*args, **kwargs)
        self.analyzer = analysis.Analyzer(print_conf_mat=False, use_binary_metric=self.use_binary_metric, device=self.device)
        if kwargs['print_conf_mat']:
            self.analyzer = analysis.Analyzer(print_conf_mat=True, use_binary_metric=self.use_binary_metric, device=self.device)

    def train_epoch(self, epoch_num=None):
        self.model = self.model.train()
        epoch_loss = 0  # total loss of epoch
        total_batches = 0 
        for i, batch in enumerate(self.train_loader):
            X, targets, IDs = batch
            X = X.to(self.device)
            targets = targets.to(self.device)
            predictions = self.model(X)
            total_loss = self.loss_module(predictions, targets)  # (batch_size,) loss for each sample in the batch

            # Zero gradients, perform a backward pass, and update the weights.
            self.optimizer.zero_grad()
            total_loss.backward()

            # torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=4.0)
            self.optimizer.step()

            with torch.no_grad():
                total_batches += 1
                epoch_loss += total_loss.item()
        self.scheduler.step()

        epoch_loss = epoch_loss / total_batches  # average loss per sample for whole epoch
        self.epoch_metrics['epoch'] = epoch_num
        self.epoch_metrics['loss'] = epoch_loss
        return self.epoch_metrics

    def evaluate(self, epoch_num=None, keep_all=True):

        self.model = self.model.eval()

        epoch_loss = 0  # total loss of epoch
        total_batches = 0 

        per_batch = {'targets': [], 'predictions': [], 'metrics': [], 'IDs': []}
        for i, batch in enumerate(self.train_loader):
            X, targets, IDs = batch
            X = X.to(self.device)
            targets = targets.to(self.device)
            predictions = self.model(X)
            loss = self.loss_module(predictions, targets)  # (batch_size,) loss for each sample in the batch
            mean_loss =  loss.cpu().item()  # mean loss (over samples)

            per_batch['targets'].append(targets.cpu().numpy())
            predictions = predictions.detach()
            per_batch['predictions'].append(predictions.cpu().numpy())
            loss = loss.detach()
            per_batch['metrics'].append([loss.cpu().numpy()])
            per_batch['IDs'].append(IDs)

            metrics = {"loss": mean_loss}

            total_batches += 1
            epoch_loss += mean_loss  # add total loss of batch

        epoch_loss /= total_batches  # average loss per element for whole epoch
        self.epoch_metrics['epoch'] = epoch_num
        self.epoch_metrics['loss'] = epoch_loss

        predictions = torch.from_numpy(np.concatenate(per_batch['predictions'], axis=0))
        targets = np.concatenate(per_batch['targets'], axis=0).flatten()

        if self.use_binary_metric:
            metrics_dict = self.analyzer.analyze_binary(predictions.squeeze(),  torch.from_numpy(targets))
            self.epoch_metrics['accuracy'] = metrics_dict['total_accuracy']  # same as average recall over all classes
            self.epoch_metrics['f1'] = metrics_dict['f1']
            self.epoch_metrics['AUROC'] = metrics_dict['auroc']
            result = binary_metrics_fn(targets, torch.sigmoid(predictions.squeeze()).cpu().numpy(), metrics=["balanced_accuracy"])
            self.epoch_metrics['B-accuracy'] = result['balanced_accuracy']
        else:
            probs = torch.nn.functional.softmax(predictions,
                                                dim=1)  # (total_samples, num_classes) est. prob. for each class and sample
            predictions = torch.argmax(probs, dim=1).cpu().numpy()  # (total_samples,) int class index for each sample
            probs = probs.cpu().numpy()
            class_names = np.arange(probs.shape[1]) 
            metrics_dict = self.analyzer.analyze_classification(predictions, targets, class_names)
            self.epoch_metrics['accuracy'] = metrics_dict['total_accuracy']  # same as average recall over all classes
            self.epoch_metrics['precision'] = metrics_dict['prec_avg']  # average precision over all classes
            if max(targets) < 2 == 2:
                false_pos_rate, true_pos_rate, _ = roc_curve(targets, probs[:, 1])  # 1D scores needed
                self.epoch_metrics['AUROC'] = auc(false_pos_rate, true_pos_rate)

                prec, rec, _ = precision_recall_curve(targets, probs[:, 1])
                self.epoch_metrics['AUPRC'] = auc(rec, prec)
            result = multiclass_metrics_fn(targets, probs, metrics=['balanced_accuracy', 'cohen_kappa', 'f1_weighted'])
            self.epoch_metrics['B-accuracy'] = result['balanced_accuracy']
            self.epoch_metrics['cohen_kappa'] = result['cohen_kappa']
            self.epoch_metrics['f1_weighted'] = result['f1_weighted']

        return self.epoch_metrics, metrics_dict


def validate(val_evaluator, tensorboard_writer, config, best_metrics, best_value, epoch):
    """Run an evaluation on the validation set while logging metrics, and handle outcome"""

    with torch.no_grad():
        aggr_metrics, ConfMat = val_evaluator.evaluate(epoch, keep_all=True)

    print()
    print_str = 'Validation Summary: '
    for k, v in aggr_metrics.items():
        tensorboard_writer.add_scalar('{}/val'.format(k), v, epoch)
        print_str += '{}: {:8f} | '.format(k, v)
    logger.info(print_str)

    if config['key_metric'] in NEG_METRICS:
        condition = (aggr_metrics[config['key_metric']] < best_value)
    else:
        condition = (aggr_metrics[config['key_metric']] > best_value)
    if condition:
        best_value = aggr_metrics[config['key_metric']]
        utils.save_model(os.path.join(config['save_dir'], 'model_best.pth'), epoch, val_evaluator.model)
        best_metrics = aggr_metrics.copy()

    return aggr_metrics, best_metrics, best_value


def Strain_runner(config, model, trainer, evaluator, path):
    epochs = config['epochs']
    optimizer = config['optimizer']
    loss_module = config['loss_module']
    start_epoch = 0
    total_start_time = time.time()
    tensorboard_writer = SummaryWriter(config['tensorboard_dir'])
    best_value = 1e16
    metrics = []  # (for validation) list of lists: for each epoch, stores metrics like loss, ...
    best_metrics = {}
    save_best_model = utils.SaveBestModel()
    early_stopping = EarlyStopping(patience=config['patience'], verbose=True)

    for epoch in tqdm(range(start_epoch + 1, epochs + 1), desc='Training Epoch', leave=False):

        aggr_metrics_train = trainer.train_epoch(epoch)  # dictionary of aggregate epoch metrics
        aggr_metrics_val, best_metrics, best_value = validate(evaluator, tensorboard_writer, config, best_metrics,
                                                              best_value, epoch)
        save_best_model(aggr_metrics_val['loss'], epoch, model, optimizer, loss_module, path)
        metrics_names, metrics_values = zip(*aggr_metrics_train.items())
        metrics.append(list(metrics_values))

        print_str = 'Epoch {} Training Summary: '.format(epoch)
        for k, v in aggr_metrics_train.items():
            tensorboard_writer.add_scalar('{}/train'.format(k), v, epoch)
            print_str += '{}: {:8f} | '.format(k, v)
        logger.info(print_str)
        early_stopping(aggr_metrics_val['loss'])
        if early_stopping.early_stop:
            break
    total_runtime = time.time() - total_start_time
    logger.info("Train Time: {} hours, {} minutes, {} seconds\n".format(*utils.readable_time(total_runtime)))
    return
