import argparse
import json
import pathlib
import time
import traceback

import numpy as np
import torch
import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader

from CHANGELOG import MODEL_VERSION, CHECK_VERSION
from tigprompt.data.data_loader import GraphCollator, load_jodie_data_for_node_task
from tigprompt.data.graph import Graph
from tigprompt.eval_utils import eval_node_classification
from tigprompt.model.basic_modules import MLP
from tigprompt.model.feature_getter import NumericalFeature
from tigprompt.utils import BackgroundThreadGenerator

from init_utils import init_model
from train_utils import EarlyStopMonitor, get_logger, hash_args, seed_all
from tigprompt.model.temporal_prompt_generator import TransformerTProG, VanillaTProG, ProjectionTProG


def run(*, prefix, gpu, seed, lr, n_epochs,
        patience, force, ckpt_path, root, bs,
        use_valid, dropout,
        # above are new parameters
        prompter_type,
        teaching_prompt, prompt_teaching_coef,
        lr_prompter, prompt_start, prompt_end, prompt_dim,
        # above are prompting parameters
        data, dim, feature_as_buffer, num_workers,
        hit_type, restarter_type, hist_len,
        n_neighbors, n_layers, n_heads,
        strategy, msg_src, upd_src,
        mem_update_type, msg_tsfm_type,
        embedding_type, dyrep, no_memory,
        **kwargs
        ):
    args = {k: v for k, v in locals().items()
            if not k in {'gpu', 'force', 'kwargs'}}
    HASH = hash_args(**args, MODEL_VERSION=MODEL_VERSION)
    prefix += '.' + HASH
    if not prefix:
        raise ValueError('Prefix should be given explicitly.')
    if gpu == "-1":
        device = torch.device('cpu')
    else:
        device = torch.device(f'cuda:{gpu}')

    RESULT_SAVE_PATH = f"results/{prefix}.json"
    PICKLE_SAVE_PATH = "results/{}.pkl".format(prefix)

    ckpts_dir = pathlib.Path(f"./saved_checkpoints/{prefix}")
    ckpts_dir.mkdir(parents=True, exist_ok=True)
    get_checkpoint_path = lambda epoch: ckpts_dir / f'{epoch}.pth'

    prompter_ckpts_dir = pathlib.Path(f"./saved_prompter/{prefix}")
    prompter_ckpts_dir.mkdir(parents=True, exist_ok=True)
    get_prompter_path = lambda epoch: prompter_ckpts_dir / f'{epoch}.pth'

    logger = get_logger(prefix)
    logger.info(f'[START {prefix}]')
    logger.info(f'Model version: {MODEL_VERSION}')
    logger.info(", ".join([f"{k}={v}" for k, v in args.items()]))

    if pathlib.Path(RESULT_SAVE_PATH).exists() and not force:
        logger.info('Duplicate task! Abort!')
        return False

    try:
        seed_all(seed)
        # ============= Load Data ===========
        (
            nfeats, efeats, full_data, train_data, val_data, test_data, prompt_data, train_and_prompt_data
        ) = load_jodie_data_for_node_task(data, train_seed=seed, root=root,
                                          use_validation=use_valid, prompt_start=prompt_start, prompt_end=prompt_end)

        train_graph = Graph.from_data(train_data, strategy=strategy, seed=seed)
        train_and_prompt_graph = Graph.from_data(train_and_prompt_data, strategy=strategy, seed=seed)
        full_graph = Graph.from_data(full_data, strategy=strategy, seed=seed)

        if restarter_type != 'none':
            train_collator = GraphCollator(train_graph, n_neighbors, n_layers,
                                           restarter=restarter_type, hist_len=hist_len)
            train_and_prompt_collator = GraphCollator(train_and_prompt_graph, n_neighbors, n_layers,
                                                      restarter=restarter_type, hist_len=hist_len)
            eval_collator = GraphCollator(full_graph, n_neighbors, n_layers,
                                          restarter=restarter_type, hist_len=hist_len)
        elif restarter_type == 'none':
            if prompter_type == 'transformer':
                train_collator = GraphCollator(train_graph, n_neighbors, n_layers,
                                               restarter='seq', hist_len=hist_len)
                train_and_prompt_collator = GraphCollator(train_and_prompt_graph, n_neighbors, n_layers,
                                                          restarter='seq', hist_len=hist_len)
                eval_collator = GraphCollator(full_graph, n_neighbors, n_layers,
                                              restarter='seq', hist_len=hist_len)
            elif prompter_type == 'projection' or 'vanilla':
                train_collator = GraphCollator(train_graph, n_neighbors, n_layers,
                                               restarter='static', hist_len=hist_len)
                train_and_prompt_collator = GraphCollator(train_and_prompt_graph, n_neighbors, n_layers,
                                                          restarter='static', hist_len=hist_len)
                eval_collator = GraphCollator(full_graph, n_neighbors, n_layers,
                                              restarter='static', hist_len=hist_len)

        train_dl = DataLoader(train_data, batch_size=bs, collate_fn=train_collator, pin_memory=True, num_workers=num_workers)
        val_dl = DataLoader(val_data, batch_size=bs, collate_fn=eval_collator)
        test_dl = DataLoader(test_data, batch_size=bs, collate_fn=eval_collator)
        prompt_dl = DataLoader(prompt_data, batch_size=bs, collate_fn=train_and_prompt_collator)

        # ============= Init Model ===========
        encoder = init_model(
            nfeats, efeats, train_and_prompt_graph, full_graph, full_data, device,
            feature_as_buffer=feature_as_buffer, dim=dim,
            n_layers=n_layers, n_heads=n_heads, n_neighbors=n_neighbors,
            hit_type=hit_type, dropout=dropout,
            restarter_type=restarter_type, hist_len=hist_len,
            msg_src=msg_src, upd_src=upd_src,
            msg_tsfm_type=msg_tsfm_type, mem_update_type=mem_update_type,
            embedding_type=embedding_type, dyrep=dyrep, no_memory=no_memory,
            and_prompt=False
        )

        if nfeats is not None:
            nfeats = torch.from_numpy(nfeats).float()
            dim = nfeats.shape[1] if dim is None else dim
        if efeats is not None:
            efeats = torch.from_numpy(efeats).float()
            dim = efeats.shape[1] if dim is None else dim

        raw_feat_getter = NumericalFeature(
            nfeats, efeats, dim=dim, register_buffer=feature_as_buffer, device=device
        )
        raw_feat_getter.n_nodes = full_graph.num_node
        raw_feat_getter.n_edges = len(full_data)
        ts_delta_mean, ts_delta_std, *_ = full_data.get_stats()

        if prompter_type == "transformer":
            prompter = TransformerTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                hist_len=hist_len,
                n_head=n_heads, dropout=dropout,
                dyrep=dyrep, prompt_dim=prompt_dim
            ).to(device)
        elif prompter_type == "vanilla":
            prompter = VanillaTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                dyrep=dyrep, prompt_dim=prompt_dim
                ).to(device)
        elif prompter_type == "projection":
            prompter = ProjectionTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                dyrep=dyrep, prompt_dim=prompt_dim
                ).to(device)


        # load model ckpt
        encoder.load_state_dict(torch.load(ckpt_path, map_location=device))
        encoder.eval()

        decoder = MLP(encoder.nfeat_dim, dropout=dropout).to(device)
        loss_fn = nn.BCEWithLogitsLoss()
        prompt_loss_fn = nn.MSELoss()

        optimizer = optim.Adam(decoder.parameters(), lr=lr)
        optimizer_prompter = optim.Adam(prompter.parameters(), lr=lr_prompter)

        val_aucs = []
        if use_valid:
            early_stopper = EarlyStopMonitor(max_round=patience)
        for epoch in range(n_epochs):
            start_epoch_t0 = time.time()
            logger.info('Start {} epoch'.format(epoch))

            m_loss = []
            m_node_class_loss = []
            m_prompt_loss = []

            it = BackgroundThreadGenerator(prompt_dl)
            it = tqdm.tqdm(it, total=len(prompt_dl), ncols=50)

            encoder.reset()
            decoder.train()
            prompter.train()
            for i_batch, (src_ids, dst_ids, neg_dst_ids, ts, eids, labels, comp_graph) in enumerate(it):
                bs = len(src_ids)
                src_ids = src_ids.long().to(device)
                dst_ids = dst_ids.long().to(device)
                neg_dst_ids = neg_dst_ids.long().to(device)
                ts = ts.float().to(device)
                eids = eids.long().to(device)
                labels = labels.float().to(device)
                comp_graph.to(device)
                with torch.no_grad():
                    _1, h, _2, _3, h_prev_left, h_prev_right = encoder.contrast_learning(src_ids, dst_ids, neg_dst_ids,
                                                                                         ts, eids, comp_graph)
                optimizer.zero_grad()
                optimizer_prompter.zero_grad()

                index = comp_graph.restart_data.index
                unique_nids = torch.cat([src_ids, dst_ids])[index]
                unique_ts = ts.repeat(2)[index]

                # todo: to delete
                index_neg = comp_graph.restart_data_neg.index
                unique_nids_neg = torch.cat([src_ids, neg_dst_ids])[index_neg]
                unique_ts_neg = ts.repeat(2)[index_neg]

                index = index.to(device)
                unique_nids = unique_nids.to(device)
                unique_ts = unique_ts.to(device)

                # todo: to delete
                index_neg = index_neg.to(device)
                unique_nids_neg = unique_nids_neg.to(device)
                unique_ts_neg = unique_ts_neg.to(device)

                pos_h = h[:bs]
                surrogate_h_prev_left, surrogate_h_prev_right, prompted_h = prompter(unique_nids, unique_ts, pos_h, src_ids)

                pred_y = decoder(prompted_h)  # only positive nodes
                node_class_loss = loss_fn(pred_y, labels)

                # if no_memory, cannot teaching prompt
                if not no_memory:
                    if teaching_prompt:
                        targets = torch.cat([h_prev_left[index], h_prev_right[index]], 0)  # [2n, d]
                        preds = torch.cat([surrogate_h_prev_left, surrogate_h_prev_right], 0)  # [2n, d]
                        valid_rows = torch.where(~(targets == 0).all(1))[0]

                        if len(valid_rows):
                            prompt_loss = prompt_loss_fn(preds[valid_rows], targets[valid_rows].detach())
                        else:
                            prompt_loss = torch.tensor(0, device=node_class_loss.device)
                    else:
                        prompt_loss = torch.tensor(0, device=node_class_loss.device)
                else:
                    prompt_loss = torch.tensor(0, device=node_class_loss.device)


                loss = node_class_loss + prompt_teaching_coef * prompt_loss
                loss.backward()

                optimizer.step()
                optimizer_prompter.step()

                m_node_class_loss.append(node_class_loss.item())
                m_prompt_loss.append(prompt_loss.item())
                m_loss.append(loss.item())

            epoch_time = time.time() - start_epoch_t0

            val_auc = eval_node_classification(encoder, decoder, prompter, val_dl, device)
            val_aucs.append(val_auc)

            logger.info('Epoch {:4d} training took  {:.2f}s'.format(epoch, epoch_time))
            logger.info(f'Epoch mean classification loss: {np.mean(m_node_class_loss):.4f}')
            logger.info(f'Epoch mean prompter loss: {np.mean(m_prompt_loss):.4f}')
            logger.info(f'Epoch mean loss: {np.mean(m_loss):.4f}')
            logger.info(f'Epoch validation auc: {val_auc:.4f}')

            if use_valid:
                if early_stopper.early_stop_check(val_auc):
                    logger.info('No improvement over {} epochs, stop training'.format(
                        early_stopper.max_round))
                    break
                else:
                    torch.save(decoder.state_dict(), get_checkpoint_path(epoch))
                    torch.save(prompter.state_dict(), get_prompter_path(epoch))

        if use_valid:
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            best_model_path = get_checkpoint_path(early_stopper.best_epoch)
            best_prompter_path = get_prompter_path(early_stopper.best_epoch)

            model_state = torch.load(best_model_path)
            decoder.load_state_dict(model_state)

            prompter_state = torch.load(best_prompter_path)
            prompter.load_state_dict(prompter_state)

            best_val_auc = val_aucs[early_stopper.best_epoch]
            logger.info(f'[ Val] auc: {best_val_auc:.4f}')

            test_auc = eval_node_classification(encoder, decoder, prompter, test_dl, device)
        else:
            logger.info('No validation set. Use the last epoch result.')
            test_auc = val_aucs[-1]

        logger.info(f'[Test] auc: {test_auc:.4f}')

        results = args.copy()
        results.update(
            prefix=prefix,
            VERSION=MODEL_VERSION,
            test_auc=test_auc,
        )
        json.dump(results, open(RESULT_SAVE_PATH, 'w'))

    except Exception as e:
        logger.error(traceback.format_exc())
        logger.error(e)
        raise


