from functools import partial

# from pathlib import Path
from typing import Any, Dict, Union
from copy import deepcopy

import ignite
import torch
import pdb
import csv
from jarvis.core.atoms import Atoms

from ignite.contrib.handlers import TensorboardLogger
try:
    from ignite.contrib.handlers.stores import EpochOutputStore
except Exception as exp:
    from ignite.handlers.stores import EpochOutputStore

    pass
from ignite.handlers import EarlyStopping
from ignite.contrib.handlers.tensorboard_logger import (
    global_step_from_engine,
)
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.engine import (
    Events,
    create_supervised_evaluator,
    create_supervised_trainer,
)
from ignite.contrib.metrics import ROC_AUC, RocCurve
from ignite.metrics import (
    Accuracy,
    Precision,
    Recall,
    ConfusionMatrix,
)
import pickle as pk
import numpy as np
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.engine import Engine
from ignite.metrics import Loss, MeanAbsoluteError
from torch import nn
import models
from data import get_train_val_loaders
from config import TrainingConfig
from models.pyg_att import Matformer
from models.property_predictor import PropertyPredictor


import torch.nn.utils as nn_utils


from jarvis.db.jsonutils import dumpjson
import json
import pprint
from tqdm import tqdm
import os


# torch config
torch.set_default_dtype(torch.float32)

# torch.use_deterministic_algorithms(True)

device = "cpu"
if torch.cuda.is_available():
    device = torch.device("cuda")


def activated_output_transform(output):
    """Exponentiate output."""
    y_pred, y = output
    y_pred = torch.exp(y_pred)
    y_pred = y_pred[:, 1]
    return y_pred, y


def make_standard_scalar_and_pca(output):
    """Use standard scalar and PCS for multi-output data."""
    sc = pk.load(open(os.path.join(tmp_output_dir, "sc.pkl"), "rb"))
    y_pred, y = output
    y_pred = torch.tensor(sc.transform(y_pred.cpu().numpy()), device=device)
    y = torch.tensor(sc.transform(y.cpu().numpy()), device=device)
    return y_pred, y


def thresholded_output_transform(output):
    """Round off output."""
    y_pred, y = output
    y_pred = torch.round(torch.exp(y_pred))
    # print ('output',y_pred)
    return y_pred, y


def group_decay(model):
    """Omit weight decay from bias and batchnorm params."""
    decay, no_decay = [], []

    for name, p in model.named_parameters():
        if "bias" in name or "bn" in name or "norm" in name:
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay},
        {"params": no_decay, "weight_decay": 0},
    ]


def setup_optimizer(params, config: TrainingConfig):
    """Set up optimizer for param groups."""
    if config['optimizer'] == "adamw":
        optimizer = torch.optim.AdamW(
            params,
            lr=config['learning_rate'],
            weight_decay=config['weight_decay'],
        )
    elif config['optimizer'] == "sgd":
        optimizer = torch.optim.SGD(
            params,
            lr=config['learning_rate'],
            momentum=0.9,
            weight_decay=config['weight_decay'],
        )
    return optimizer

max_grad_norm = 2.0          # try 1.0 if still unstable

# def grad_clipping(engine, max_norm):
#     nn_utils.clip_grad_norm_(net.parameters(), max_norm)

