import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint#, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from model.pl_prox_model import plCovNN
from model.custom_callbacks import LoggingCallback

import os, sys
from argparse import ArgumentParser
from datetime import datetime

from data.pl_data import RealDataModule, SyntheticDataModule, PsuedoSyntheticDataModule
from utils.util_funcs import print_subnet_perf_dict, best_subnetwork_at_best_metric


def parse_monitor(monitor):
    if ('loss' in monitor) or ('mse' in monitor):
        #monitor = 'mse' #many different losses to use: train_epoch, train_batch, val_epoch..
        mode = 'minimize'
    elif ('f1' in monitor) or ('F1' in monitor):
        monitor = 'macro_F1'
        mode = 'maximize'
    elif 'mcc' in monitor:
        monitor = 'mcc'
        mode = 'maximize'
    elif 'acc' in monitor:
        monitor = 'acc'
        mode = 'maximize'
    else:
        raise ValueError(f'Unrecognized Monitor {monitor}')

    return mode, monitor

def construct_model_trainer(cl_parsed_args, dict_args=None, fast_dev_run=False, checkpoint_callback_args=None, load_model_checkpoint_path=None, load_trainer_checkpoint=False, split=None):
    #mode, monitor = parse_monitor(monitor)
    # callbacks
    #early_stop_callback = pl.callbacks.EarlyStopping(monitor=monitor, min_delta=0.001, patience=100, verbose=False,
    #                                                 mode=mode)
    #checkpoint_callback = ModelCheckpoint(
    #    dirpath=os.getcwd() + '/checkpoints',
    #    filename=logger_filename + '{epoch}-{epoch_val_f1:.3f}-{epoch_val_loss:.3f}',
    #    monitor=monitor,
    #    save_top_k=2,
    #    mode=mode)
    trainer_callbacks = []
    #callbacks += extra_trainer_callbacks

    # construct checkpoint callback
    if checkpoint_callback_args is not None:
        # if we have multiple splits, create a checkpoints for each split
        print(f"Saving Checkpoints to {checkpoint_callback_args['dirpath']}")
        if split is not None:
            checkpoint_callback_args = checkpoint_callback_args.copy()
            checkpoint_callback_args['filename'] = checkpoint_callback_args['filename'] + f"__split={split}"
        trainer_callbacks.append(ModelCheckpoint(**checkpoint_callback_args))


    if dict_args is not None:
        model_dict_args = dict_args['model']
        trainer_dict_args = dict_args['trainer']
        # only include logging callbacks if we allow logging in trainer
        if trainer_dict_args.__contains__('logger') and (trainer_dict_args['logger'] is not False):
            #lr_monitor = LearningRateMonitor(logging_interval='step')
            log_callback = LoggingCallback()
            path2logging = os.getcwd() + '/lightning_logs'
            trainer_dict_args['logger'] = TensorBoardLogger(path2logging, log_graph=False)
            trainer_callbacks += [] #[lr_monitor, log_callback]

        if load_model_checkpoint_path is not None:
            print(f'\n\nLoading model from checkpoint: {load_model_checkpoint_path}\n\n')
            model = plCovNN.load_from_checkpoint(checkpoint_path=load_model_checkpoint_path)
        else:
            model = plCovNN(**model_dict_args)
        if load_trainer_checkpoint:
            print(f'\n\nLoading trainer from checkpoint: {load_model_checkpoint_path}\n\n')
            trainer = pl.Trainer(resume_from_checkpoint=load_model_checkpoint_path)
        else:
            trainer = pl.Trainer(**trainer_dict_args,
                                 callbacks=trainer_callbacks,
                                 fast_dev_run=fast_dev_run,
                                 num_sanity_val_steps=1)
    else:
        assert False, f'need to impliment functionality for CLI'
        model = plCovNN(cl_parsed_args)
        trainer = pl.Trainer.from_argparse_args(cl_parsed_args, deterministic=cl_parsed_args.train_deterministically,
                                                callbacks=trainer_callbacks,
                                                fast_dev_run=fast_dev_run) #, auto_lr_find=True)

    return model, trainer


