import torch
from sklearn.metrics import mean_absolute_error
import torch.nn as nn

from networks import HGNN_classifier, GCN, MLP, GAT
import torch.nn.functional as F
import random
import numpy as np
import time
import datetime
import os.path as osp
import torch.cuda

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))


# def contrast_loss(H_raw, mask, labels):  #拥有相同lbl的结点应该属于相同的边
#     lbl_num = labels.max().item() + 1
#     total_loss = 0
#     for h in H_raw:
#         for lbl in range(lbl_num):
#             lbl_mask = labels == lbl
#             src = h[mask][lbl_mask]
#             target_idx = [i for i in range(src.shape[0])]
#             random.shuffle(target_idx)
#             target = src[target_idx]
#             loss = F.mse_loss(src, target)
#             total_loss = total_loss + loss
#     return total_loss
def contrast_loss(H_raw, mask, labels):
    lbl_num = labels.max().item() + 1
    total_loss = torch.tensor(0.0, device=H_raw[0].device)  # 避免 float 类型污染计算图
    for h in H_raw:
        for lbl in range(lbl_num):
            lbl_mask = labels == lbl
            src = h[mask][lbl_mask]  # 取出该类节点的表示
            if src.shape[0] == 0:
                continue  # 跳过空类
            target_idx = list(range(src.shape[0]))
            random.shuffle(target_idx)
            target = src[target_idx]
            loss = F.mse_loss(src, target)
            if torch.isnan(loss):
                print(f"[Warning] NaN loss detected at label {lbl}")
                continue  # 跳过 NaN
            total_loss += loss
    return total_loss



def contrast_loss2(H, x):  #相同边的结点特征应该相似
    total_loss = 0
    feartures = x
    for h in H:
        cc = h.ceil().abs()
        for i in range(cc.shape[1]):  #h 是 n*n 维
            col_mask = cc[:, i] == 1
            src = feartures[col_mask]  #属于同一条边的结点
            target_idx = [i for i in range(src.shape[0])]
            random.shuffle(target_idx)
            target = src[target_idx]  #随机调换idx
            loss = F.mse_loss(src, target) + 1e-8  #同一条边的结点特征应该相似
            if loss > 1e-8:
                total_loss = total_loss + loss
    return total_loss





def edge_shuffle_contrast_loss(e: torch.Tensor, mode='mse') -> torch.Tensor:
    """
    超边shuffle对比学习损失：
    e: 超边特征张量 (num_edges × dim)
    mode: 'mse' 或 'cosine'，选择不同的对比方式
    """
    if e.size(0) < 2:
        return torch.tensor(0.0, device=e.device)  # 小于两条边就不比了
    # 打乱顺序
    idx = list(range(e.size(0)))
    random.shuffle(idx)
    e_shuffled = e[idx]
    if mode == 'mse':
        loss = F.mse_loss(e, e_shuffled)
    elif mode == 'cosine':
        e_norm = F.normalize(e, dim=1)
        e_shuf_norm = F.normalize(e_shuffled, dim=1)
        sim = torch.sum(e_norm * e_shuf_norm, dim=1)
        loss = torch.mean(1 - sim)  # 1 - cosine similarity
    else:
        raise ValueError("Unsupported mode. Use 'mse' or 'cosine'.")
    return loss


def laplacian_rank(H_list, device):
    # L = I - Dv^(-1/2) W De^(-1) H^(T) Dv^(-1/2)
    rank_list = []
    for tmpH in H_list:
        H = tmpH.clone()

        ## 删除空边
        n_edge = H.shape[1]
        tmp_sum = H.sum(dim=0)
        index = []
        for i in range(n_edge):
            if tmp_sum[i] != 0:
                index.append(i)

        H = H[:, index]
        ################
        n_node = H.shape[0]
        n_edge = H.shape[1]

        # the weight of the hyperedge
        # W = np.ones(n_edge)
        W = torch.ones(n_edge).to(device)

        # the degree of the node
        # DV = np.sum(H * W, axis=1)
        DV = torch.sum(H * W, axis=1)

        # the degree of the hyperedge
        # DE = np.sum(H, axis=0)
        DE = torch.sum(H, axis=0)

        # invDE = np.mat(np.diag(np.power(DE, -1)))
        invDE = torch.diag(torch.pow(DE, -1))

        # DV2 = np.mat(np.diag(np.power(DV, -0.5)))
        DV2 = torch.diag(torch.pow(DV, -0.5))

        # W = np.mat(np.diag(W))
        # H = np.mat(H)
        HT = H.T

        I = torch.eye(n_node, n_node).to(device)

        L = I - DV2 @ H @ W @ invDE @ HT @ DV2

        rank_L = torch.linalg.matrix_rank(L)

        rank_list.append(rank_L)

    print("===========================> Rank of L is: ", rank_list)

