import numpy as np

import timeit

import os
import os.path as osp
from pathlib import Path
import logging
import torch

# internal imports
from tgb.utils.utils import set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from models.tgnw import WeightedGraphAttentionEmbedding, WeightedTGN
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator, MeanAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import WeightedTGNMemory
from modules.early_stopping import EarlyStopMonitorModular
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.utils.utils import load_pkl
from seml.experiment import Experiment

# from seml.observers import add_to_file_storage_observer
from sacred.observers import FileStorageObserver
from utils.tgnw_linkpred import train_one_epoch, test

from permissive_dict import PermissiveDict as edict

# ==========
# ==========
# ==========

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Start...
start_overall = timeit.default_timer()

ex = Experiment()

ex.observers.append(
    FileStorageObserver("./tgb_runs", copy_sources=True),
)


# seml.observers.add_to_file_storage_observer(runs)
@ex.automain
def run(
    dataset_name="tgbl-wiki",
    lr=1e-4,
    bs=200,
    num_epoch=50,
    seed=1,
    mem_dim=100,
    time_dim=100,
    emb_dim=100,
    tolerance=1e-6,
    num_runs=1,
    patience=5,
    num_neighbors=10,
    adv_attack=False,
    check_every=3,
    attack_params={},
    memory_update_position: str = "front",
    detect_anomalies=False,
):

    MODEL_NAME = "TGNW"
    # ==========

    # set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # data loading
    dataset = PyGLinkPropPredDataset(name=dataset_name, root="datasets")
    train_mask = dataset.train_mask
    val_mask = dataset.val_mask
    test_mask = dataset.test_mask
    data = dataset.get_TemporalData()
    data = data.to(device)
    metric = dataset.eval_metric

    # neighhorhood sampler
    neighbor_loader = LastNeighborLoader(
        data.num_nodes, size=num_neighbors, device=device
    )

    # define the model end-to-end
    memory = WeightedTGNMemory(
        data.num_nodes,
        data.msg.size(-1),
        mem_dim,
        time_dim,
        message_module=IdentityMessage(data.msg.size(-1), mem_dim, time_dim),
        aggregator_module=MeanAggregator(),
    ).to(device)

    gnn = WeightedGraphAttentionEmbedding(
        in_channels=mem_dim,
        out_channels=emb_dim,
        msg_dim=data.msg.size(-1),
        time_enc=memory.time_enc,
    ).to(device)

    link_pred = LinkPredictor(in_channels=emb_dim).to(device)

    tgn_model = WeightedTGN(
        memory, gnn, link_pred, neighbor_loader=neighbor_loader, device=device
    )

    criterion = torch.nn.BCEWithLogitsLoss()

    # Helper vector to map global node indices to local ones.
    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)

    logging.info("==========================================================")
    logging.info(
        f"=================*** {MODEL_NAME}: LinkPropPred: {dataset_name} ***============="
    )
    logging.info("==========================================================")

    evaluator = Evaluator(name=dataset_name)
    neg_sampler = dataset.negative_sampler

    # for saving the results...
    results_path = (
        "saved_results"  # f'{osp.dirname(osp.abspath(__file__))}/saved_results'
    )
    save_model_dir = "./tgb_runs/saved_models"
    if not osp.exists(results_path):
        Path(results_path).mkdir(parents=True)
        logging.info("INFO: Create directory {}".format(results_path))

    if not osp.exists(save_model_dir):
        os.mkdir(save_model_dir)

    results_filename = f"{results_path}/{MODEL_NAME}_{dataset_name}_results.json"
    test_metric_list = []
    for run_idx in range(num_runs):
        logging.info(
            "-------------------------------------------------------------------------------"
        )
        logging.info(f"INFO: >>>>> Run: {run_idx} <<<<<")

        tgn_model.reset_parameters()
        optimizer = torch.optim.Adam(
            set(tgn_model.memory.parameters())
            | set(tgn_model.gnn.parameters())
            | set(tgn_model.link_pred.parameters()),
            lr=lr,
        )

        start_run = timeit.default_timer()

        # set the seed for deterministic results...
        torch.manual_seed(run_idx + seed)
        torch.cuda.manual_seed_all(run_idx + seed)
        torch.backends.cudnn.deterministic = True
        set_random_seed(run_idx + seed)

        # define an early stopper
        # save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
        # timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
        save_model_id = f"{MODEL_NAME}_{dataset_name}_S{seed}_R{run_idx}_BS{bs}"
        early_stopper = EarlyStopMonitorModular(
            save_model_dir=save_model_dir,
            save_model_id=save_model_id,
            tolerance=tolerance,
            patience=patience,
        )

        # ==================================================== Train & Validation
        # loading the validation negative samples
        dataset.load_val_ns()

        val_perf_list = []
        start_train_val = timeit.default_timer()

        # num_epoch = 1
        train_val_time = 0
        log_edges = False
        if log_edges:
            # to dump edges as csv file
            train_edges = data[train_mask]

            train_edges_list = [
                (
                    train_edges.src[i].item(),
                    train_edges.dst[i].item(),
                    train_edges.t[i].item(),
                    "train",
                )
                for i in range(len(train_edges))
            ]

            val_edges = data[val_mask]

            val_edges_list = [
                (
                    val_edges.src[i].item(),
                    val_edges.dst[i].item(),
                    val_edges.t[i].item(),
                    "val",
                )
                for i in range(len(val_edges))
            ]
            import pandas as pd

            df = pd.DataFrame(
                train_edges_list + val_edges_list, columns=["src", "dst", "t", "type"]
            )
            df.to_csv(
                f"edge_lists/train_val_edges__tgnw-{neg_sampler.dataset_name}.csv",
                index=None,
            )
            return
        if not osp.exists(early_stopper.get_best_model_path()):
            # model doesnt exist, train it
            for epoch in range(1, num_epoch + 1):
                # training
                start_epoch_train = timeit.default_timer()
                loss = train_one_epoch(
                    tgn_model, data, train_mask, bs, optimizer, criterion, assoc, device
                )
                logging.info(
                    f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
                )

                if epoch % check_every == 0:
                    # validation
                    start_val = timeit.default_timer()
                    perf_metric_val = test(
                        tgn_model,
                        data,
                        val_mask,
                        bs,
                        neg_sampler,
                        "val",
                        assoc,
                        metric,
                        evaluator,
                        device,
                    )
                    logging.info(f"\tValidation {metric}: {perf_metric_val: .4f}")
                    logging.info(
                        f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}"
                    )
                    val_perf_list.append(perf_metric_val)

                    # check for early stopping
                    if early_stopper.step_check(
                        perf_metric_val, tgn_model, neighbor_loader
                    ):
                        break

            train_val_time = timeit.default_timer() - start_train_val
            logging.info(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
        else:
            tgn_model.memory.reset_parameters()

        # ==================================================== Test
        # first, load the best model
        neighbor_loader = early_stopper.load_checkpoint(tgn_model)

        # loading the test negative samples
        dataset.load_test_ns()

        if adv_attack:
            attack_params = edict(attack_params)
            attack_type = attack_params.attack_type

            # Handle both scalar and tensor timestamps
            times = []
            for s in data[train_mask]:
                t = s.t.cpu()
                if t.dim() == 0:  # scalar
                    times.append(t.item())
                else:  # tensor
                    times.extend(t.numpy().tolist())
            times = np.array(times)
            rel_times = times[1:] - times[:-1]
            train_time_sd = np.std(rel_times)
            if attack_type.startswith("negatt") or (
                "hist_init" in attack_params and attack_params.hist_init
            ):
                train_edges = {
                    (s.src.item(), s.dst.item(), s.t.item()): s.msg
                    for s in data[train_mask]
                }
                val_edges = {
                    (s.src.item(), s.dst.item(), s.t.item()): s.msg
                    for s in data[val_mask]
                }

                adv_test_edges = train_edges | val_edges

                # load testNegEdgeSet
                filename = f"artefacts/{dataset_name}_recent-neg_B{bs}.pkl"
                logging.info(f"Loading negset from {filename}")
                assert os.path.exists(filename)
                neg_set = load_pkl(filename)

                attack_params["train_time_rel_sd"] = train_time_sd
                attack_params["neg_set"] = neg_set
                attack_params["adv_test_edges"] = adv_test_edges
            if attack_type.startswith("prbcd") or attack_type.startswith("grbcd"):
                attack_params["train_time_rel_sd"] = train_time_sd

        start_test = timeit.default_timer()
        perf_metric_test = test(
            tgn_model,
            data,
            test_mask,
            bs,
            neg_sampler,
            "test",
            assoc,
            metric,
            evaluator,
            device,
            attack_params=attack_params,
            detect_anomalies=detect_anomalies,
        )

        logging.info(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
        logging.info(f"\tTest: {metric}: {perf_metric_test: .4f}")
        test_time = timeit.default_timer() - start_test
        logging.info(f"\tTest: Elapsed Time (s): {test_time: .4f}")
        test_metric_list.append(perf_metric_test)
        save_results(
            {
                "model": MODEL_NAME,
                "data": dataset_name,
                "run": run_idx,
                "seed": seed,
                f"val {metric}": val_perf_list,
                f"test {metric}": perf_metric_test,
                "test_time": test_time,
                "tot_train_val_time": train_val_time,
            },
            results_filename,
        )

        logging.info(
            f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<"
        )
        logging.info(
            "-------------------------------------------------------------------------------"
        )
    ex.add_artifact(results_filename)
    ex.add_artifact(early_stopper.get_best_model_path())

    logging.info(
        f"Aggregate Test {metric}: {np.mean(test_metric_list):.4f} ± {np.std(test_metric_list):.4f}"
    )
    logging.info(
        f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}"
    )
    logging.info("==============================================================")
