
import os
import sys
import json
import yaml
import copy
import pandas as pd

from utils.data.load_graph import load_graph, get_clusters_from_graph
from utils.data.positional_encoding import landmark_spd_features
from utils.training.steps import train_step, valid_step
from utils.training.sampling_functions import get_validation_pairs
from utils.training.loss import ContrastiveLoss

from evaluate import evaluate

import torch
from torch_geometric.nn import GAT, GCN, GraphSAGE
from torch.optim.lr_scheduler import ReduceLROnPlateau

ROOT_DIR = os.path.dirname(__file__)
sys.path.append(ROOT_DIR)
DATASET_DIR = os.path.join(ROOT_DIR, "dataset")


def train(config: dict, verbose: bool = True):

    # graph for the training
    train_index = 0
    train_graph = load_graph(folder=DATASET_DIR, tp="train", index=train_index,
                             undirected=config["training"]["undirected"],
                             use_edge_features=config["training"]["use_edge_features"])
    _, train_clusters = get_clusters_from_graph(graph=train_graph)
    if config["positional_encoding"]["use"]:
        train_graph = landmark_spd_features(data=train_graph,
                                            num_landmarks=config["positional_encoding"]["num_landmarks"])


    # samples for validation
    valid_graph = load_graph(folder=DATASET_DIR, tp="valid", undirected=config["training"]["undirected"],
                             use_edge_features=config["training"]["use_edge_features"])
    if config["positional_encoding"]["use"]:
        valid_graph = landmark_spd_features(data=valid_graph,
                                            num_landmarks=config["positional_encoding"]["num_landmarks"])
    valid_batch, valid_pos_pairs, valid_neg_pairs = get_validation_pairs(graph=valid_graph, **config["training"],
                                                                         **config["validation"])


    # initiate the model, optimize, scheduler
    if config["model"]["name"] == "GAT":
        model = GAT(in_channels=train_graph.x.shape[1],
                    edge_dim=train_graph.edge_attr.size(-1) if config["training"]["use_edge_features"] else None,
                    **config["model"]["params"])
    elif config["model"]["name"] == "GCN":
        assert not config["training"]["use_edge_features"]
        model = GCN(in_channels=train_graph.x.shape[1], **config["model"]["params"])
    elif config["model"]["name"] == "GraphSAGE":
        assert not config["training"]["use_edge_features"]
        model = GraphSAGE(in_channels=train_graph.x.shape[1], **config["model"]["params"])
    else:
        raise ValueError
    optimizer = torch.optim.Adam(model.parameters(), **config["optimizer"]["params"])
    scheduler = ReduceLROnPlateau(optimizer, mode="min", **config["scheduler"]["params"])

    metrics: dict[int, dict] = {}
    best_valid_loss = float("inf")
    best_model_state = copy.deepcopy(model.state_dict())
    os.makedirs(config["save_folder"], exist_ok=True)

    contrastive_loss_fn = ContrastiveLoss(temperature=0.07)

    try:

        # iterate over the epochs
        for epoch in range(config["training"]["num_epochs"]):

            # rotate the training graph
            if (epoch + 1) % config["training"]["freq_rotation_graph"] == 0:
                train_index += 1
                train_graph = load_graph(folder=DATASET_DIR, tp="train", index=train_index % 3,
                                         undirected=config["training"]["undirected"],
                                         use_edge_features=config["training"]["use_edge_features"])
                _, train_clusters = get_clusters_from_graph(graph=train_graph)
                if config["positional_encoding"]["use"]:
                    train_graph = landmark_spd_features(data=train_graph,
                                                        num_landmarks=config["positional_encoding"]["num_landmarks"])

            # back-propagation of the training loss
            train_loss = train_step(training_graph=train_graph, model=model, optimizer=optimizer,
                                    cluster_to_nodes=train_clusters, **config["training"],
                                    contrastive_loss_fn=contrastive_loss_fn
                                    )

            # compute the valid metrics
            valid_dict = valid_step(valid_batch=valid_batch, model=model,
                                    valid_pos_pairs=valid_pos_pairs, valid_neg_pairs=valid_neg_pairs,
                                    contrastive_loss_fn=contrastive_loss_fn
                                    )
            scheduler.step(valid_dict["loss"])

            # store the best mode
            if (valid_dict["loss"] < best_valid_loss) and config["save"]:
                best_valid_loss = valid_dict["loss"]
                best_model_state = copy.deepcopy(model.state_dict())

            epoch_metrics = ({"train_loss": train_loss.item()} | valid_dict
                             | {"current_lr": optimizer.param_groups[0]["lr"]})
            metrics[epoch] = epoch_metrics

            if verbose:
                parts = [f"{k}={v:.6f}" if isinstance(v, float) else f"{k}={v}"
                         for k, v in epoch_metrics.items()]
                print(f"Epoch {epoch}: " + " | ".join(parts))

    except KeyboardInterrupt:
        pass


    if config["save"]:
        pd.DataFrame(metrics).T.to_csv(os.path.join(config["save_folder"], "metrics.csv"))
        if best_model_state is not None:
            torch.save(best_model_state, os.path.join(config["save_folder"], "model.pt"))


if __name__ == "__main__":

    # load the config file
    if not os.path.exists("config.yaml"):
        raise FileNotFoundError("config.yaml not found, please create one following 'example_config.yaml'")
    conf = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
    # train
    train(conf)
    # evaluate
    res = evaluate(conf)
    # save the results
    with open(os.path.join(conf["save_folder"], "results.json"), "w") as f:
        json.dump(res, f, indent=4)

