import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


import hydra
import logging
import numpy as np
from copy import deepcopy
from omegaconf import DictConfig, OmegaConf
from utils import load_data, fix_iso_v, ho_topology_score , sinkhorn , score
from utils import visualize , draw_HeatMap , append_number , ho_topology_scores_my, hyperedge_divergence,visualize_rho_vs_divergence
from utils import load_data, fix_iso_v, ho_topology_score , sinkhorn , score, pot

from my_model import MyGCN, MyHGNN, MyMLPs, MyHGNNL3, MyMLPsL3

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import tracemalloc  # stdlib, good for Python memory usage

from dhg.nn import MLP
from dhg import Hypergraph, Graph
from dhg.random import set_seed
from dhg.utils import split_by_num
from dhg.models import HGNNP, HGNN, HNHN, UniGCN, UniGAT, GCN, GAT
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluatorbase
import dhg
from utils import MultiExpMetric  , build_subhypergraph
import threading
import psutil

def train(net, X, G, lbls, train_mask, optimizer):
    net.train()
    optimizer.zero_grad()
    outs = net(X, G)
    loss = F.nll_loss(F.log_softmax(outs[train_mask], dim=1), lbls[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def valid(net, X, G, lbls, mask, evaluator):
    net.eval()
    outs = net(X, G)
    res = evaluator.validate(lbls[mask], outs[mask])
    return res


@torch.no_grad()
def test(net, X, G, lbls, mask, evaluator, ft_noise_level=0):
    net.eval()
    if ft_noise_level > 0:
        X = (1 - ft_noise_level) * X + ft_noise_level * torch.randn_like(X)
    outs = net(X, G)
    res = evaluator.test(lbls[mask], outs[mask])
    return res

class HighOrderConstraint(nn.Module):
    def __init__(self, model, X, G, noise_level=1.0, tau=1.0):
        super().__init__()
        model.eval()
        self.tau = tau
        pred = model(X, G).softmax(dim=-1).detach()
        entropy_x = -(pred * pred.log()).sum(1, keepdim=True)
        entropy_x[entropy_x.isnan()] = 0
        entropy_e = G.v2e(entropy_x, aggr="mean")

        X_noise = X.clone() * (torch.randn_like(X) + 1) * noise_level
        pred_ = model(X_noise, G).softmax(dim=-1).detach()
        entropy_x_ = -(pred_ * pred_.log()).sum(1, keepdim=True)
        entropy_x_[entropy_x_.isnan()] = 0
        entropy_e_ = G.v2e(entropy_x_, aggr="mean")

        self.delta_e_ = (entropy_e_ - entropy_e).abs()
        self.delta_e_ = 1 - self.delta_e_ / self.delta_e_.max()
        self.delta_e_ = self.delta_e_.squeeze()
        

        self.delta_x_ = (entropy_x_ - entropy_x).abs()
        self.delta_x_ = 1 - self.delta_x_ / self.delta_x_.max()
        self.delta_x_ = self.delta_x_.squeeze()
        
    def forward(self, pred_s, pred_t, G):
        pred_s, pred_t = F.softmax(pred_s, dim=1), F.softmax(pred_t, dim=1)
        e_mask = torch.bernoulli(self.delta_e_).bool()
        pred_s_e = G.v2e(pred_s, aggr="mean")
        pred_s_e = pred_s_e[e_mask]
        pred_t_e = G.v2e(pred_t, aggr="mean")
        pred_t_e = pred_t_e[e_mask]
        loss = F.kl_div(torch.log(pred_s_e / self.tau), pred_t_e / self.tau, reduction="batchmean", log_target=True)
        return loss
    def getmask(self):
        e_mask = torch.bernoulli(self.delta_e_).bool()
        return e_mask
    def getmask_node(self):
        mask =torch.bernoulli(self.delta_x_).bool()
        return mask

def js_divergence(p, q, eps=1e-8):
    """Jensen-Shannon Divergence between two probability distributions."""
    m = 0.5 * (p + q)
    kl_pm = F.kl_div((p + eps).log(), m, reduction="batchmean", log_target=True)
    kl_qm = F.kl_div((q + eps).log(), m, reduction="batchmean", log_target=True)
    return 0.5 * (kl_pm + kl_qm)

class HighOrderConstraint_OT(nn.Module):
    def __init__(self, class_mask_data, selected_class, model, X, G, tau=1.0, alpha=1.0, beta=1.0,noise_level=1.0):
        """
        :param model: teacher model
        :param X: input features
        :param G: graph/hypergraph object with v2e (vertex-to-edge aggregation)
        :param tau: temperature
        :param alpha: weight for Euclidean alignment
        :param beta: weight for JSD alignment
        """
        super().__init__()
        model.eval()
        self.tau = tau
        self.alpha = alpha
        self.beta = beta
        self.num_nodes= X.shape[0]
        with torch.no_grad():
            self.emb_x = model(X, G)  # assumes model has get_embeddings
            self.emb_e = G.v2e(self.emb_x, aggr="mean")
        
    def forward(self, mask_teacher_outout, pred_s, pred_t, G):
        pred_s, pred_t = F.softmax(pred_s / self.tau, dim=-1), F.softmax(pred_t / self.tau, dim=-1)
        full_pred_s = torch.zeros(self.num_nodes, pred_s.size(-1), device=pred_s.device)
        full_pred_t = torch.zeros(self.num_nodes, pred_t.size(-1), device=pred_t.device)

        full_pred_s[mask_teacher_outout] = pred_s
        full_pred_t[mask_teacher_outout] = pred_t
        emb_x = self.emb_x
        emb_e = self.emb_e
        emb_x_to_e = G.v2e(emb_x, aggr="mean")
        eucl_loss = ((emb_x_to_e - emb_e) ** 2).mean()
        pred_s_e = G.v2e(full_pred_s, aggr="mean")
        pred_t_e = G.v2e(full_pred_t, aggr="mean")
        jsd_loss = js_divergence(pred_s_e, pred_t_e)
        loss =   self.beta * jsd_loss
        return loss

def train_stu(teachersubtask,class_mask_data, selected_class,net, X, G, lbls, out_t, train_mask, optimizer, hc=None, lamb=0):
  
    
    net.train()
    optimizer.zero_grad()
    outs = net(X)


    pred_teacher = out_t.argmax(dim = 1)
    mask_teacher_outout = torch.isin(pred_teacher, selected_class)
    if teachersubtask: 
        mask_teacher_outout = class_mask_data
    loss_x = F.nll_loss(F.log_softmax(outs[train_mask], dim=1), lbls[train_mask])
    loss_k = F.kl_div(F.log_softmax(outs[mask_teacher_outout], dim=1), F.softmax(out_t[mask_teacher_outout][:, selected_class], dim=1), reduction="batchmean", log_target=True)
    if hc is not None:
        loss_h = hc(outs, out_t, G)
        loss_k = loss_h + loss_k
    loss = loss_x * lamb + loss_k * (1 - lamb)
    loss.backward()
    optimizer.step()
    return loss.item()


def train_stu_OT(teachersubtask,class_mask_data, selected_class, net, X, G, lbls, out_t, train_mask, optimizer, hc=None, lamb=0,alpha=1, beta=1, gama=1):
    net.train()
    optimizer.zero_grad()
    outs = net(X,get_emb=False)
    num_class_student = len(outs[0])
    label = torch.arange(num_class_student)
    pred_teacher = out_t.argmax(dim = 1)
    mask_teacher_outout = torch.isin(pred_teacher, selected_class)
    if teachersubtask: 
        mask_teacher_outout = class_mask_data
    mask_teacher_outout = class_mask_data
    T=5
    p_s1 = sinkhorn(torch.softmax(outs[mask_teacher_outout] , dim=1) , label , num_class_student ,selected_class, outs[mask_teacher_outout])
    p_t1 = sinkhorn(torch.softmax(out_t [mask_teacher_outout], dim=1) , label, num_class_student,selected_class, out_t[mask_teacher_outout])
    loss_ot =(F.kl_div(F.log_softmax(p_s1, dim=1),  F.softmax(p_t1, dim=1),  reduction="batchmean" )     )
    conkd= ContrastiveKD()
    con_loss2 = conkd(p_s1,p_t1,G,False)
    if hc is not None:
        loss_h_ot = hc(mask_teacher_outout, p_s1, p_t1, G)
    loss =alpha*con_loss2+beta*loss_h_ot+gama*loss_ot
    loss.backward()
    optimizer.step()
    return loss.item()


class ContrastiveKD(nn.Module):
    def __init__(self, temperature=0.1):
        super(ContrastiveKD, self).__init__()
        self.temperature = temperature

    def forward(self, logits_s, logits_t,G, onedge=False):
        logits_s_e = logits_s
        logits_t_e = logits_t
        z_s = F.normalize(logits_s_e, dim=1)  
        z_t = F.normalize(logits_t_e, dim=1)  
        sim_matrix = torch.matmul(z_s, z_t.T) 
        labels = torch.arange(z_s.size(0), device=logits_s.device)
        loss = F.cross_entropy(sim_matrix / self.temperature, labels)
        return loss


@torch.no_grad()
def valid_stu(net, X, lbls, mask, evaluator):
    net.eval()
    outs = net(X)
    res = evaluator.validate(lbls[mask], outs[mask])
    return res


@torch.no_grad()
def test_stu(net, X, lbls, mask, evaluator, ft_noise_level=0):
    net.eval()
    if ft_noise_level > 0:
        X = (1 - ft_noise_level) * X + ft_noise_level * torch.randn_like(X)
    outs = net(X)
    res = evaluator.test(lbls[mask], outs[mask])
    return res


from sklearn.metrics import precision_score, recall_score

class CustomEvaluator(Evaluatorbase):
    def __init__(self, metrics=None):
        if metrics is None:
            metrics = [
                "accuracy",
                "f1_score",
                {"f1_score": {"average": "micro"}},
                "confusion_matrix",
                "precision",
                "recall"
            ]
        super().__init__([m for m in metrics if m not in ["precision", "recall"]])
        self.extra_metrics = [m for m in metrics if m in ["precision", "recall"]]

    def validate(self, y_true, y_pred):
        results = super().validate(y_true, y_pred)
        return results
        
    def test(self, y_true, y_pred):
        results = super().test(y_true, y_pred)
        if torch.is_tensor(y_true):
            y_true = y_true.cpu().numpy()
        if torch.is_tensor(y_pred):
            y_pred = y_pred.argmax(dim=1).cpu().numpy()

        if "precision" in self.extra_metrics: 
            results["precision_macro"] = precision_score(y_true, y_pred, average="macro", zero_division=0)
            results["precision_micro"] = precision_score(y_true, y_pred, average="micro", zero_division=0)

        if "recall" in self.extra_metrics:
            results["recall_macro"] = recall_score(y_true, y_pred, average="macro", zero_division=0)
            results["recall_micro"] = recall_score(y_true, y_pred, average="micro", zero_division=0)

        return results
def train_stu_OT_reject_aware(teachersubtask,class_mask_data, selected_class, net, X, G, lbls, out_t, train_mask, optimizer, hc=None, lamb=0):
    net.train()
    optimizer.zero_grad()
    outs = net(X,get_emb=False)
    num_class_t= len(out_t[0])
    num_class_student = len(outs[0])
    
    pred_teacher = out_t.argmax(dim = 1)
    pred_teacher[pred_teacher >= num_class_student-1] = num_class_student-1
    mask_teacher_outout = class_mask_data
    mask = (pred_teacher >= 0) & (pred_teacher<= num_class_student-2)
    m = mask.sum().item()
    out_t_open = torch.zeros(out_t.size(0), num_class_student, device=out_t.device)
    out_t_open[:, :num_class_student-1] = out_t[:, :num_class_student-1]  # seen classes
    out_t_open[:, num_class_student-1] = torch.logsumexp(out_t[:, num_class_student-1:], dim=1)  # unseen → class 3
    label = torch.arange(num_class_student)
    p_s1 = pot(torch.softmax(outs, dim=1)  , label , m, num_iter=3, num_cls=num_class_student)
    p_t1 = pot(torch.softmax(out_t_open, dim=1)  ,label , m, num_iter=8, num_cls=num_class_student)
    loss_ot =(F.kl_div(F.log_softmax(p_s1, dim=1),  F.softmax(p_t1, dim=1),  reduction="batchmean" )     )
    conkd= ContrastiveKD()
    con_loss2 = conkd(p_s1 ,p_t1 ,G,False)
    if hc is not None:
        loss_h_ot = hc(class_mask_data, p_s1, p_t1, G)
    loss =con_loss2+loss_h_ot+loss_ot
    loss.backward()
    optimizer.step()
    return loss.item()

def exp(seed, cfg: DictConfig):
    set_seed(seed)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    evaluator = CustomEvaluator()

    data, edge_list = load_data(cfg.data.name)
    if cfg.model.teacher in ['gcn', 'gat']: 
        if cfg.data.name in ['cora', 'pubmed', 'citeseer']: 
            G = Graph(data["num_vertices"], edge_list)
        else: 
            g = Hypergraph(data["num_vertices"], edge_list)
            G = Graph.from_hypergraph_clique(g)
        G.add_extra_selfloop()
    else: 
        if cfg.data.name in ['cora', 'pubmed', 'citeseer']:
            print("converting...")
            g = Graph(data["num_vertices"], edge_list)
            G = Hypergraph.from_graph(g)
            G.add_hyperedges_from_graph_kHop(g, 1)
            print("converted!")
        else: 
            if cfg.data.name in ["ogbn-products"]: 
                G = Hypergraph(data["num_vertices"], edge_list, is_sparse=True)
            else:
                G = Hypergraph(data["num_vertices"], edge_list, )

        G = fix_iso_v(G)
    train_mask, val_mask, test_mask = split_by_num(
        data["num_vertices"], data["labels"], cfg.data.num_train, cfg.data.num_val
    )
    X, lbl = data["features"], data["labels"]
    train_nodes = torch.where(train_mask)[0].tolist()
    num_classes_student = data["num_classes"]
    num_classes_teacher = data["num_classes"]
    selected_class = torch.unique(lbl)
    class_mask_data = torch.isin(lbl,selected_class)
    train_mask_sel=train_mask  
    val_mask_sel=val_mask  
    test_mask_sel=test_mask
    if cfg.data.task_numberClass > 0:
        selected_class = torch.arange(cfg.data.task_numberClass, device=device,dtype=torch.long)
        num_classes_student = selected_class.shape[0]
        class_mask_data = torch.isin(lbl,selected_class)
        train_mask_sel=train_mask  & class_mask_data
        val_mask_sel=val_mask  & class_mask_data
        test_mask_sel=test_mask  & class_mask_data
        if cfg.data.teachersubtask==1:
            num_classes_teacher= num_classes_student
            train_mask=train_mask  & class_mask_data
            val_mask=val_mask  & class_mask_data
            test_mask=test_mask  & class_mask_data
    print("Start...")
    if cfg.model.teacher == "hgnn":
        net = MyHGNN(X.shape[1], cfg.model.t_hid, num_classes_teacher, use_bn=False)
    elif cfg.model.teacher == "hgnnp":
        net = HGNNP(X.shape[1], cfg.model.t_hid, num_classes_teacher, use_bn=False)
    elif cfg.model.teacher == "hnhn":
        net = HNHN(X.shape[1], cfg.model.t_hid, num_classes_teacher, use_bn=False)
    elif cfg.model.teacher == "unigcn":
        net = UniGCN(X.shape[1], cfg.model.t_hid, num_classes_teacher, use_bn=False)
    elif cfg.model.teacher == "unigat":
        net = UniGAT(X.shape[1], 8, num_classes_teacher, 4, use_bn=False)
    elif cfg.model.teacher == "gcn":
        net = MyGCN(X.shape[1], 32, num_classes_teacher, use_bn=False)
    elif cfg.model.teacher == "gat":
        net = GAT(X.shape[1], 8, num_classes_teacher, num_heads=4, use_bn=False)
    else:
        raise NotImplementedError

    # train teacher
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
    X, lbl, G = X.to(device), lbl.to(device), G.to(device)
    net = net.to(device)

    best_state = None
    best_epoch, best_val = 0, 0
    totalepoch = cfg.model.epoch
    for epoch in range(200):
        train (net, X, G, lbl, train_mask, optimizer)
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = valid(net, X, G, lbl, val_mask, evaluator)
            if val_res > best_val:
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net.state_dict())
    net.load_state_dict(best_state)
    res_t = test(net, X, G, lbl, test_mask, evaluator, cfg.data.ft_noise_level)
    logging.info(f"teacher test best epoch: {best_epoch}, res: {res_t}")

    # train student
    out_t = net(X, G,get_emb=False).detach()
    if cfg.model.student == "light_hgnnp" and cfg.data.teachersubtask==0:
        hc = HighOrderConstraint(net, X, G, noise_level=cfg.data.hc_noise_level, tau=cfg.loss.tau)
    else:
        hc = None

    net_s = MyMLPs(X.shape[1], cfg.model.hid, num_classes_student)
    optimizer = optim.Adam(net_s.parameters(), lr=0.01, weight_decay=5e-4)
    net_s = net_s.to(device)

    best_state = None
    best_epoch, best_val = 0, 0
    for epoch in range(totalepoch):
        if cfg.model.student == "HGSelKD":
            hc = HighOrderConstraint_OT(class_mask_data,selected_class, net, X, G, noise_level=cfg.data.hc_noise_level, tau=cfg.loss.tau)
            train_stu_OT(cfg.data.teachersubtask,class_mask_data, selected_class, net_s, X, G, lbl, out_t, train_mask_sel, optimizer, hc=hc, lamb=cfg.loss.lamb, alpha=cfg.loss.alpha,beta=cfg.loss.beta,gama=cfg.loss.gama)
        else:
            train_stu(cfg.data.teachersubtask,class_mask_data, selected_class, net_s, X, G, lbl, out_t, train_mask_sel, optimizer, hc=hc, lamb=cfg.loss.lamb)
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = valid_stu(net_s, X, lbl, val_mask_sel, evaluator)
            if val_res > best_val:
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net_s.state_dict())
    net_s.load_state_dict(best_state)
    res_s = test_stu(net_s, X, lbl, test_mask_sel, evaluator, cfg.data.ft_noise_level)

    logging.info(f"student test best epoch: {best_epoch}, res: {res_s}\n")
    return {"t": res_t, "s": res_s}


 
@hydra.main(config_path=".", config_name="trans_config", version_base="1.1")
def main(cfg: DictConfig):
    logging.info(OmegaConf.to_yaml(cfg))
    res_all = MultiExpMetric()
    num_runs=cfg.data.num_runs
    for seed in range(num_runs):
        res = exp(seed, cfg)
        res_all.update(res)
    logging.info(OmegaConf.to_yaml(cfg))
    logging.info(res_all)


if __name__ == "__main__":
    main()
