import argparse
import json
import pathlib
import time
import traceback

import numpy as np
import torch
import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader

from CHANGELOG import MODEL_VERSION, CHECK_VERSION
from tigprompt.data.data_loader import GraphCollator, load_jodie_data_for_node_task
from tigprompt.data.graph import Graph
from tigprompt.eval_utils import eval_node_classification
from tigprompt.model.basic_modules import MLP
from tigprompt.model.feature_getter import NumericalFeature
from tigprompt.utils import BackgroundThreadGenerator

from init_utils import init_model
from train_utils import EarlyStopMonitor, get_logger, hash_args, seed_all
from tigprompt.model.temporal_prompt_generator import TransformerTProG, VanillaTProG, ProjectionTProG


def run(*, prefix, gpu, seed, lr, n_epochs,
        patience, force, ckpt_path, prompter_ckpt_path, root, bs,
        use_valid, dropout, optimizing_prompter, lr_prompter, pretrained_prompter,
        # above are new parameters
        data, dim, feature_as_buffer, num_workers,
        hit_type, restarter_type, hist_len,
        n_neighbors, n_layers, n_heads,
        strategy, msg_src, upd_src,
        mem_update_type, msg_tsfm_type,
        embedding_type, dyrep, no_memory,
        and_prompt, prompter_type, prompt_start, prompt_end, prompt_dim,
        **kwargs
        ):
    args = {k: v for k, v in locals().items()
            if not k in {'gpu', 'force', 'kwargs'}}
    HASH = hash_args(**args, MODEL_VERSION=MODEL_VERSION)
    prefix += '.' + HASH
    if not prefix:
        raise ValueError('Prefix should be given explicitly.')
    if gpu == "-1":
        device = torch.device('cpu')
    else:
        device = torch.device(f'cuda:{gpu}')

    RESULT_SAVE_PATH = f"results/{prefix}.json"
    PICKLE_SAVE_PATH = "results/{}.pkl".format(prefix)

    ckpts_dir = pathlib.Path(f"./saved_checkpoints/{prefix}")
    ckpts_dir.mkdir(parents=True, exist_ok=True)
    get_checkpoint_path = lambda epoch: ckpts_dir / f'{epoch}.pth'

    prompter_ckpts_dir = pathlib.Path(f"./saved_prompter_ckpts/{prefix}")
    prompter_ckpts_dir.mkdir(parents=True, exist_ok=True)
    get_prompter_path = lambda epoch: prompter_ckpts_dir / f'{epoch}.pth'

    logger = get_logger(prefix)
    logger.info(f'[START {prefix}]')
    logger.info(f'Model version: {MODEL_VERSION}')
    logger.info(", ".join([f"{k}={v}" for k, v in args.items()]))

    if pathlib.Path(RESULT_SAVE_PATH).exists() and not force:
        logger.info('Duplicate task! Abort!')
        return False

    try:
        seed_all(seed)
        # ============= Load Data ===========
        (
            nfeats, efeats, full_data, train_data, val_data, test_data, prompt_data, train_and_prompt_data
        ) = load_jodie_data_for_node_task(data, train_seed=seed, root=root,
                                          use_validation=use_valid, prompt_start=prompt_start, prompt_end=prompt_end)

        train_graph = Graph.from_data(train_data, strategy=strategy, seed=seed)
        train_and_prompt_graph = Graph.from_data(train_and_prompt_data, strategy=strategy, seed=seed)
        full_graph = Graph.from_data(full_data, strategy=strategy, seed=seed)

        if restarter_type != 'none':
            train_collator = GraphCollator(train_graph, n_neighbors, n_layers,
                                           restarter=restarter_type, hist_len=hist_len)
            train_and_prompt_collator = GraphCollator(train_and_prompt_graph, n_neighbors, n_layers,
                                                      restarter=restarter_type, hist_len=hist_len)
            eval_collator = GraphCollator(full_graph, n_neighbors, n_layers,
                                          restarter=restarter_type, hist_len=hist_len)
        elif restarter_type == 'none':
            if prompter_type == 'transformer':
                train_collator = GraphCollator(train_graph, n_neighbors, n_layers,
                                               restarter='seq', hist_len=hist_len)
                train_and_prompt_collator = GraphCollator(train_and_prompt_graph, n_neighbors, n_layers,
                                                          restarter='seq', hist_len=hist_len)
                eval_collator = GraphCollator(full_graph, n_neighbors, n_layers,
                                              restarter='seq', hist_len=hist_len)
            elif prompter_type == 'projection' or 'vanilla':
                train_collator = GraphCollator(train_graph, n_neighbors, n_layers,
                                               restarter='static', hist_len=hist_len)
                train_and_prompt_collator = GraphCollator(train_and_prompt_graph, n_neighbors, n_layers,
                                                          restarter='static', hist_len=hist_len)
                eval_collator = GraphCollator(full_graph, n_neighbors, n_layers,
                                              restarter='static', hist_len=hist_len)

        train_dl = DataLoader(train_data, batch_size=bs, collate_fn=train_collator, pin_memory=True, num_workers=num_workers)
        val_dl = DataLoader(val_data, batch_size=bs, collate_fn=eval_collator)
        test_dl = DataLoader(test_data, batch_size=bs, collate_fn=eval_collator)
        prompt_dl = DataLoader(prompt_data, batch_size=bs, collate_fn=train_and_prompt_collator)
        train_and_prompt_dl = DataLoader(train_and_prompt_data, batch_size=bs, collate_fn=train_and_prompt_collator)

        # ============= Init Model ===========
        encoder = init_model(
            nfeats, efeats, train_and_prompt_graph, full_graph, full_data, device,
            feature_as_buffer=feature_as_buffer, dim=dim,
            n_layers=n_layers, n_heads=n_heads, n_neighbors=n_neighbors,
            hit_type=hit_type, dropout=dropout,
            restarter_type=restarter_type, hist_len=hist_len,
            msg_src=msg_src, upd_src=upd_src,
            msg_tsfm_type=msg_tsfm_type, mem_update_type=mem_update_type,
            embedding_type=embedding_type, dyrep=dyrep, no_memory=no_memory,
            and_prompt=and_prompt, prompter_type=prompter_type
        )

        if nfeats is not None:
            nfeats = torch.from_numpy(nfeats).float()
            dim = nfeats.shape[1] if dim is None else dim
        if efeats is not None:
            efeats = torch.from_numpy(efeats).float()
            dim = efeats.shape[1] if dim is None else dim

        raw_feat_getter = NumericalFeature(
            nfeats, efeats, dim=dim, register_buffer=feature_as_buffer, device=device
        )
        raw_feat_getter.n_nodes = full_graph.num_node
        raw_feat_getter.n_edges = len(full_data)
        ts_delta_mean, ts_delta_std, *_ = full_data.get_stats()

        if prompter_type == "transformer":
            prompter = TransformerTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                hist_len=hist_len,
                n_head=n_heads, dropout=dropout,
                dyrep=dyrep, prompt_dim=prompt_dim
            ).to(device)
        elif prompter_type == "vanilla":
            prompter = VanillaTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                dyrep=dyrep, prompt_dim=prompt_dim
                ).to(device)
        elif prompter_type == "projection":
            prompter = ProjectionTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                dyrep=dyrep, prompt_dim=prompt_dim
            ).to(device)

        # load model ckpt
        encoder.load_state_dict(torch.load(ckpt_path, map_location=device))
        # if not optimizing_prompter, need using the pretrained prompter
        if not optimizing_prompter and not pretrained_prompter:
            pretrained_prompter = True
        if pretrained_prompter:
            prompter.load_state_dict(torch.load(prompter_ckpt_path, map_location=device))

        encoder.eval()

        decoder = MLP(encoder.nfeat_dim, dropout=dropout).to(device)
        loss_fn = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(decoder.parameters(), lr=lr)

        if optimizing_prompter:
            optimizer_prompter = optim.Adam(prompter.parameters(), lr=lr_prompter)

        val_aucs = []
        if use_valid:
            early_stopper = EarlyStopMonitor(max_round=patience)
        for epoch in range(n_epochs):
            start_epoch_t0 = time.time()
            logger.info('Start {} epoch'.format(epoch))

            m_loss = []
            it = BackgroundThreadGenerator(train_and_prompt_dl)
            it = tqdm.tqdm(it, total=len(train_and_prompt_dl), ncols=50)

            encoder.reset()
            decoder.train()
            if optimizing_prompter:
                prompter.train()
            else:
                prompter.eval()
            for i_batch, (src_ids, dst_ids, neg_dst_ids, ts, eids, labels, comp_graph) in enumerate(it):
                bs = len(src_ids)
                src_ids = src_ids.long().to(device)
                dst_ids = dst_ids.long().to(device)
                neg_dst_ids = neg_dst_ids.long().to(device)
                ts = ts.float().to(device)
                eids = eids.long().to(device)
                labels = labels.float().to(device)
                comp_graph.to(device)
                with torch.no_grad():
                    _1, h, _2, _3, h_prev_left, h_prev_right = encoder.contrast_learning(src_ids, dst_ids, neg_dst_ids,
                                                                                         ts, eids, comp_graph)
                optimizer.zero_grad()
                if optimizing_prompter:
                    optimizer_prompter.zero_grad()

                index = comp_graph.restart_data.index
                unique_nids = torch.cat([src_ids, dst_ids])[index]
                unique_ts = ts.repeat(2)[index]

                index = index.to(device)
                unique_nids = unique_nids.to(device)
                unique_ts = unique_ts.to(device)

                pos_h = h[:bs]
                surrogate_h_prev_left, surrogate_h_prev_right, prompted_h = prompter(unique_nids, unique_ts, pos_h, src_ids)

                pred_y = decoder(prompted_h)  # only positive nodes

                loss = loss_fn(pred_y, labels)
                loss.backward()
                optimizer.step()
                if optimizing_prompter:
                    optimizer_prompter.step()
                m_loss.append(loss.item())

            epoch_time = time.time() - start_epoch_t0

            val_auc = eval_node_classification(encoder, decoder, prompter, val_dl, device)
            val_aucs.append(val_auc)

            logger.info('Epoch {:4d} training took  {:.2f}s'.format(epoch, epoch_time))
            logger.info(f'Epoch mean loss: {np.mean(m_loss):.4f}')
            logger.info(f'Epoch validation auc: {val_auc:.4f}')

            if use_valid:
                if early_stopper.early_stop_check(val_auc):
                    logger.info('No improvement over {} epochs, stop training'.format(
                        early_stopper.max_round))
                    break
                else:
                    torch.save(decoder.state_dict(), get_checkpoint_path(epoch))
                    if optimizing_prompter:
                        torch.save(prompter.state_dict(), get_prompter_path(epoch))

        if use_valid:
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            best_model_path = get_checkpoint_path(early_stopper.best_epoch)
            model_state = torch.load(best_model_path)
            decoder.load_state_dict(model_state)

            if optimizing_prompter:
                best_prompter_path = get_prompter_path(early_stopper.best_epoch)
                prompter_state = torch.load(best_prompter_path)
                prompter.load_state_dict(prompter_state)

            best_val_auc = val_aucs[early_stopper.best_epoch]
            logger.info(f'[ Val] auc: {best_val_auc:.4f}')

            test_auc = eval_node_classification(encoder, decoder, prompter, test_dl, device)
        else:
            logger.info('No validation set. Use the last epoch result.')
            test_auc = val_aucs[-1]

        logger.info(f'[Test] auc: {test_auc:.4f}')

        results = args.copy()
        results.update(
            prefix=prefix,
            VERSION=MODEL_VERSION,
            test_auc=test_auc,
        )
        json.dump(results, open(RESULT_SAVE_PATH, 'w'))

    except Exception as e:
        logger.error(traceback.format_exc())
        logger.error(e)
        raise


