import warnings
import copy
import datetime
import os
import shutil
import sys
import pathlib
import pandas as pd
from argparse import ArgumentParser

import random
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import yaml
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import CosineAnnealingLR

from lib import fillers, datasets, config
from lib.data.datamodule import SpatioTemporalDataModule
from lib.data.imputation_dataset import ImputationDataset, GraphImputationDataset
from lib.nn import models

from lib.nn.utils.metric_base import MaskedMetric
from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE
from lib.utils import parser_utils, numpy_metrics, ensure_list, prediction_dataframe, plot_weights
from lib.utils.parser_utils import str_to_bool


def has_graph_support(model_cls):
    return model_cls in [models.DRIK]


def not_nn_model(model_name):
    return model_name in ['mean', 'okriging']


def get_model_classes(model_str):
    if model_str == 'drik':
        model, filler = models.DRIK, fillers.GCNDecFiller
    else:
        raise ValueError(f'Model {model_str} not available.')
    return model, filler


def get_dataset(dataset_name, val_rate=0.1, test_rate=0.2, mode="road", test_entries="", eval_sample_strategy='random',
                    adj_thr=0.1):
    if dataset_name[:3] == 'aqi':
        # Support aqi36 (small dataset) and aqi437 (full dataset)
        if dataset_name == 'aqi36':
            small = True
        elif dataset_name == 'aqi437':
            small = False
        else:
            # Maintain backward compatibility: aqi defaults to full dataset
            small = dataset_name[3:] == '36'
        dataset = datasets.AirQuality(impute_nans=True, small=small, p=[val_rate, test_rate], adj_thr=adj_thr)
    elif dataset_name == 'la_point':
        dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=[val_rate, test_rate], mode=mode, test_entries=test_entries,
                                               eval_sample_strategy=eval_sample_strategy, adj_thr=adj_thr)
    elif dataset_name == 'bay_point':
        dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=[val_rate, test_rate], mode=mode, adj_thr=adj_thr)
    elif dataset_name == 'pems07_point':
        dataset = datasets.MissingValuesPems07(p_fault=0., p_noise=[val_rate, test_rate], mode=mode, adj_thr=adj_thr)
    elif dataset_name == 'nrel_md_point':
        dataset = datasets.MissingValuesNrelMd(p_fault=0., p_noise=[val_rate, test_rate], mode=mode, adj_thr=adj_thr)
    elif dataset_name == 'ushcn':
        dataset = datasets.MissingValuesUshcn(p_fault=0., p_noise=[val_rate, test_rate], mode=mode, adj_thr=adj_thr)
    else:
        raise ValueError(f"Dataset {dataset_name} not available in this setting.")
    return dataset