def train_val_test(datamodule, cl_parsed_args, subnetworks, dict_args=None, checkpoint_callback_args=None, load_model_checkpoint_path=None, load_trainer_checkpoint=False, extra_trainer_callbacks=[], split=None):
    model, trainer = construct_model_trainer(cl_parsed_args=cl_parsed_args, dict_args=dict_args,
                                             checkpoint_callback_args=checkpoint_callback_args,
                                             load_model_checkpoint_path=load_model_checkpoint_path,
                                             load_trainer_checkpoint=load_trainer_checkpoint,
                                             split=split)
    model.subnetwork_masks = datamodule.subnetwork_masks
    trainer.fit(model, train_dataloader=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())
    # THIS WILL BREAK -> prior_metrics_val now
    raise ValueError(f'need to check this! changed prior metrics')
    prior_metrics = model.prior_metrics_val
    final_epoch_metrics = model.list_of_metrics[-1]
    sort_subnetwork = 'full'
    best_acc_epoch_metrics = model.best_metrics('acc', sort_subnetwork=sort_subnetwork, top_k=1)[0]
    best_mse_epoch_metrics = model.best_metrics('mse', sort_subnetwork=sort_subnetwork, top_k=1, maximize=False)[0]
    best_mcc_epoch_metrics = model.best_metrics('mcc', sort_subnetwork=sort_subnetwork, top_k=1)[0]

    print(f'\nPrior metrics: {model.prior_construction}')  # prior metrics is a dict over priors (single, single-batched, multi, zeros)
    for i, (prior_channel_name, prior_metrics_) in enumerate(prior_metrics.items()):
        print(f'\tprior: {prior_channel_name} > {model.training_prior_threshold[i].item():.4f}')
        print_subnet_perf_dict(prior_metrics_, indents=2, convert_to_percent=['acc', 'error'])
        #for subnetwork_name, subnetwork_metric in prior_metric.items():
        #    print(f'\t\t{subnetwork_name}: {subnetwork_metric}')
            #print('\t' + which_prior, prior_metric)
    print(f'Metrics of epoch which optimize metric _ on subnetwork {sort_subnetwork}')
    print('\tACC'); print_subnet_perf_dict(best_acc_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
    print('\tMSE'); print_subnet_perf_dict(best_mse_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
    print('\tMCC'); print_subnet_perf_dict(best_mcc_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
    print('Final epoch metrics')
    print('\tFinal'); print_subnet_perf_dict(final_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])

    best_subnetwork_at_best_metric(model, prior_metrics, subnetworks=subnetworks, metrics=['mse', 'mae', 'error', 'mcc'], indents=1)

    return_prior_metrics = {prior_channel_name: subnet_metrics_dict['full'] for prior_channel_name, subnet_metrics_dict in prior_metrics.items()}
    best_epochs_by_metrics = {'acc': best_acc_epoch_metrics['full'], 'mse': best_mse_epoch_metrics['full'], 'mcc': best_mcc_epoch_metrics['full']}
    return return_prior_metrics, final_epoch_metrics['full'],  best_epochs_by_metrics


def cross_validation(datamodule, cl_parsed_args, subnetworks, dict_args=None, num_splits=26, which_splits_to_train=[1], rand_seeds=[50], checkpoint_callback_args=None, load_model_checkpoint_path=None, load_trainer_checkpoint=False, extra_trainer_callbacks=[], extra_model_callbacks=[]):
    assert len(which_splits_to_train) <= num_splits
    assert len(which_splits_to_train) == len(rand_seeds)

    prior_metrics_list, final_epoch_metrics_list, \
    best_acc_epoch_metrics_list, best_mse_epoch_metrics_list, best_mcc_epoch_metrics_list = [], [], [], [], []
    for split, rand_seed in zip(which_splits_to_train, rand_seeds):
        print(f'========  SPLIT {split}  ============')
        datamodule.set_split(split)
        split_info = datamodule.train_val_splits[split]
        print(f"Patients in Train/Val: {split_info['patients_in_train_fold']}, {split_info['patients_in_val_fold']}")
        print(f"FCS      in Train/Val: {len(split_info['train_fold_idxs'])}, {len(split_info['val_fold_idxs'])}")

        #filename =  f'_SPLIT{split}_'
        prior_metrics, final_epoch_metrics, best_epochs_by_metrics = \
            train_val_test(datamodule, cl_parsed_args, subnetworks=subnetworks, dict_args=dict_args,
                           checkpoint_callback_args=checkpoint_callback_args,
                           load_model_checkpoint_path=load_model_checkpoint_path,
                           load_trainer_checkpoint=load_trainer_checkpoint,
                           split=split,
                           extra_trainer_callbacks=extra_trainer_callbacks) #extra_model_callbacks # for logging??
        prior_metrics_list.append(prior_metrics)
        final_epoch_metrics_list.append(final_epoch_metrics)
        best_acc_epoch_metrics_list.append(best_epochs_by_metrics['acc'])
        best_mse_epoch_metrics_list.append(best_epochs_by_metrics['mse'])
        best_mcc_epoch_metrics_list.append(best_epochs_by_metrics['mcc'])
        print(f'========^^^  SPLIT {split}  ^^^============\n\n')

    print(f'========  MEDIAN OVER SPLITS  ============')
    prior_summaries = prior_summary(prior_metrics_list=prior_metrics_list, func=np.median)
    final_epoch_summaries = summary(metrics_list=final_epoch_metrics_list, func=np.median)
    best_acc_epoch_summaries = summary(metrics_list=best_acc_epoch_metrics_list, func=np.median)
    best_mse_epoch_summaries = summary(metrics_list=best_mse_epoch_metrics_list, func=np.median)
    best_mcc_epoch_summaries = summary(metrics_list=best_mcc_epoch_metrics_list, func=np.median)

    print(f'Returning mean/median across {len(which_splits_to_train)} splits, of {num_splits} total')
    best_epochs_by_metrics_summaries = {'acc': best_acc_epoch_summaries, 'mse': best_mse_epoch_summaries, 'mcc': best_mcc_epoch_summaries}
    return prior_summaries, final_epoch_summaries, best_epochs_by_metrics_summaries


# metrics list is list of dicts
def summary(metrics_list, func):
    num_runs = len(metrics_list)
    acc, macro_f1, mse, mcc = np.zeros(num_runs), np.zeros(num_runs), np.zeros(num_runs), np.zeros(num_runs)

    for i, d in enumerate(metrics_list):
        mse[i] = d['mse']
        macro_f1[i] = d['macro_F1']
        acc[i] = d['acc']
        mcc[i] = d['mcc']

    summary_metrics_dict = {'mse': func(mse), 'macro_F1': func(macro_f1), 'acc': func(acc),  'mcc': func(mcc)}
    return summary_metrics_dict


def prior_summary(prior_metrics_list, func):
    prior_channels_names = list(prior_metrics_list[0].keys())
    prior_summary_metrics_dict_per_channel = {}
    for pcn in prior_channels_names:
        this_channel_in_each_run = [(x[pcn]) for x in prior_metrics_list]
        prior_summary_metrics_dict_per_channel[pcn] = summary(this_channel_in_each_run, func)

    return prior_summary_metrics_dict_per_channel


def which_dm(w: str):
    assert w in ['synthetics', 'pseudo-synthetics', 'real']
    if w == 'synthetics':
        return SyntheticDataModule
    elif w == 'pseudo-synthetics':
        return PsuedoSyntheticDataModule
    elif w == 'real':
        return RealDataModule
    else:
        raise ValueError(f'unrecognized argument which {w}')


# must specify num_splits/num_splits_train or num_patients_val/test, and logging as CLI
def run_experiment_cli(which='real'):

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = plCovNN.add_model_specific_args(parser)

    dm_class = which_dm(which)
    parser = which_dm.add_module_specific_args(parser)
    # loop through and remove all keys with value==None. This allows default value in model to be
    # used if not supplied as CLI
    args = parser.parse_args()
    all_args = dict(vars(args))
    for k, v in all_args.items():
        if v is None:
            args.__delattr__(k)

    # construct datamodule
    dm = dm_class.from_argparse_args(args)
    dm.setup('fit')

    if args.num_splits is not None:
        return cross_validation(datamodule=dm,
                                cl_parsed_args=args,
                                num_splits=args.num_splits,
                                which_splits_to_train=list(range(args.num_splits_train)),
                                rand_seeds=np.random.randint(low= -10000, high=10000, size=args.num_splits_train)
                                )
    else:
        return train_val_test(datamodule=dm,
                              cl_parsed_args=args,
                              dict_args=None
                              )


# customize inputs to this function
def run_experiment_manual(model_args={}, dm_args={}, trainer_args={},
                          datamodule=None,
                          extra_trainer_callbacks=[],
                          checkpoint_callback_args = None,
                          load_model_checkpoint_path=None,
                          load_trainer_checkpoint=False,
                          num_splits_train=None,
                          which_exp='real',
                          subnetworks=['full']):

    dict_args = {'model': model_args, 'trainer': trainer_args}

    # create dm with dict
    if datamodule is None:
        datamodule = which_dm(which_exp)(**dm_args)
    datamodule.setup('fit')

    # Synthetics datamodule does not have cross_validation supported
    if (which_exp != 'synthetics') and datamodule.num_splits is not None:
        # note checkpoints need to be integrated into this. Currently only support for one split??
        return cross_validation(datamodule=datamodule,
                                cl_parsed_args=None,
                                dict_args=dict_args,
                                checkpoint_callback_args=checkpoint_callback_args,
                                load_model_checkpoint_path=load_model_checkpoint_path,
                                load_trainer_checkpoint=load_trainer_checkpoint,
                                num_splits=datamodule.num_splits,
                                which_splits_to_train=list(range(num_splits_train)),
                                rand_seeds=np.random.randint(low= -10000, high=10000, size=num_splits_train),
                                extra_trainer_callbacks=extra_trainer_callbacks,
                                subnetworks=subnetworks
                                )
    else:
        return run_experiment_manual(datamodule=datamodule,
                              cl_parsed_args=None,
                              dict_args=dict_args,
                              checkpoint_callback_args=checkpoint_callback_args,
                              load_model_checkpoint_path=load_model_checkpoint_path,
                              load_trainer_checkpoint=load_trainer_checkpoint,
                              extra_trainer_callbacks=extra_trainer_callbacks,
                              subnetworks=subnetworks
                              )


## checkpoint stuff
def make_checkpoint_run_folder(path2checkpoints, path2run):
    try:
        os.mkdir(path=path2checkpoints)
        print("Creating checkpoints folder")
    except FileExistsError as e:
        print(f"Checkpoints folder {path2checkpoints} already exists")

    try:
        os.mkdir(path=path2run)
        print(f"Creating folder for this run: {path2run}", flush=True)
    except FileExistsError as e:
        print(f"This run {path2run} already exists already exists")


def make_checkpoint_callback_dict(path2currDir, monitor, task, loss, which_exp, rand_seed, trainer_args, run_directory):
    assert ("=" not in path2currDir), f'= in path2CurrDir {path2currDir} not allowed!'
    assert ("=" not in run_directory), f'= in run_directory {run_directory} not allowed!'
    # make folders if they dont exist, specify filename of checkpoint(s) and create checkpoint callback

    path2Allcheckpoints = path2currDir + 'checkpoints/'
    path2run = path2Allcheckpoints + run_directory

    # make this directory if doesnt exist yet
    make_checkpoint_run_folder(path2checkpoints=path2Allcheckpoints, path2run=path2run)

    # construct filename of checkpoints
    filename = f"{task}_loss_{loss}_" + "epoch{epoch:05d}" #+ f"{monitor_name}=" + "{" + f"{monitor}" + ":.3f}"
    metrics = [(a, 'val/full/' + a) for a in ['error', 'mcc', 'mse', 'mae']]
    for metric_name, logged_name in metrics:
        filename += f"_{metric_name}" + "{" + f"{logged_name}" + ":.7f}"
    if 'regress' in task:
        logged_name = f"train/{loss}_epoch"
        metric_name = f"Train{loss}"
        filename += "_" + metric_name + "{" + f"{logged_name}" + ":.7f}"
    filename += f"_seed{rand_seed}_" + 'date&time' + str(datetime.now()).replace(" ", "_")[5:-7].replace("/", "-")

    if 'real' in which_exp or 'pseudo' in which_exp:
        filename = f"{which_exp}-" + filename

    assert'=' not in filename, f" '=' seems to mess up loading"
    checkpoint_callback_args = {'monitor': monitor,
                                'dirpath': path2run,
                                'verbose': True,
                                'save_last': False,
                                'save_top_k': 1,
                                'auto_insert_metric_name': False,
                                'filename': filename, # <= this will be changed for each split!
                                'mode': 'min',
                                'save_on_train_epoch_end': False}

    return checkpoint_callback_args


# common base model args to be used by many trainers
def make_model_dict(channels,
                    poly_fc_orders,
                    optimizer,
                    learning_rate,
                    share_parameters,
                    monitor,
                    loss,
                    include_nonlinearity,
                    rand_seed,
                    hinge_margin,
                    n_train,
                    momentum=0.9,
                    prior_construction='mean',
                    threshold_metric_test_points = np.concatenate((np.arange(0, .07, .02), np.arange(.09, .15, .001), np.arange(.15, .6, .15)), axis=0),
                    learn_tau=True,
                    where_normalize_slices='after_reduction'):

    model_args = {
        'channels': channels,
        'poly_fc_orders': poly_fc_orders,
        'where_normalize_slices': where_normalize_slices,
        'poly_basis': 'cheb',
        'optimizer': optimizer,
        'learning_rate': learning_rate,  # .1 mse (>.5 produces unstable training), .3 hinge
        'momentum': momentum,
        'share_parameters': share_parameters,
        'real_network': 'full',
        'threshold_metric_test_points': threshold_metric_test_points,
        'monitor': monitor,
        'which_loss': loss,
        'logging': 'only_scalars',
        'learn_tau': learn_tau,
        'hinge_margin': hinge_margin,
        'include_nonlinearity': include_nonlinearity,
        'n_train': n_train,
        'prior_construction': prior_construction,  # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
        'rand_seed': rand_seed}

    return model_args

#common base trainer args to be used by many trainers
def make_trainer_dict(max_epochs=10000, check_val_every_n_epoch=10, logger=False):
    trainer_args = {'max_epochs': max_epochs,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': logger,
                    'check_val_every_n_epoch': check_val_every_n_epoch}

    if torch.cuda.is_available():
        trainer_args['precision'] = 32 #if (model_args['s_in_norm'] == 'max_eig') else 16
    return trainer_args

if __name__ == "__main__":
    print('trainer')