import time
import torch.cuda as cuda
def train(model, optimizer, data, args):
    device = torch.device(args.device)
    model.to(device)
    train_mask = data['train_idx']
    labels = data['lbls'][train_mask]
    best_val_acc = 0
    best_test_acc = 0
    patience = 0
    patience2 = 0
    best_epoch = 0

    spent_time = []

    for epoch in range(args.epoch):
        start_time = time.time()
        cuda.reset_peak_memory_stats(device)

        model.train()
        optimizer.zero_grad()

        args.stage = 'train'
        out, x, H, H_raw, edges = model(data, args)

        pred = F.log_softmax(out, dim=1)
        contra_ls = 0
        contra_ls2 = 0
        if H_raw is not None:
            # if epoch % 20 == 0 or epoch == args.epoch - 1:
            #     laplacian_rank(H, args.device)

            contra_ls = contrast_loss(H_raw, train_mask, labels)
            contra_ls2 = contrast_loss2(H, x)
            loss_edge_shuffle = edge_shuffle_contrast_loss(edges, mode='mse')  # or 'cosine'
            # print(contra_ls, contra_ls2, loss_edge_shuffle)

        loss = F.nll_loss(pred[train_mask], labels) + contra_ls * args.namuda + contra_ls2 * (
                    args.namuda2 / 1000) + loss_edge_shuffle

        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        end_time = time.time()
        epoch_time_ms = (end_time - start_time) * 1000  # 秒 → 毫秒
        max_mem_gb = cuda.max_memory_allocated(device) / (1024 ** 3)  # Bytes → GB
        print(f"Epoch {epoch} time: {epoch_time_ms:.2f} ms, Peak memory usage: {max_mem_gb:.4f} GB")
        _, pred = pred[train_mask].max(dim=1)
        correct = int(pred.eq(labels).sum().item())
        acc = correct / len(labels)
        if args.dataset in ['20news']:
            args.stage = 'test'
            test_acc, test_loss = evaluate(model, data, args)
            if test_acc > best_test_acc:
                patience = 0
                best_test_acc = test_acc
                best_epoch = epoch
                torch.save(model.state_dict(), 'model.pth')
                torch.save(H, 'H.pt')
                if H is not None:
                    args.num_edges = H[0].shape[1]
            else:
                patience = patience + 1
                print("patience now: ======>",patience)

            if patience > args.patience:
                break

            print("========================> ", epoch)
            print("Train acc: {}, loss: {}".format(acc, loss))
            # print("Val acc: {}, loss: {}".format(val_acc,val_loss))
            print("Test acc: {}, loss: {}".format(test_acc, test_loss))

        else:
            args.stage = 'val'
            val_acc, val_loss = evaluate(model, data, args)

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                args.stage = 'test'
                test_acc, test_loss = evaluate(model, data, args)
                if test_acc > best_test_acc:
                    best_epoch = epoch
                    best_test_acc = test_acc
                    torch.save(model.state_dict(), 'model.pth')
                    torch.save(H, 'H.pt')
                patience = 0
                if H is not None:
                    args.num_edges = H[0].shape[1]
            else:
                patience = patience + 1
                print("patience now: ======>",patience)
            if patience > args.patience:
                break
            print("========================> ", epoch)
            print("Train acc: {}, loss: {}".format(acc, loss))
            # print("Val acc: {}, loss: {}".format(val_acc,val_loss))
            print("val acc: {}, loss: {}".format(val_acc, val_loss))
            print("Test acc: {}, loss: {}".format(test_acc, test_loss))
            # print("Epoch: {}; Loss: {}, acc: {}".format(i,loss,acc))
            # print("Time: {} (h:mm:ss)".format(format_time(time.time() - t0)))
            # spent_time.append(format_time(time.time() - t0))
    if args.dataset in ['40', 'NTU','20news']:
        print("Best epch: {}, best test acc: {}".format(best_epoch, best_test_acc))
        return best_test_acc
    else:
        print("Best epch: {}, best tset acc: {}".format(best_epoch, best_test_acc))
        return best_test_acc