def parse_args():
    # Argument parser
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--model-name", type=str, default='drik')
    parser.add_argument("--dataset-name", type=str, default='la_point')
    parser.add_argument("--val-rate", default=0.2, type=float)
    parser.add_argument("--test-rate", default=0.2, type=float)
    parser.add_argument("--mode", default="road", choices=["road"], type=str)
    parser.add_argument('--eval-sample-strategy', type=str, default='random')  # random, region, degree
    parser.add_argument("--test-entries", default="", choices=["", "metr_la_coarse_to_fine.txt", "metr_la_coarse_to_fine_hard.txt", "metr_la_region.txt", "metr_la_region_hard.txt"], type=str)
    parser.add_argument("--config", type=str, default="config/drik/la_point.yaml")
    parser.add_argument("--use-adj-drop", type=str_to_bool, nargs='?', const=True, default=False)
    parser.add_argument("--use-adj-add", type=str_to_bool, nargs='?', const=True, default=False)
    parser.add_argument("--use-init", type=str_to_bool, nargs='?', const=True, default=False)
    parser.add_argument("--pretrained-model", type=str, default="")
    parser.add_argument("--use-subgraph", type=str_to_bool, nargs='?', const=True, default=False)
    # Splitting/aggregation params
    parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=False)
    parser.add_argument('--val-len', type=float, default=0.1)
    parser.add_argument('--test-len', type=float, default=0.2)
    parser.add_argument('--aggregate-by', type=str, default='mean')
    # Training params
    parser.add_argument('--lr', type=float, default=0.0002)
    parser.add_argument('--epochs', type=int, default=400)
    parser.add_argument('--patience', type=int, default=80)
    parser.add_argument('--l2-reg', type=float, default=0.)
    parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=True)
    parser.add_argument('--grad-clip-val', type=float, default=1.)
    parser.add_argument('--grad-clip-algorithm', type=str, default='norm')
    parser.add_argument('--loss-fn', type=str, default='l1_loss')
    parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True)
    parser.add_argument('--whiten-prob', type=float, default=0.05)
    # graph params
    parser.add_argument("--adj-threshold", type=float, default=0.1)
    parser.add_argument("--include-self", type=str_to_bool, nargs='?', const=True, default=True)
    # position encoding params
    parser.add_argument('--pos-enc-mode', type=str, default='none', choices=['none', 'sinusoidal', 'direct', 'fourier'])
    # preemptive
    parser.add_argument('--preemptive', type=str_to_bool, nargs='?', const=True, default=False)
    parser.add_argument('--domain-adaptation', type=str_to_bool, nargs='?', const=True, default=False)
    # meta
    parser.add_argument('--adaptation-steps', type=int, default=1)
    parser.add_argument('--grad-acc-steps', type=int, default=1)
    parser.add_argument('--fast-lr', type=float, default=0.5)
    parser.add_argument('--pretrained-model-meta', type=str, default="")
    parser.add_argument('--semi-maml', type=str_to_bool, nargs='?', const=True, default=False)
    parser.add_argument('--cont-in-batch', type=str_to_bool, nargs='?', const=True, default=False)
    # eerm
    parser.add_argument('--K', type=int, default=5)
    parser.add_argument('--p-remove', type=float, default=0.1)
    parser.add_argument('--p-add', type=float, default=0)
    parser.add_argument('--lr_r', type=float, default=0.001)
    parser.add_argument('--lr_a', type=float, default=0.001)
    # awp
    parser.add_argument('--use-awp', type=str_to_bool, nargs='?', const=True, default=False)
    # swa
    parser.add_argument('--use-swa', type=str_to_bool, nargs='?', const=True, default=False)
    # node perturbation
    parser.add_argument('--use-node-perturbation', type=str_to_bool, nargs='?', const=True, default=False)

    known_args, _ = parser.parse_known_args()
    model_cls, _ = get_model_classes(known_args.model_name)
    parser = model_cls.add_model_specific_args(parser)
    parser = SpatioTemporalDataModule.add_argparse_args(parser)
    parser = ImputationDataset.add_argparse_args(parser)

    args = parser.parse_args()
    if args.config is not None:
        with open(args.config, 'r') as fp:
            config_args = yaml.load(fp, Loader=yaml.FullLoader)
        for arg in config_args:
            setattr(args, arg, config_args[arg])

    return args


