
import os
import copy
import random
import numpy as np
import pandas as pd

from tqdm import tqdm

from train import DATASET_DIR
from utils.data.load_graph import load_graph

import torch
import torch.nn as nn
from torch_geometric.nn import GAT
from torch_geometric.loader import NeighborLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_


class DGIDiscriminator(nn.Module):
    """D(h, s) = sigmoid( (h @ W) · (s @ W) ) """
    def __init__(self, dim):
        super().__init__()
        self.W = nn.Parameter(torch.empty(dim, dim))
        nn.init.xavier_uniform_(self.W)

    def forward(self, z, s):
        # z: [N, d], s: [d]
        z_proj = z @ self.W                 # [N, d]
        s_proj = s @ self.W                 # [d]
        # broadcast s_proj over nodes, then dot-product along features
        scores = (z_proj * s_proj).sum(dim=-1)  # [N] logits
        return scores


def set_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def train(config: dict, verbose: bool = True):

    seed = config["validation"]["seed"]
    set_seed(seed)
    gen_train = torch.Generator().manual_seed(seed + 1)
    gen_valid = torch.Generator().manual_seed(seed + 2)
    gen_noise = torch.Generator().manual_seed(seed + 3)

    # graph for the training
    train_index = 0
    train_graph = load_graph(folder=DATASET_DIR, tp="train", index=train_index, undirected=True,
                             use_edge_features=False)

    # samples for validation
    valid_graph = load_graph(folder=DATASET_DIR, tp="valid", undirected=True,
                             use_edge_features=False)

    # encoder
    model = GAT(in_channels=train_graph.x.shape[1],
                edge_dim=train_graph.edge_attr.size(-1) if config["training"]["use_edge_features"] else None,
                **config["model"]["params"])
    # discriminator
    model.eval()
    loader = NeighborLoader(
        train_graph, num_neighbors=config["training"]["num_neighbors_loader"],
        batch_size=config["training"]["batch_size"], shuffle=False, generator=gen_train)
    with torch.no_grad():
        for batch in loader:
            _z = model(batch.x, batch.edge_index)
            d_latent = _z.size(-1)
            break
    discriminator = DGIDiscriminator(d_latent)

    params = list(model.parameters()) + list(discriminator.parameters())
    optimizer = torch.optim.Adam(params, **config["optimizer"]["params"])
    scheduler = ReduceLROnPlateau(optimizer, mode="min", **config["scheduler"]["params"])

    metrics: dict[int, dict] = {}
    best_valid_loss = float("inf")
    best_model_state = copy.deepcopy(model.state_dict())
    os.makedirs(config["save_folder"], exist_ok=True)

    bce = nn.BCEWithLogitsLoss()

    def dgi_batch_loss(batch):
        z = model(batch.x, batch.edge_index)  # [Nb, d]
        s = torch.sigmoid(z.mean(dim=0))  # [d]
        perm = torch.randperm(batch.x.size(0), generator=gen_noise)  # feature corruption
        x_tilde = batch.x[perm]
        z_tilde = model(x_tilde, batch.edge_index)  # [Nb, d]
        pos_scores = discriminator(z, s)  # [Nb]
        neg_scores = discriminator(z_tilde, s)  # [Nb]
        logits = torch.cat([pos_scores, neg_scores], dim=0)  # [2*Nb]
        labels = torch.cat([torch.ones(pos_scores.size(0)),
                            torch.zeros(neg_scores.size(0))], dim=0)
        loss = bce(logits, labels)
        return loss

    try:

        # iterate over the epochs
        for epoch in tqdm(range(config["training"]["num_epochs"])):

            # rotate the training graph
            if (epoch + 1) % config["training"]["freq_rotation_graph"] == 0:
                train_index += 1
                train_graph = load_graph(folder=DATASET_DIR, tp="train", index=train_index % 3,
                                         undirected=True, use_edge_features=False)

            # back-propagation of the training loss
            model.train()
            discriminator.train()
            loader = NeighborLoader(
                train_graph,
                num_neighbors=config["training"]["num_neighbors_loader"],
                batch_size=config["training"]["batch_size"],
                shuffle=True, generator=gen_train
            )

            epoch_train_loss, n_train = 0.0, 0

            for batch in loader:
                optimizer.zero_grad()
                loss = dgi_batch_loss(batch)
                loss.backward()
                clip_grad_norm_(params, max_norm=1.0)
                optimizer.step()
                epoch_train_loss += float(loss.item())
                n_train += 1

            train_loss = torch.tensor(epoch_train_loss / max(1, n_train), dtype=torch.float)

            # compute the valid loss
            model.eval()
            discriminator.eval()
            epoch_valid_loss, n_valid = 0.0, 0
            valid_loader = NeighborLoader(
                valid_graph,
                num_neighbors=config["training"]["num_neighbors_loader"],
                batch_size=config["training"]["batch_size"],
                shuffle=False, generator=gen_valid
            )
            with torch.no_grad():
                for i, batch in enumerate(valid_loader):
                    if i > 30:
                        break
                    loss = dgi_batch_loss(batch)
                    epoch_valid_loss += float(loss.item())
                    n_valid += 1

            valid_loss = torch.tensor(epoch_valid_loss / max(1, n_valid), dtype=torch.float)
            scheduler.step(valid_loss.detach().cpu())

            # store the best mode
            if (valid_loss.item() < best_valid_loss) and config["save"]:
                best_valid_loss = valid_loss.item()
                best_model_state = copy.deepcopy(model.state_dict())

            epoch_metrics = ({"train_loss": train_loss.item(),  "valid_loss": valid_loss.item(),
                              "current_lr": optimizer.param_groups[0]["lr"]})
            metrics[epoch] = epoch_metrics

            if verbose:
                parts = [f"{k}={v:.6f}" if isinstance(v, float) else f"{k}={v}"
                         for k, v in epoch_metrics.items()]
                print(f"Epoch {epoch}: " + " | ".join(parts))

    except KeyboardInterrupt:
        pass

    if config["save"]:
        pd.DataFrame(metrics).T.to_csv(os.path.join(config["save_folder"], "metrics.csv"))
        if best_model_state is not None:
            torch.save(best_model_state, os.path.join(config["save_folder"], "model.pt"))