def get_args():
    parser = argparse.ArgumentParser()
    # Exp Setting
    parser.add_argument('--code', type=str, default='', help='Name of the saved result and model')
    parser.add_argument('--json', type=str, default='', help='Path to model result (json file)')
    parser.add_argument('--ckpt', type=str, default='', help='Path to model check point')
    parser.add_argument('--root', type=str, default='.', help='Dataset root')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--gpu', type=str, default='0', help='Cuda index')

    parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--bs', type=int, default=100, help='Batch size')
    parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout probability')
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--use_valid', action='store_true', help='Use validation set')

    parser.add_argument('--force', action='store_true', help='Overwirte the existing task')
    
    parser.add_argument('--prompter_type', type=str, default="transformer", choices=["transformer", "vanilla", "projection"], help='TProG type')
    parser.add_argument('--teaching_prompt', action='store_true', help='teaching prompt when fine tuning')
    parser.add_argument('--prompt_teaching_coef', type=float, default=1.0, help='prompt teaching loss coef')
    parser.add_argument('--lr_prompter', type=float, default=0.0001, help='Learning rate for prompter')

    parser.add_argument('--prompt_end', type=float, default=0.7, help='quantile of prompting end time, this value should >= prompt_start')

    parser.add_argument('--prompt_dim', type=int, default=None, help='dim of prompts')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()

    if args.code:
        with open(f'./results/{args.code}.json') as fh:
            saved_results = json.load(fh)
        ckpt_path = f'./saved_models/{args.code}.pth'
    else:
        with open(args.json) as fh:
            saved_results = json.load(fh)
        ckpt_path = args.ckpt_path

    if not CHECK_VERSION(saved_results['VERSION'], MODEL_VERSION):
        raise ValueError('version not match: {} != {}'.format(
            saved_results['VERSION'], MODEL_VERSION))

    prefix = saved_results['HASH'] if saved_results['prefix'] == '' else saved_results['prefix']
    prefix += '-fine_tune-node'
    kwargs = {k: v for k, v in saved_results.items() if k not in
              {'prefix', 'seed', 'lr', 'n_epochs', 'bs', 'patience', 'root', 'dropout', 'prompt_start', 'prompt_end'}
              }

    run(
        prefix=prefix, gpu=args.gpu, seed=args.seed,
        lr=args.lr, dropout=args.dropout, bs=args.bs, n_epochs=args.n_epochs,
        patience=args.patience, force=args.force,
        use_valid=args.use_valid, root=args.root, ckpt_path=ckpt_path,
        prompter_type=args.prompter_type,
        teaching_prompt=args.teaching_prompt, prompt_teaching_coef=args.prompt_teaching_coef,
        lr_prompter=args.lr_prompter, prompt_start=saved_results['prompt_start'], prompt_end=args.prompt_end,
        prompt_dim=args.prompt_dim,
        **kwargs
    )
