import os
from tqdm import tqdm
import torch
from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list
from configs.supported import process_outputs_functions
#import nni
import globalvariable
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def run_epoch(algorithm, dataset, general_logger, epoch, config, train, val_dataset=None, 
              test_datasets=None, test=False, best_e=False):
    if dataset['verbose']:
        general_logger.write(f"\n{dataset['name']}:\n")

    if train:
        algorithm.train()
        torch.set_grad_enabled(True)
    else:
        algorithm.eval()
        torch.set_grad_enabled(True)

    # Not preallocating memory is slower
    # but makes it easier to handle different types of data loaders
    # (which might not return exactly the same number of examples per epoch)
    epoch_y_true = []
    epoch_y_pred = []
    epoch_metadata = []
    epoch_y_uncertainty = []
    epoch_y_group = []

    # Using enumerate(iterator) can sometimes leak memory in some environments (!)
    # so we manually increment batch_idx
    batch_idx = 0
    iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader']

    for batch in iterator:
        if train:
            if 'UMIX' in config.algorithm or 'BMIX' in config.algorithm:
                batch_results = algorithm.update(batch, epoch)
            else:   
                batch_results = algorithm.update(batch)
        else:
            batch_results = algorithm.evaluate(batch, test=test)
            if config.algorithm == "BMIX":
                batch_weights = algorithm.evaluate_uncertainty(batch)
                epoch_y_uncertainty.append(batch_weights['weight'])
                epoch_y_group.append(batch_weights['g'])
        # These tensors are already detached, but we need to clone them again
        # Otherwise they don't get garbage collected properly in some versions
        # The extra detach is just for safety
        # (they should already be detached in batch_results)
        epoch_y_true.append(detach_and_clone(batch_results['y_true']))
        y_pred = detach_and_clone(batch_results['y_pred'])
        if config.process_outputs_function is not None:
            y_pred = process_outputs_functions[config.process_outputs_function](y_pred)
        epoch_y_pred.append(y_pred)
        epoch_metadata.append(detach_and_clone(batch_results['metadata']))

        if train and (batch_idx+1) % config.log_every==0:
            log_results(algorithm, dataset, general_logger, epoch, batch_idx)

        batch_idx += 1

        if train and config.evaluate_steps is not None and batch_idx % config.evaluate_steps == 0:
            if val_dataset is not None:
                # Then run val
                val_results, val_y_pred = run_epoch(algorithm, val_dataset, general_logger, epoch, config, train=False)
                
                curr_val_metric = val_results[config.val_metric]
                general_logger.write(f'Validation {config.val_metric}: {curr_val_metric:.3f}\n')

                if globalvariable.best_val_metric is None:
                    is_best = True
                else:
                    if config.val_metric_decreasing:
                        is_best = curr_val_metric < globalvariable.best_val_metric
                    else:
                        is_best = curr_val_metric > globalvariable.best_val_metric
                if is_best:
                    globalvariable.best_val_metric = curr_val_metric
                    general_logger.write(f'Epoch {epoch} has the best validation performance so far.\n')

                save_model_if_needed(algorithm, val_dataset, epoch, config, is_best, globalvariable.best_val_metric)
                save_pred_if_needed(val_y_pred, val_dataset, epoch, config, is_best)
            if test_datasets is not None:
                for test_dataset in test_datasets:
                    metric_split, test_y_pred = run_epoch(algorithm, test_dataset, general_logger, epoch, config, 
                                                          train=False, test=True, best_e=is_best)
            test=False
            algorithm.train()
            torch.set_grad_enabled(True)


    epoch_y_pred = collate_list(epoch_y_pred)
    epoch_y_true = collate_list(epoch_y_true)
    epoch_metadata = collate_list(epoch_metadata)
    if len(epoch_y_group)>0:
        group_flatten = collate_list(epoch_y_group).detach().cpu().numpy()
        group_weight_flatten = collate_list(epoch_y_uncertainty).detach().cpu().numpy()
        group_weight = (group_weight_flatten-group_weight_flatten.min())/(group_weight_flatten.max() - group_weight_flatten.min())
        df = pd.DataFrame({'Group':group_flatten, 'Weights':group_weight})
        
        sns.kdeplot(data=df, x="Weights", hue='Group', fill=True, alpha=0.2, palette="tab10")
        plt.savefig(get_pred_prefix(dataset, config)+"_weights.pdf", bbox_inches="tight")
        plt.close()

    results, results_str = dataset['dataset'].eval(
        epoch_y_pred,
        epoch_y_true,
        epoch_metadata)

    if config.scheduler_metric_split==dataset['split']:
        algorithm.step_schedulers(
            is_epoch=True,
            metrics=results,
            log_access=(not train))

    # log after updating the scheduler in case it needs to access the internal logs
    log_results(algorithm, dataset, general_logger, epoch, batch_idx)

    results['epoch'] = epoch
    dataset['eval_logger'].log(results)
    if dataset['verbose']:
        general_logger.write('Epoch eval:\n')
        general_logger.write(results_str)

    return results, epoch_y_pred



