
import os
import copy
import random
import numpy as np
import pandas as pd

from tqdm import tqdm

from train import DATASET_DIR
from utils.data.load_graph import load_graph

import torch
import torch.nn.functional as F
from torch_geometric.nn import GAT
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import negative_sampling
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_


def decode_dot(z, edge_index_uv):
    src, dst = edge_index_uv[0], edge_index_uv[1]
    return (z[src] * z[dst]).sum(dim=-1)


def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def train(config: dict, verbose: bool = True):

    seed = int(config["validation"]["seed"])
    set_global_seed(seed)
    gen = torch.Generator()
    gen.manual_seed(seed)

    # graph for the training
    train_index = 0
    train_graph = load_graph(folder=DATASET_DIR, tp="train", index=train_index, undirected=True,
                             use_edge_features=False)

    # samples for validation
    valid_graph = load_graph(folder=DATASET_DIR, tp="valid", undirected=True,
                             use_edge_features=False)

    model = GAT(in_channels=train_graph.x.shape[1],
                edge_dim=train_graph.edge_attr.size(-1) if config["training"]["use_edge_features"] else None,
                **config["model"]["params"])
    optimizer = torch.optim.Adam(model.parameters(), **config["optimizer"]["params"])
    scheduler = ReduceLROnPlateau(optimizer, mode="min", **config["scheduler"]["params"])

    metrics: dict[int, dict] = {}
    best_valid_loss = float("inf")
    best_model_state = copy.deepcopy(model.state_dict())
    os.makedirs(config["save_folder"], exist_ok=True)

    try:

        # iterate over the epochs
        for epoch in tqdm(range(config["training"]["num_epochs"])):

            # rotate the training graph
            if (epoch + 1) % config["training"]["freq_rotation_graph"] == 0:
                train_index += 1
                train_graph = load_graph(folder=DATASET_DIR, tp="train", index=train_index % 3,
                                         undirected=True, use_edge_features=False)


            model.train()
            loader = NeighborLoader(
                train_graph,
                num_neighbors=config["training"]["num_neighbors_loader"],
                batch_size=config["training"]["batch_size"],
                shuffle=True, generator=gen,
            )
            epoch_train_loss, n_train_batches = 0.0, 0

            # back-propagation of the training loss
            for batch in loader:

                optimizer.zero_grad()

                z = model(batch.x, batch.edge_index)

                pos_edge_index = batch.edge_index
                self_loop_mask = pos_edge_index[0] != pos_edge_index[1]
                pos_edge_index = pos_edge_index[:, self_loop_mask]
                if pos_edge_index.numel() == 0:
                    continue

                num_nodes_sub = batch.x.size(0)
                num_pos = pos_edge_index.size(1)
                neg_edge_index = negative_sampling(pos_edge_index, num_nodes=num_nodes_sub, num_neg_samples=num_pos,
                                                   method="sparse")

                pos_logits = decode_dot(z, pos_edge_index)
                neg_logits = decode_dot(z, neg_edge_index)

                logits = torch.cat([pos_logits, neg_logits], dim=0)
                labels = torch.cat([torch.ones(num_pos), torch.zeros(neg_edge_index.size(1))], dim=0)

                loss = F.binary_cross_entropy_with_logits(logits, labels)
                loss.backward()
                clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                epoch_train_loss += float(loss.item())
                n_train_batches += 1

            train_loss = torch.tensor(epoch_train_loss / max(1, n_train_batches), dtype=torch.float)

            # compute the valid loss
            model.eval()
            valid_loader = NeighborLoader(
                valid_graph,
                num_neighbors=config["training"]["num_neighbors_loader"],
                batch_size=config["training"]["batch_size"],
                shuffle=False, generator=gen,
            )
            valid_loss_acc, n_valid_batches = 0.0, 0
            with torch.no_grad():
                for i, batch in enumerate(valid_loader):
                    if i > 30:
                        break

                    z = model(batch.x, batch.edge_index)

                    pos_edge_index = batch.edge_index
                    self_loop_mask = pos_edge_index[0] != pos_edge_index[1]
                    pos_edge_index = pos_edge_index[:, self_loop_mask]

                    if pos_edge_index.numel() == 0:
                        continue

                    num_nodes_sub = batch.x.size(0)
                    num_pos = pos_edge_index.size(1)
                    neg_edge_index = negative_sampling(pos_edge_index, num_nodes=num_nodes_sub, num_neg_samples=num_pos,
                                                       method="sparse")

                    pos_logits = decode_dot(z, pos_edge_index)
                    neg_logits = decode_dot(z, neg_edge_index)

                    logits = torch.cat([pos_logits, neg_logits], dim=0)
                    labels = torch.cat([torch.ones(num_pos), torch.zeros(neg_edge_index.size(1))], dim=0)

                    valid_loss = F.binary_cross_entropy_with_logits(logits, labels)
                    valid_loss_acc += float(valid_loss.item())
                    n_valid_batches += 1

            valid_loss = torch.tensor(valid_loss_acc / max(1, n_valid_batches), dtype=torch.float)
            scheduler.step(valid_loss.detach().cpu())

            # store the best mode
            if (valid_loss.item() < best_valid_loss) and config["save"]:
                best_valid_loss = valid_loss.item()
                best_model_state = copy.deepcopy(model.state_dict())

            epoch_metrics = ({"train_loss": train_loss.item(),  "valid_loss": valid_loss.item(),
                              "current_lr": optimizer.param_groups[0]["lr"]})
            metrics[epoch] = epoch_metrics

            if verbose:
                parts = [f"{k}={v:.6f}" if isinstance(v, float) else f"{k}={v}"
                         for k, v in epoch_metrics.items()]
                print(f"Epoch {epoch}: " + " | ".join(parts))

    except KeyboardInterrupt:
        pass


    if config["save"]:
        pd.DataFrame(metrics).T.to_csv(os.path.join(config["save_folder"], "metrics.csv"))
        if best_model_state is not None:
            torch.save(best_model_state, os.path.join(config["save_folder"], "model.pt"))