def evaluate(model, data, args):
    stage = args.stage

    model.eval()
    mask = data[(stage + '_idx')]
    labels = data['lbls'][mask]

    out, x, H, H_raw, edges = model(data, args)

    pred = F.log_softmax(out, dim=1)

    contra_ls = 0
    contra_ls2 = 0
    if H_raw is not None:
        contra_ls = contrast_loss(H_raw, mask, labels)
        contra_ls2 = contrast_loss2(H, x)
        loss_edge_shuffle = edge_shuffle_contrast_loss(edges, mode='mse')  # or 'cosine'
    loss = F.nll_loss(pred[mask], labels) + contra_ls * args.namuda + contra_ls2 * (
                args.namuda2 / 1000) +loss_edge_shuffle

    _, pred = pred[mask].max(dim=1)
    correct = int(pred.eq(labels).sum().item())
    acc = correct / len(labels)
    return acc, loss


def train_dhl(data, args):
    in_dim = args.in_dim
    hid_dim = args.hid_dim
    out_dim = args.out_dim
    num_edges = args.num_edges
    model = HGNN_classifier(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lrate, weight_decay=args.wdecay)

    best_acc = train(model, optimizer, data, args)

    return best_acc


def train_gcn(data, args):
    in_dim = args.in_dim
    hid_dim = args.hid_dim
    out_dim = args.out_dim
    model = GCN(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lrate, weight_decay=args.wdecay)

    best_acc = train(model, optimizer, data, args)

    return best_acc


def train_gat(data, args):
    model = GAT(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lrate, weight_decay=args.wdecay)

    best_acc = train(model, optimizer, data, args)

    return best_acc


def train_mlp(data, args):
    model = MLP(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lrate, weight_decay=args.wdecay)
    best_acc = train(model, optimizer, data, args)

    return best_acc




def train_dhl_regression(data, args):
    in_dim = args.in_dim
    hid_dim = args.hid_dim
    out_dim = args.out_dim
    num_edges = args.num_edges
    model = HGNN_classifier(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lrate, weight_decay=args.wdecay)
    best_mae = train_regression(model, optimizer, data, args)

    return best_mae
def train_regression(model, optimizer, data, args):
    device = torch.device(args.device)
    model.to(device)
    loss_fn_mae = nn.L1Loss()
    loss_fn_mse = nn.MSELoss()

    best_val_mae = 1000
    best_test_mae = 1000
    patience = 0

    for epoch in range(args.epoch):
        model.train()
        optimizer.zero_grad()

        all_out = torch.tensor([]).to(device)
        all_label = torch.tensor([]).to(device)

        train_mask = data['train_idx']
        labels = data['lbls'][train_mask].squeeze().float().to(device)

        args.stage = 'train'
        preds, _, _, _, _ = model(data, args)
        preds = preds[train_mask].squeeze()

        all_out = torch.cat((all_out, preds.view(-1, 1)), dim=0)
        all_label = torch.cat((all_label, labels.view(-1, 1)), dim=0)

        loss_mae = loss_fn_mae(all_out, all_label)
        loss_mse = loss_fn_mse(all_out, all_label)

        # 用 MAE 作为训练目标
        loss_mae.backward()
        optimizer.step()

        train_mae = loss_mae.item() / torch.std(all_label).item()
        train_mse = loss_mse.item() / torch.var(all_label).item()

        # 验证阶段
        args.stage = 'val'
        val_mae, val_mse = evaluate_regtess(model, data, args, return_mse=True)

        # 提前停止
        if val_mae < best_val_mae:
            best_val_mae = val_mae
            args.stage = 'test'
            best_test_mae, best_test_mse = evaluate_regtess(model, data, args, return_mse=True)
            patience = 0
        else:
            patience += 1
            print("patience now: ======>", patience)
        if patience > args.patience:
            break

        print("========================> Epoch", epoch)
        print("Train MAE (normalized):", train_mae)
        print("Train MSE (normalized):", train_mse)
        print("Val MAE (normalized):", val_mae)
        print("Val MSE (normalized):", val_mse)
        print("Best Test MAE (normalized):", best_test_mae)
        print("Best Test MSE (normalized):", best_test_mse)

    return best_test_mae, best_test_mse



def evaluate_regtess(model, data, args, return_mse=False):
    device = torch.device(args.device)
    model.eval()

    stage = args.stage
    if stage == 'val':
        mask = data['val_idx']
    elif stage == 'test':
        mask = data['test_idx']
    else:
        raise ValueError("Unsupported stage for evaluation")

    labels = data['lbls'][mask].squeeze().float().to(device)
    preds, _, _, _, _ = model(data, args)
    preds = preds[mask].squeeze()

    mae = nn.L1Loss()(preds, labels).item()
    mse = nn.MSELoss()(preds, labels).item()

    if return_mse:
        # 标准化后的版本
        return mae / torch.std(labels).item(), mse / torch.var(labels).item()
    else:
        return mae / torch.std(labels).item()