def train(algorithm, datasets, general_logger, config, epoch_offset):

    for epoch in range(epoch_offset, config.n_epochs):
        general_logger.write('\nEpoch [%d]:\n' % epoch)
        
        if config.evaluate_all_splits:
            additional_splits = [split for split in datasets.keys() if split not in ['train','val']]
        else:
            additional_splits = config.eval_splits
            
        # First run training
        run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, 
                val_dataset=datasets['val'], 
                test_datasets=[datasets[split] for split in additional_splits])

        # Then run val
        val_results, y_pred = run_epoch(algorithm, datasets['val'], general_logger, epoch, config, train=False)
        curr_val_metric = val_results[config.val_metric]
        general_logger.write(f'Validation {config.val_metric}: {curr_val_metric:.3f}\n')

        if globalvariable.best_val_metric is None:
            is_best = True
        else:
            if config.val_metric_decreasing:
                is_best = curr_val_metric < globalvariable.best_val_metric
            else:
                is_best = curr_val_metric > globalvariable.best_val_metric
        if is_best:
            globalvariable.best_val_metric = curr_val_metric
            general_logger.write(f'Epoch {epoch} has the best validation performance so far.\n')

        save_model_if_needed(algorithm, datasets['val'], epoch, config, is_best, globalvariable.best_val_metric)
        save_pred_if_needed(y_pred, datasets['val'], epoch, config, is_best)
        
        for split in additional_splits:
            metric_split, y_pred = run_epoch(algorithm, datasets[split], general_logger, epoch, 
                                             config, train=False, test=(split == 'test'))
            save_pred_if_needed(y_pred, datasets[split], epoch, config, is_best)

        general_logger.write('\n')

        if config.NNI:
            metric = {'default': val_results[config.val_metric]}
            metric.update(val_results)
            metric.update(metric_split)
            nni.report_intermediate_result(metric)
            if is_best: best_nni_metric = metric
    if config.NNI:
        nni.report_final_result(best_nni_metric)


def evaluate(algorithm, datasets, epoch, general_logger, config, is_best):
    algorithm.eval()
    torch.set_grad_enabled(False)
    for split, dataset in datasets.items():
        if (not config.evaluate_all_splits) and (split not in config.eval_splits):
            continue
        epoch_y_true = []
        epoch_y_uncertainty = []
        epoch_y_group = []
        epoch_y_pred = []
        epoch_metadata = []
        iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader']
        for batch in iterator:
            batch_results = algorithm.evaluate(batch)
            if config.algorithm == "BMIX":
                batch_weights = algorithm.evaluate_uncertainty(batch)
                epoch_y_uncertainty.append(batch_weights['weight'])
                epoch_y_group.append(batch_weights['g'])
            epoch_y_true.append(detach_and_clone(batch_results['y_true']))
            y_pred = detach_and_clone(batch_results['y_pred'])
            if config.process_outputs_function is not None:
                y_pred = process_outputs_functions[config.process_outputs_function](y_pred)
            epoch_y_pred.append(y_pred)
            epoch_metadata.append(detach_and_clone(batch_results['metadata']))

        epoch_y_pred = collate_list(epoch_y_pred)
        epoch_y_true = collate_list(epoch_y_true)
        epoch_metadata = collate_list(epoch_metadata)
        results, results_str = dataset['dataset'].eval(
            epoch_y_pred,
            epoch_y_true,
            epoch_metadata)

        results['epoch'] = epoch
        dataset['eval_logger'].log(results)
        general_logger.write(f'Eval split {split} at epoch {epoch}:\n')
        general_logger.write(results_str)
        
        
        # Skip saving train preds, since the train loader generally shuffles the data
        if split != 'train':
            save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True)
        if split != 'train' and is_best:
            group_flatten = np.array(epoch_y_group).flatten()
            group_weight_flatten = np.array(epoch_y_uncertainty).flatten()
            group_weight = (group_weight_flatten-group_weight_flatten.min())/(group_weight_flatten.max() - group_weight_flatten.min())
            df = pd.DataFrame({'Group':group_flatten, 'Weights':group_weight})
            
            sns.kdeplot(data=df, x="Weights", hue='Group')
            plt.savefig(get_pred_prefix(dataset, config)+"best_weights.pdf", bbox_inches="tight")
            
def log_results(algorithm, dataset, general_logger, epoch, batch_idx):
    if algorithm.has_log:
        log = algorithm.get_log()
        log['epoch'] = epoch
        log['batch'] = batch_idx
        dataset['algo_logger'].log(log)
        if dataset['verbose']:
            general_logger.write(algorithm.get_pretty_log_str())
        algorithm.reset_log()


def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=False):
    if config.save_pred:
        prefix = get_pred_prefix(dataset, config)
        if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0):
            save_pred(y_pred, prefix + f'epoch:{epoch}_pred')
        if (not force_save) and config.save_last:
            save_pred(y_pred, prefix + f'epoch:last_pred')
        if config.save_best and is_best:
            save_pred(y_pred, prefix + f'epoch:best_pred')


def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric):
    prefix = get_model_prefix(dataset, config)
    if config.save_step is not None and (epoch + 1) % config.save_step == 0:
        save_model(algorithm, epoch, best_val_metric, prefix + f'epoch:{epoch}_model.pth')
    if config.save_last:
        save_model(algorithm, epoch, best_val_metric, prefix + 'epoch:last_model.pth')
    if config.save_best and is_best:
        save_model(algorithm, epoch, best_val_metric, prefix + 'epoch:best_model.pth')