def run_experiment(args):
    # Set configuration and seed
    args = copy.deepcopy(args)
    if args.seed < 0:
        args.seed = np.random.randint(1e9)
    torch.set_num_threads(1)

    pl.seed_everything(args.seed)

    model_cls, filler_cls = get_model_classes(args.model_name)
    dataset = get_dataset(args.dataset_name, args.val_rate, args.test_rate, args.mode, args.test_entries,
                            args.eval_sample_strategy, args.adj_threshold)

    ########################################
    # create logdir and save configuration #
    ########################################
    exp_name = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}"
    logdir = os.path.join(config['logs'], args.dataset_name, args.model_name, exp_name)
    # save config for logging
    pathlib.Path(logdir).mkdir(parents=True)
    with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:
        yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True)

    ########################################
    # data module                          #
    ########################################
    # instantiate dataset
    dataset_cls = GraphImputationDataset if has_graph_support(model_cls) else ImputationDataset
    torch_dataset = dataset_cls(*dataset.numpy(return_idx=True),
                                mask=dataset.training_mask,
                                val_mask=dataset.val_mask,
                                test_mask=dataset.test_mask,
                                window=args.window,
                                stride=args.stride)

    # get train/val/test indices
    split_conf = parser_utils.filter_function_args(args, dataset.splitter, return_dict=True)
    train_idxs, val_idxs, test_idxs = dataset.splitter(torch_dataset, **split_conf)

    # configure datamodule
    data_conf = parser_utils.filter_args(args, SpatioTemporalDataModule, return_dict=True)
    if args.dataset_name in ["pems07_point"]:
        data_conf["scaling_type"] = "minmax"
        dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs,
                                      **data_conf)
        min_val = 0
        max_val = 1500
        print("Min Max Scaler - max: {}".format(max_val))
        dm.setup(min=min_val, max=max_val)
    elif args.dataset_name in ["nrel_al_point", "nrel_md_point"]:
        print("Use capacities as Min Max Scaler")
        data_conf["scaling_type"] = "minmax"
        dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs,
                                      **data_conf)
        
        min_val, capacities = dataset.min_val, dataset.capacities

        dm.setup(min=min_val, max=capacities)
    else:
        dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs,
                                      **data_conf)
        dm.setup()

    # get adjacency matrix
    adj = dataset.adj

    # eval nodes
    val_mask_, test_mask_ = dataset.val_mask, dataset.test_mask
    # Column indices that are not all zeros
    val_nodes = val_mask_.sum(axis=0) > 0
    val_nodes = np.where(val_nodes)[0]
    test_nodes = test_mask_.sum(axis=0) > 0
    test_nodes = np.where(test_nodes)[0]

    # get position
    # if args.dataset_name in ["la_point", "bay_point", "pems07_point", "nrel_al_point", "aqi36", "sea_loop_point"]:
    if args.dataset_name in ["la_point", "bay_point"]:
        position = dataset.get_position()
    else:
        position = None

    plot_weights(adj, logdir, pos=position, val_nodes=val_nodes, test_nodes=test_nodes)

    ########################################
    # predictor                            #
    ########################################
    if not_nn_model(args.model_name):
        y_true = dataset.df.iloc[dm.test_slice].values
        mask = dataset.test_mask[dm.test_slice]
        # For okriging and mean models, pass dataset object to get correct position data and training mask
        if args.model_name in ['okriging', 'mean']:
            model = model_cls(args=args, dataset=dataset)
        else:
            model = model_cls(args=args)
        y_hat = model.predict(y_true*(1-mask))
        aggr_methods = ensure_list(args.aggregate_by)
        df_hats = [pd.DataFrame(y_hat, index=dataset.df.index[dm.test_slice], columns=dataset.df.columns)]

    else:
        # model's inputs
        additional_model_hparams = dict(adj=adj, d_in=dm.d_in, n_nodes=dm.n_nodes, args=args, position=position)
        model_kwargs = parser_utils.filter_args(args={**vars(args), **additional_model_hparams},
                                                target_cls=model_cls,
                                                return_dict=True)

        # loss and metrics
        loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn),
                               compute_on_step=True,
                               metric_kwargs={'reduction': 'none'})

        metrics = {
            'mae': MaskedMAE(compute_on_step=False),
            'mape': MaskedMAPE(compute_on_step=False),
            'mse': MaskedMSE(compute_on_step=False),
            'mre': MaskedMRE(compute_on_step=False)
        }

        # filler's inputs
        scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None
        additional_filler_hparams = dict(model_class=model_cls,
                                         model_kwargs=model_kwargs,
                                         optim_class=torch.optim.Adam,
                                         optim_kwargs={'lr': args.lr,
                                                       'weight_decay': args.l2_reg},
                                         loss_fn=loss_fn,
                                         metrics=metrics,
                                         scheduler_class=scheduler_class,
                                         scheduler_kwargs={
                                             'eta_min': 0.0001 if 'min_lr' not in vars(args) else args.min_lr,
                                             'T_max': args.epochs
                                         }
                                         )
        filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams},
                                                 target_cls=filler_cls,
                                                 return_dict=True)
        filler = filler_cls(**filler_kwargs)

        ########################################
        # training                             #
        ########################################
        # callbacks
        early_stop_callback = EarlyStopping(monitor='val_mae', patience=args.patience, mode='min', verbose=True)
        checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mae', mode='min')
        callbacks = [checkpoint_callback, early_stop_callback]
        if args.use_swa:
            swa_callback = StochasticWeightAveraging(
                swa_epoch_start=args.swa_kwargs['start_epoch'],
                swa_lrs=args.swa_kwargs['swa_lr'],
                annealing_epochs=args.swa_kwargs['anneal_epochs']
            )
            callbacks.append(swa_callback)

        logger = TensorBoardLogger(logdir, name="model")

        trainer_kwargs = dict(
            max_epochs=args.epochs,
            logger=logger,
            default_root_dir=logdir,
            gpus=1 if torch.cuda.is_available() else None,
            callbacks=callbacks,
        )

        # Only pass gradient clipping settings when AWP is not enabled
        if not args.use_awp:
            trainer_kwargs.update({
                "gradient_clip_val": args.grad_clip_val,
                "gradient_clip_algorithm": args.grad_clip_algorithm,
            })

        trainer = pl.Trainer(**trainer_kwargs)

        if args.pretrained_model == "" or args.pretrained_model == None:
            if not (args.pretrained_model_meta == "" or args.pretrained_model_meta == None):
                state_dict = torch.load(args.pretrained_model_meta, lambda storage, loc: storage)['state_dict']
                adj = torch.from_numpy(adj)
                state_dict["model.adj"] = adj
                # position = torch.from_numpy(position)
                # state_dict["model.position"] = position
                filler.load_state_dict(state_dict)
            trainer.fit(filler, datamodule=dm)
            if args.use_swa:
                swa_ckpt_path = os.path.join(logdir, 'swa_model.ckpt')
                torch.save({'state_dict': filler.state_dict()}, swa_ckpt_path)
                print(f"SWA model saved to {swa_ckpt_path}")

            trainer.test()
            if args.use_swa:
                ckpt = torch.load(swa_ckpt_path)
            else:
                ckpt = torch.load(checkpoint_callback.best_model_path)
            # Load model parameters
            filler.load_state_dict(ckpt['state_dict'])
            if 'gl_a_state' in ckpt and filler.gl_a is not None:
                filler.gl_a.load_state_dict(ckpt['gl_a_state'])
        else:
            # state_dict = torch.load(args.pretrained_model, lambda storage, loc: storage)['state_dict']
            # adj = torch.from_numpy(adj)
            # state_dict[
            #     "model.adj"] = adj  # in case of using pretrained model of other datasets to infer current dataset
            # position = torch.from_numpy(position)
            # state_dict["model.position"] = position
            # filler.load_state_dict(state_dict)
            ckpt = torch.load(args.pretrained_model, lambda storage, loc: storage)
            adj = torch.from_numpy(adj)
            ckpt['model.adj'] = adj
            # Delete "gl_a.B", "gl_a.A", "gl_r.B" from ckpt
            # del ckpt['state_dict']['gl_a.B']
            # del ckpt['state_dict']['gl_a.A']
            # del ckpt['state_dict']['gl_r.B']
            filler.load_state_dict(ckpt['state_dict'])
            if 'gl_a_state' in ckpt and filler.gl_a is not None:
                filler.gl_a.load_state_dict(ckpt['gl_a_state'])

        ########################################
        # testing                              #
        ########################################
        filler.freeze()
        filler.eval()

        if torch.cuda.is_available():
            filler.cuda()

        with torch.no_grad():
            y_true, y_hat, mask = filler.predict_loader(dm.test_dataloader(), return_mask=True)
        y_hat = y_hat.detach().squeeze(-1).cpu().numpy()  # reshape to (eventually) squeeze node channels

        # Aggregate predictions in dataframes
        index = dm.torch_dataset.data_timestamps(dm.testset.indices, flatten=False)['horizon']
        # tar
        aggr_methods = ensure_list(args.aggregate_by)
        df_hats = prediction_dataframe(y_hat, index, dataset.df.columns, aggregate_by=aggr_methods)

    df_hats = dict(zip(aggr_methods, df_hats))

    test_mask = dataset.test_mask[dm.test_slice]
    df_true = dataset.df.iloc[dm.test_slice]

    metrics = {
        'mae': numpy_metrics.masked_mae,
        'mape': numpy_metrics.masked_mape,
        'mre': numpy_metrics.masked_mre,
        'mse': numpy_metrics.masked_mse,
        'r2': numpy_metrics.masked_r2
    }

    results = []
    for aggr_by, df_hat in df_hats.items():
        # Compute error
        print(f'- AGGREGATE BY {aggr_by.upper()}')
        result_dict = {"aggregate_by": aggr_by}
        for metric_name, metric_fn in metrics.items():
            error = metric_fn(df_hat.values, df_true.values, test_mask).item()
            result_dict[metric_name] = error
            print(f' {metric_name}: {error:.4f}')
            if metric_name == "mse":
                rmse = np.sqrt(error)
                result_dict["rmse"] = rmse
                print(f' rmse: {rmse:.4f}')
        results.append(result_dict)

    results_df = pd.DataFrame(results)

    if not (args.pretrained_model == "" or args.pretrained_model == None):
        # Delete the created log directory
        shutil.rmtree(logdir)
        results_df.to_csv(os.path.join(os.path.dirname(args.pretrained_model), "metrics.csv"), index=False)
    else:
        results_df.to_csv(os.path.join(logdir, "metrics.csv"), index=False)

    return y_true, y_hat, mask


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    args = parse_args()
    run_experiment(args)