def get_args():
    parser = argparse.ArgumentParser()
    # Exp Setting
    parser.add_argument('--code', type=str, default='', help='Name of the saved result and model')
    parser.add_argument('--json', type=str, default='', help='Path to model result (json file)')
    parser.add_argument('--ckpt', type=str, default='', help='Path to model check point')
    parser.add_argument('--prompter_ckpt', type=str, default='', help='Path to model check point')
    parser.add_argument('--root', type=str, default='.', help='Dataset root')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--gpu', type=str, default='0', help='Cuda index')

    parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--bs', type=int, default=100, help='Batch size')
    parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout probability')
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--use_valid', action='store_true', help='Use validation set')

    parser.add_argument('--force', action='store_true', help='Overwirte the existing task')

    # prompt
    parser.add_argument('--lr_prompter', type=float, default=3e-4, help='Learning rate for prompter')
    parser.add_argument('--optimizing_prompter', action='store_true', help='also optimize prompter')
    parser.add_argument('--pretrained_prompter', action='store_true', help='using pretrained prompter to init prompter')

    parser.add_argument('--prompt_end', type=float, default=0.7, help='quantile of prompting end time, this value should >= prompt_start')

    parser.add_argument('--prompt_dim', type=int, default=None, help='dim of prompts')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()

    if args.code:
        with open(f'./results/{args.code}.json') as fh:
            saved_results = json.load(fh)
        ckpt_path = f'./saved_models/{args.code}.pth'
        prompter_ckpt_path = f'./saved_prompters/{args.code}.pth'
    else:
        with open(args.json) as fh:
            saved_results = json.load(fh)
        ckpt_path = args.ckpt
        prompter_ckpt_path = args.prompter_ckpt

    if not CHECK_VERSION(saved_results['VERSION'], MODEL_VERSION):
        raise ValueError('version not match: {} != {}'.format(
            saved_results['VERSION'], MODEL_VERSION))

    prefix = saved_results['HASH'] if saved_results['prefix'] == '' else saved_results['prefix']
    prefix += '-node-fine_tune_mode'
    kwargs = {k: v for k, v in saved_results.items() if k not in
              {'prefix', 'seed', 'lr', 'n_epochs', 'bs', 'patience', 'root', 'dropout', 'use_valid', 'ckpt_path',
               'prompt_start', 'prompt_end', 'prompt_dim'}
              }

    run(
        prefix=prefix, gpu=args.gpu, seed=args.seed,
        lr=args.lr, dropout=args.dropout, bs=args.bs, n_epochs=args.n_epochs,
        patience=args.patience, force=args.force,
        use_valid=args.use_valid, root=args.root, ckpt_path=ckpt_path, prompter_ckpt_path=prompter_ckpt_path,
        and_prompt=False,
        optimizing_prompter=args.optimizing_prompter, lr_prompter=args.lr_prompter,
        pretrained_prompter=args.pretrained_prompter, prompt_start=saved_results['prompt_start'], prompt_end=args.prompt_end,
        prompt_dim=args.prompt_dim,
        **kwargs
    )