def train_dgl(
    config: Union[TrainingConfig, Dict[str, Any]],
    model: nn.Module = None,
    train_val_test_loaders=[],
    test_only=False,
    use_save=True,
    mp_id_list=None,
):
    """
    `config` should conform to matformer.conf.TrainingConfig, and
    if passed as a dict with matching keys, pydantic validation is used
    """
    print(config)
    print(config['output_dir'])
    # pdb.set_trace()
    if type(config) is dict:
        try:
            config = TrainingConfig(**config).dict()
            model = TrainingConfig(**config).model
        except Exception as exp:
            print("Check", exp)
            print('error in converting to training config!')
    import os
    
    if not os.path.exists(config['output_dir']):
        os.makedirs(config['output_dir'])
    checkpoint_dir = os.path.join(config['output_dir'])
    deterministic = False
    classification = False
    print("config:")
    tmp = deepcopy(config)
    f = open(os.path.join(config['output_dir'], "config.json"), "w")
    f.write(json.dumps(tmp, indent=4))
    f.close()
    global tmp_output_dir
    tmp_output_dir = config['output_dir']
    pprint.pprint(tmp) 
    if config['classification_threshold'] is not None:
        classification = True
    if config['random_seed'] is not None:
        deterministic = True
        ignite.utils.manual_seed(config['random_seed'])

    line_graph = True
    
    if not train_val_test_loaders:
        # use input standardization for all real-valued feature sets
        (
            train_loader,
            val_loader,
            test_loader,
            prepare_batch,
            mean_train,
            std_train,
        ) = get_train_val_loaders(
            dataset=config['dataset'],
            # dataset_array=dataset,
            target=config['target'],
            n_train=config['n_train'],
            n_val=config['n_val'],
            n_test=config['n_test'],
            train_ratio=config['train_ratio'],
            val_ratio=config['val_ratio'],
            test_ratio=config['test_ratio'],
            batch_size=config['batch_size'],
            atom_features=config['atom_features'],
            neighbor_strategy=config['neighbor_strategy'],
            standardize=config['atom_features'] != "cgcnn",
            line_graph=line_graph,
            id_tag=config['id_tag'],
            pin_memory=config['pin_memory'],
            workers=config['num_workers'],
            save_dataloader=config['save_dataloader'],
           
            cutoff=config['cutoff'],
            max_neighbors=config['max_neighbors'],
            classification_threshold=config['classification_threshold'],
            
            output_dir=config['output_dir'],
            matrix_input=config['matrix_input'],
            pyg_input=config['pyg_input'],
            use_lattice=config['use_lattice'],
            use_angle=config['use_angle'],
            use_save=use_save,
            mp_id_list=mp_id_list,
        )
    
    else:
        train_loader = train_val_test_loaders[0]
        val_loader = train_val_test_loaders[1]
        test_loader = train_val_test_loaders[2]
        prepare_batch = train_val_test_loaders[3]
    prepare_batch = partial(prepare_batch, device=device)
    if classification:
        config['model']['classification'] = True
    # define network, optimizer, scheduler
    _model = {
        "matformer" : Matformer,
    }
    if std_train is None:
        std_train = 1.0

    net = _model.get(config['model']['name'])(model)
    print("config:")
    pprint.pprint(config['model'])
    
    
    net = PropertyPredictor(config = model)
    
    net.load_state_dict(torch.load("your pretrained model from Latent"),strict=False)

    
    net.to(device)

    if config['distributed']:
        import torch.distributed as dist
        import os

        def setup(rank, world_size):
            os.environ["MASTER_ADDR"] = "localhost"
            os.environ["MASTER_PORT"] = "12355"

            # initialize the process group
            dist.init_process_group("gloo", rank=rank, world_size=world_size)

        def cleanup():
            dist.destroy_process_group()

        setup(2, 2)
        net = torch.nn.parallel.DistributedDataParallel(
            net
        )
    params = group_decay(net)
    optimizer = setup_optimizer(params, config)

    if config['scheduler'] == "none":
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda epoch: 1.0
        )

    elif config['scheduler'] == "onecycle":
        steps_per_epoch = len(train_loader)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config['learning_rate'],
            epochs=config['epochs'],
            steps_per_epoch=steps_per_epoch,
            pct_start=0.1,
        )
    elif config['scheduler'] == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=100000,
            gamma=0.96,
        )

    # select configured loss function
    criteria = {
        "mse": nn.MSELoss(),
        "l1": nn.L1Loss(),
    }
    criterion = criteria["mse"]
    # set up training engine and evaluators
    std_train = 1.0
    # std_train = 0.096403
    metrics = {"loss": Loss(criterion), "mae": MeanAbsoluteError() * std_train, "neg_mae": -1.0 * MeanAbsoluteError() * std_train}
    
    def update_fn(engine, batch):
        net.train()
        optimizer.zero_grad()
        # pdb.set_trace()
        x, y = prepare_batch(batch, device=device)
        y_pred = net(x)
        loss = criterion(y_pred, y)
        loss.backward()
        nn_utils.clip_grad_norm_(net.parameters(), 2.0)
        optimizer.step()
        return loss.item()

    trainer = Engine(update_fn)
    evaluator = create_supervised_evaluator(
        net,
        metrics=metrics,
        prepare_batch=prepare_batch,
        device=device,
    )
    train_evaluator = create_supervised_evaluator(
        net,
        metrics=metrics,
        prepare_batch=prepare_batch,
        device=device,
    )
    
    test_evaluator = create_supervised_evaluator(net,metrics=metrics,prepare_batch=prepare_batch,device=device)

    
    if config['load_ckpt']:
        # checkpoint_tmp = torch.load('checkpoint_file_name')
        checkpoint_tmp = torch.load(config['ckpt_file_name'])
        to_load = {
                "model": net,
                "optimizer": optimizer,
                "lr_scheduler": scheduler,
                "trainer": trainer,
            }

        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_tmp)
    
    if test_only:
        checkpoint_tmp = torch.load('saved_model_path')
        best_m = torch.load("saved_best_model_path")
        checkpoint_tmp['model'] = best_m
        to_load = {
            "model": net,
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "trainer": trainer,
        }
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_tmp)
        net.eval()
        targets = []
        predictions = []
        import time
        t1 = time.time()
        with torch.no_grad():
            for dat in tqdm(test_loader):
                g, lg, _, target = dat
                try:
                    out_data = net([g.to(device), lg.to(device), _.to(device)])
                    success_flag=1
                except: # just in case
                    print('error for this data')
                    print(g)
                    success_flag=0
                if success_flag > 0:
                    out_data = out_data.cpu().numpy().tolist()
                    target = target.cpu().numpy().flatten().tolist()
                    if len(target) == 1:
                        target = target[0]
                    targets.append(target)
                    predictions.append(out_data)
        t2 = time.time()
        f.close()
        from sklearn.metrics import mean_absolute_error
        targets = np.array(targets) #* std_train
        predictions = np.array(predictions) #* std_train
        
        with open(config['output_dir'] + "/target.pkl","wb") as f:
            pk.dump(targets,f)
            
        with open(config['output_dir'] + "/prediction.pkl","wb") as f:
            pk.dump(predictions,f)
        
        print("Test MAE:", mean_absolute_error(targets, predictions)* std_train)
        print("Total test time:", t2-t1)
        return mean_absolute_error(targets, predictions)
    # ignite event handlers:
    trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan())

    trainer.add_event_handler(
        Events.ITERATION_COMPLETED, lambda engine: scheduler.step()
    )
    
    best_loss = float('inf')

    if config['write_checkpoint']:
        # model checkpointing
        to_save = {
            "model": net,
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "trainer": trainer,
        }
        handler = Checkpoint(
            to_save,
            DiskSaver(checkpoint_dir, create_dir=True, require_empty=False),
            n_saved=2,
            global_step_transform=lambda *_: trainer.state.epoch,
        )
        trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
        # evaluate save
        to_save = {"model": net}
        handler = Checkpoint(
            to_save,
            DiskSaver(checkpoint_dir, create_dir=True, require_empty=False),
            n_saved=5,
            filename_prefix='best',
            score_name="neg_mae",
            global_step_transform=lambda *_: trainer.state.epoch,
        )
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler)
    if config['progress']:
        pbar = ProgressBar()
        pbar.attach(trainer, output_transform=lambda x: {"loss": x})
        # pbar.attach(evaluator,output_transform=lambda x: {"mae": x})

    history = {
        "train": {m: [] for m in metrics.keys()},
        "validation": {m: [] for m in metrics.keys()},
        "test": {m: [] for m in metrics.keys()}
    }

    if config['store_outputs']:
        # in history["EOS"]
        eos = EpochOutputStore()
        eos.attach(evaluator)
        train_eos = EpochOutputStore()
        train_eos.attach(train_evaluator)

    # collect evaluation performance
    
    # @trainer.on(Events.ITERATION_COMPLETED)
    # @trainer.on(Events.BACKWARD_COMPLETED)
    # def clip_gradients(engine):
    #     nn_utils.clip_grad_norm_(net.parameters(), 2.0)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(engine):
        """Print training and validation metrics to console."""
        evaluator.run(val_loader)
        test_evaluator.run(test_loader)

        vmetrics = evaluator.state.metrics
        tstmetrics = test_evaluator.state.metrics

        
        for metric in metrics.keys():
            vm = vmetrics[metric]
            tstm = tstmetrics[metric]
            t_metric = metric
            if metric == "roccurve":
                vm = [k.tolist() for k in vm]
            if isinstance(vm, torch.Tensor):
                vm = vm.cpu().numpy().tolist()
            if isinstance(tstm, torch.Tensor):
                tstm = tstm.cpu().numpy().tolist()

            history["validation"][metric].append(vm)
            history["test"][metric].append(tstm)

        
        
        epoch_num = len(history["validation"][t_metric])
        if epoch_num % 20 == 0:
            train_evaluator.run(train_loader)
            tmetrics = train_evaluator.state.metrics
            for metric in metrics.keys():
                tm = tmetrics[metric]
                if metric == "roccurve":
                    tm = [k.tolist() for k in tm]
                if isinstance(tm, torch.Tensor):
                    tm = tm.cpu().numpy().tolist()

                history["train"][metric].append(tm)
        else:
            tmetrics = {}
            tmetrics['mae'] = -1


        # for metric in metrics.keys():
        #    history["train"][metric].append(tmetrics[metric])
        #    history["validation"][metric].append(vmetrics[metric])

        if config['store_outputs']:
            history["EOS"] = eos.data
            history["trainEOS"] = train_eos.data
            dumpjson(
                filename=os.path.join(config['output_dir'], "history_val.json"),
                data=history["validation"],
            )
            dumpjson(
                filename=os.path.join(config['output_dir'], "history_train.json"),
                data=history["train"],
            )
        if config['progress']:
            pbar = ProgressBar()
            if not classification:
                pbar.log_message(f"Epoch: {engine.state.epoch:.1f}")
                pbar.log_message(f"Train_MAE: {tmetrics['mae']:.4f}")
                pbar.log_message(f"Val_MAE: {vmetrics['mae']:.4f}")
                pbar.log_message(f"Test_MAE: {tstmetrics['mae']:.4f}")
            else:
                pbar.log_message(f"Train ROC AUC: {tmetrics['rocauc']:.4f}")
                pbar.log_message(f"Val ROC AUC: {vmetrics['rocauc']:.4f}")
                
        nonlocal best_loss
        if tstmetrics['mae'] < best_loss:
            best_loss = tstmetrics['mae']
        
        print("Best_mae",best_loss)
        print("\n")

    if config['n_early_stopping'] is not None:
        if classification:
            my_metrics = "accuracy"
        else:
            my_metrics = "neg_mae"

        def default_score_fn(engine):
            score = engine.state.metrics[my_metrics]
            return score

        es_handler = EarlyStopping(
            patience=config['n_early_stopping'],
            score_function=default_score_fn,
        )
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, es_handler)
    # optionally log results to tensorboard
    if config['log_tensorboard']:

        tb_logger = TensorboardLogger(
            log_dir=os.path.join(config['output_dir'], "tb_logs", "test")
        )
        for tag, evaluator in [
            ("training", train_evaluator),
            ("validation", evaluator),
        ]:
            tb_logger.attach_output_handler(
                evaluator,
                event_name=Events.EPOCH_COMPLETED,
                tag=tag,
                metric_names=["loss", "mae"],
                global_step_transform=global_step_from_engine(trainer),
            )

    trainer.run(train_loader, max_epochs=config['epochs'])

    if config['log_tensorboard']:
        test_loss = evaluator.state.metrics["loss"]
        tb_logger.writer.add_hparams(config, {"hparam/test_loss": test_loss})
        tb_logger.close()
    if config['write_predictions'] and classification:
        net.eval()
        f = open(
            os.path.join(config['output_dir'], "prediction_results_test_set.csv"),
            "w",
        )
        f.write("id,target,prediction\n")
        targets = []
        predictions = []
        with torch.no_grad():
            ids = test_loader.dataset.ids  # [test_loader.dataset.indices]
            for dat, id in zip(test_loader, ids):
                g, lg, target = dat
                out_data = net([g.to(device), lg.to(device)])
                # out_data = torch.exp(out_data.cpu())
                top_p, top_class = torch.topk(torch.exp(out_data), k=1)
                target = int(target.cpu().numpy().flatten().tolist()[0])

                f.write("%s, %d, %d\n" % (id, (target), (top_class)))
                targets.append(target)
                predictions.append(
                    top_class.cpu().numpy().flatten().tolist()[0]
                )
        f.close()
        from sklearn.metrics import roc_auc_score

        print("predictions", predictions)
        print("targets", targets)
        print(
            "Test ROCAUC:",
            roc_auc_score(np.array(targets), np.array(predictions)),
        )

    if (
        config['write_predictions']
        and not classification
        and config['model']['output_features'] > 1
    ):
        net.eval()
        mem = []
        with torch.no_grad():
            ids = test_loader.dataset.ids  # [test_loader.dataset.indices]
            for dat, id in zip(test_loader, ids):
                g, lg, target = dat
                out_data = net([g.to(device), lg.to(device)])
                out_data = out_data.cpu().numpy().tolist()
                if config['standard_scalar_and_pca']:
                    sc = pk.load(open("sc.pkl", "rb"))
                    out_data = list(
                        sc.transform(np.array(out_data).reshape(1, -1))[0]
                    )  # [0][0]
                target = target.cpu().numpy().flatten().tolist()
                info = {}
                info["id"] = id
                info["target"] = target
                info["predictions"] = out_data
                mem.append(info)
        dumpjson(
            filename=os.path.join(
                config.output_dir, "multi_out_predictions.json"
            ),
            data=mem,
        )
    if (
        config['write_predictions']
        and not classification
        and config['model']['output_features'] == 1
    ):
        net.eval()
        f = open(
            os.path.join(config['output_dir'], "prediction_results_test_set.csv"),
            "w",
        )
        f.write("id,target,prediction\n")
        targets = []
        predictions = []
        with torch.no_grad():
            for dat in test_loader:
                g, lg, _, target = dat
                out_data = net([g.to(device), lg.to(device), lg.to(device)])
                out_data = out_data.cpu().numpy().tolist()
                target = target.cpu().numpy().flatten().tolist()
                if len(target) == 1:
                    target = target[0]
                targets.append(target)
                predictions.append(out_data)
        f.close()
        from sklearn.metrics import mean_absolute_error

        print(
            "Test MAE:",
            mean_absolute_error(np.array(targets), np.array(predictions)) * std_train,
        )
        if config['store_outputs'] and not classification:
            x = []
            y = []
            for i in history["EOS"]:
                x.append(i[0].cpu().numpy().tolist())
                y.append(i[1].cpu().numpy().tolist())
            x = np.array(x, dtype="float").flatten()
            y = np.array(y, dtype="float").flatten()
            f = open(
                os.path.join(
                    config['output_dir'], "prediction_results_train_set.csv"
                ),
                "w",
            )
            # TODO: Add IDs
            f.write("target,prediction\n")
            for i, j in zip(x, y):
                f.write("%6f, %6f\n" % (j, i))
                line = str(i) + "," + str(j) + "\n"
                f.write(line)
            f.close()
    return history


