#-*- coding: utf-8 -*-
import os
import time
import math
import random
import argparse
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--dataset_name', type=str, default='house1to10')
parser.add_argument('--emb_size', type=int, default=32, help='Embeding Size')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight Decay')
parser.add_argument('--lr', type=float, default=0.002, help='Learning Rate')
parser.add_argument('--seed', type=int, default=222, help='Random seed')
parser.add_argument('--epoch', type=int, default=80000, help='Epoch')
parser.add_argument('--batch_size', type=int, default=64, help='Batch Size')
parser.add_argument('--neighbours', type=int, default=400, help='Max Neighbours')
parser.add_argument('--val_epoch', type=int, default=200, help='Epoch Interval For Validation')
parser.add_argument('--zero_shot', action='store_true', help='Zero Shot Learning')
parser.add_argument('--ckpt', type=str, help='Checkpoint path')
parser.add_argument('--eval', action='store_true', help='Evaluation')
parser.add_argument('--comment', type=str, default='')
args = parser.parse_args()

data_prefix = 'experiments-data'

if 'cuda' in args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device[-1]

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import Edge, sub_edge, GELU
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import f1_score, roc_auc_score

# from torch.utils.tensorboard import SummaryWriter

def setup_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

from common import DATA_EMB_DIC

def find_neighbours(edges, node_in, node_out, left=1, signal=0):
    neighbour = set()
    ave = args.neighbours//len(node_in)
    if ave < 4:
        ave = 4
    random.shuffle(node_in)
    node_out = set(node_out)
    for node in node_in:
        if left:
            st = edges.edge_abp[node] + edges.edge_abn[node]
        else:
            st = edges.edge_bap[node] + edges.edge_ban[node]
        # samp = np.array(random.sample(st, min(len(st), ave*4)))
        if not len(st):
            return []
        samp = np.array(st)
        if left:
            deg = edges.degb[samp]
        else:
            deg = edges.dega[samp]
        if signal == 0:
            deg = np.random.permutation(deg)
        elif signal == 1:
            deg = -deg
        deg_idx = np.argsort(deg)[::-1][:ave]
        samp = samp[deg_idx].tolist()
        res = list(set(samp) - node_out)
        neighbour.update(res)
        if len(neighbour) >= args.neighbours:
            break
    return list(neighbour)[:args.neighbours]

def find_neighbours_new(edges, node_in, node_out, left=1):
    neighbour = set()
    ave = args.neighbours//len(node_in)
    if ave < 4:
        ave = 4
    random.shuffle(node_in)
    node_out = set(node_out)
    for node in node_in:
        if left:
            sts = [edges.edge_abp[node], edges.edge_abn[node]]
        else:
            sts = [edges.edge_bap[node], edges.edge_ban[node]]
        # samp = np.array(random.sample(st, min(len(st), ave*4)))
        if not len(sts[0])+len(sts[1]):
            return []
        for i, st in enumerate(sts[::-1]):
            samp = np.array(st)
            if not len(samp):
                continue
            if left:
                deg = edges.degb[samp]
            else:
                deg = edges.dega[samp]
            deg = np.random.permutation(deg)
            deg_idx = np.argsort(deg)[::-1][:args.neighbours//2*(i+1)]
            samp = samp[deg_idx].tolist()
            res = list(set(samp) - node_out)
            neighbour.update(res)
        if len(neighbour) >= args.neighbours:
            break
    return list(neighbour)[:args.neighbours]

class GraphDataset(Dataset):
    def __init__(self, val, test=False, edge=None):
        self.val = Edge(val)
        print(len(self.val.edge_set))
        self.test = test
        if edge is None:
            self.edges = edges
        else:
            self.edges = edge

    def __len__(self):
        return len(self.val.edge_list)

    def __getitem__(self, index):
        edge = self.val.edge_list[index]
        left, right, s = [edge[0]], [edge[1]], edge[2]==1
        neighbour = np.array(self.edges.edge_abn[left[0]] + self.edges.edge_abp[left[0]])
        if not self.test:
            signal = random.randint(0,8)
            signal = 0
        else:
            signal = 0
        left_n = find_neighbours(self.edges, left, right, left=1, signal=signal//3)
        right_n = find_neighbours(self.edges, right, left, left=0, signal=signal%3)
        sub_0 = sub_edge(left, left_n, self.edges)
        sub_1 = sub_edge(right_n, left_n, self.edges)
        sub_2 = sub_edge(right_n, right, self.edges)
        edge_s = int(s)
        return {
            'left': left,
            'right': right,
            'left_n': left_n,
            'right_n': right_n,
            'sub_0': sub_0,
            'sub_1': sub_1,
            'sub_2': sub_2,
            'edge_s': edge_s,
            'left_ns': len(left_n),
            'right_ns': len(right_n)
        }

def pad_tensor(batch):
    max_size = [max(tensor.size(dim) for tensor in batch) for dim in range(batch[0].dim())]
    batch_size = len(batch)
    background = torch.zeros([batch_size] + max_size, dtype=batch[0].dtype)
    for i, tensor in enumerate(batch):
        indices = tuple(slice(0, sz) for sz in tensor.size())
        background[i][indices] = tensor
    return background

def collate_fn(batch):
    batched_data = {key: [] for key in batch[0]}
    for item in batch:
        for key, value in item.items():
            if isinstance(value, torch.Tensor):
                batched_data[key].append(value)
            elif isinstance(value, list):
                batched_data[key].append(torch.tensor(value))
            else:
                batched_data[key].append(value)
    for key in batched_data:
        if isinstance(batched_data[key][0], torch.Tensor):
            if all(tensor.shape == batched_data[key][0].shape for tensor in batched_data[key]):
                batched_data[key] = torch.stack(batched_data[key])
            else:
                batched_data[key] = pad_tensor(batched_data[key])
        else:
            batched_data[key] = torch.tensor(batched_data[key])

    return batched_data

class Attention(nn.Module):
    def __init__(self, input_dim, head=4):
        a=4
        b=8
        c=12
        super(Attention, self).__init__()
        self.bt_pre = nn.Linear(input_dim, c)
        self.bt_cur = nn.Linear(input_dim, c)
        self.fcc = nn.Sequential(
            GELU(),
            nn.Linear(b, head),
        )
        self.fcg = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.fcm = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.head = head
        self.dim = input_dim
        self.fuse = nn.Sequential(
            GELU(),
            nn.Linear(c*2, a),
        )
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(a*3, b),
        )

    def forward(self, prev, curr, edges):
        bt_pre = self.bt_pre(prev)
        bt_cur = self.bt_cur(curr)
        shape = (bt_pre.shape[0],bt_pre.shape[1],bt_cur.shape[1],bt_pre.shape[2])
        bt_pre = bt_pre.unsqueeze(2).expand(shape)
        bt_cur = bt_cur.unsqueeze(1).expand(shape)
        c = torch.cat((bt_pre, bt_cur), dim=-1)
        c = self.fuse(c)
        c[edges==0] = 0
        edges1, edges2 = torch.sum(edges, dim=1)+1, torch.sum(edges, dim=2)+1
        edges1[edges1==0] = 1
        edges2[edges2==0] = 1
        d = torch.sum(c, dim=1).unsqueeze(1).expand(c.shape) / edges1.unsqueeze(1).unsqueeze(-1)
        e = torch.max(c, dim=1)[0].unsqueeze(1).expand(c.shape)
        fused = self.ffn(torch.cat((c,d,e),dim=-1))
        c = self.fcc(fused).transpose(-2,-3)
        mask = edges.transpose(-1,-2)
        res_list = []
        for i in range(self.head):
            c_true = c[:,:,:,i].clone()
            if c_true.shape[2] != 1:
                c_true[mask==0] = -9e15
                c_true = F.softmax(c_true, dim=2).clone()
            else:
                c_true[mask==0] = 0
            res = torch.bmm(c_true, prev[:,:,self.dim//self.head*i:self.dim//self.head*(i+1)])
            res_list.append(res)
        res = torch.cat(res_list, dim=-1)
        fused = fused.clone()
        fused[edges==0] = 0
        res = res + self.fcg(torch.sum(fused, dim=1) / edges1.unsqueeze(-1)) + self.fcm(torch.max(fused, dim=1)[0]/2)
        return res

class GCNLayer(nn.Module):
    def __init__(self, input_dim):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(input_dim, input_dim)

    def forward(self, prev, curr, edge):
        degree = edge.sum(dim=1).clamp(min=1).unsqueeze(-1)
        aggregated_messages = torch.bmm(edge.transpose(1, 2).float(), prev)# / degree
        combined_features = curr + aggregated_messages
        updated_curr = self.linear(combined_features)
        updated_curr = F.relu(updated_curr)
        return updated_curr

class SignFlowLayer(nn.Module):
    def __init__(self, input_dim, output_dim=None):
        super(SignFlowLayer, self).__init__()
        if output_dim is None:
            output_dim = input_dim
        self.weight_curr = nn.Parameter(torch.Tensor(input_dim, output_dim))
        nn.init.normal_(self.weight_curr, mean=0, std=0.1)
        self.fcp = nn.Linear(input_dim, output_dim)
        self.fcn = nn.Linear(input_dim, output_dim)
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(input_dim, input_dim//2),
            GELU(),
            nn.Linear(input_dim//2, output_dim),
        )
        self.attn = Attention(input_dim)
        self.lnrom = nn.LayerNorm(input_dim)

    def forward(self, prev_layer_features, current_layer_features, edges):
        positive_features = self.attn(prev_layer_features, current_layer_features, edges==1)
        negative_features = self.attn(prev_layer_features, current_layer_features, edges==-1)
        transformed_agg_features = self.fcp(positive_features) - self.fcn(negative_features)
        current_layer_features = torch.matmul(current_layer_features, self.weight_curr) + transformed_agg_features
        current_layer_features = self.lnrom(current_layer_features)

        return self.ffn(current_layer_features) + current_layer_features
    
class SubGraphLayer(nn.Module):
    def __init__(self, input_dim):
        super(SubGraphLayer, self).__init__()
        self.attn1 = SignFlowLayer(input_dim)
        self.attn2 = SignFlowLayer(input_dim)
        self.attn3 = SignFlowLayer(input_dim)
        self.attnq = SignFlowLayer(input_dim)
        self.fc = nn.Linear(input_dim, input_dim)
        self.pe = self.positional_encoding(args.neighbours, input_dim) * 0.01
        self.pe = nn.Parameter(self.pe, requires_grad=False)
        
    def positional_encoding(self, num_nodes, embedding_dim):
        pe = torch.zeros(num_nodes, embedding_dim)
        position = torch.arange(0, num_nodes).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * -(math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe
    
    def forward(self, emb_a, emb_an, emb_bn, emb_b, sub_0, sub_1, sub_2):
        sub_1 = sub_1.transpose(1,2)
        sub_11 = sub_2.transpose(1,2).clone()
        sub_11[sub_11!=0] = 1
        sub_22 = sub_0.transpose(1,2).clone()
        sub_22[sub_22!=0] = 1
        x0 = emb_a
        indices = torch.randperm(self.pe.size(0))
        shuffled_tensor = self.pe[indices]
        x1 = self.attn1(x0, emb_bn, sub_0)# + shuffled_tensor[:emb_a.shape[1]].unsqueeze(0).expand(emb_a.shape[0],-1,-1)
        x2 = self.attn2(x1, emb_an, sub_1)
        x3 = self.attn3(x2, emb_b, sub_2) + self.fc(self.attnq(x1, emb_b, sub_22))
        return x3

class SBSN(nn.Module):
    def __init__(self, n, m, emb_size=32):
        super(SBSN, self).__init__()

        self.features_a = nn.Parameter(torch.randn((n, emb_size)), requires_grad=True)
        self.features_b = nn.Parameter(torch.randn((m, emb_size)), requires_grad=True)
        self.sub = SubGraphLayer(emb_size)
        self.fcx = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.fcy = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.B = nn.Linear(emb_size, emb_size)
        self.CX = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size*5, emb_size)
        )
        self.C = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size, emb_size//4),
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size//4, 1),
        )
        self.fc =nn.Sequential(
            nn.Linear(emb_size, emb_size*2, bias=True)
        )
        self.emb_size = emb_size

    def get_embeddings(self, a, b, epoch, end=0):
        # if args.zero_shot:
        #     emb_a = torch.randn((*a.shape, self.emb_size), device=a.device)
        #     emb_b = torch.randn((*b.shape, self.emb_size), device=b.device)
        # else:
        emb_a = self.features_a[a.long()]
        emb_b = self.features_b[b.long()]
        if end:
            emb_a, emb_b = emb_a.detach(), emb_b.detach()
        if epoch < 40000:
            epoch = 0.0
        else:
            epoch = min(1.0, (epoch-40000)/10000)
        emb_a = emb_a * epoch
        emb_b = emb_b * epoch
        emb_a = self.fcx[end](emb_a)
        emb_b = self.fcy[end](emb_b)
        return emb_a, emb_b

    def forward(self, left, left_n, right_n, right, sub_0, sub_1, sub_2, epoch, **kwargs):
        embed_a, embed_b = self.get_embeddings(left, right, epoch, end=1)
        embed_an, embed_bn = self.get_embeddings(right_n, left_n, epoch)
        x = self.sub(embed_a, embed_an, embed_bn, embed_b, sub_0, sub_1, sub_2)
        y = self.sub(embed_b, embed_bn, embed_an, embed_a, sub_2.transpose(1,2), sub_1.transpose(1,2), sub_0.transpose(1,2))
        fuse = torch.cat((self.B(x)*y,x,y,embed_a,embed_b),dim=-1)
        fuse = self.CX(fuse)
        fuse = self.C(fuse).squeeze(-1).squeeze(-1)
        return torch.sigmoid(fuse)

    def loss(self, pred_y, y):
        assert y.min() >= 0, 'must 0~1'
        assert pred_y.size() == y.size(), 'must be same length'
        pos_ratio = y.sum() /  y.size()[0]
        weight = torch.where(y > 0.5, 1./pos_ratio, 1./(1-pos_ratio))
        # acc = sum((y>0.5) == (pred_y>0.5))/y.shape[0]
        acc = 0
        return F.binary_cross_entropy(pred_y, y, weight=weight), acc

    def forward_sub(self, left, left_n, right_n, right, sub_0, sub_1, sub_2, epoch, **kwargs):
        embed_a, embed_b = self.get_embeddings(left, right, epoch, end=1)
        embed_an, embed_bn = self.get_embeddings(right_n, left_n, epoch)
        x = self.sub(embed_a, embed_an, embed_bn, embed_b, sub_0, sub_1, sub_2)
        y = self.sub(embed_b, embed_bn, embed_an, embed_a, sub_2.transpose(1,2), sub_1.transpose(1,2), sub_0.transpose(1,2))
        return x, y
        
    def forward_node(self, left, right):
        emb_a = self.features_a[left]
        emb_b = self.features_b[right]
        x = emb_a * emb_b
        return self.fc(x)

def load_data(dataset_name, n, m):
    train_file_path = os.path.join(data_prefix, f'{dataset_name}_training.txt')
    val_file_path = os.path.join(data_prefix, f'{dataset_name}_validation.txt')
    test_file_path = os.path.join(data_prefix, f'{dataset_name}_testing.txt')
    if not os.path.exists(val_file_path):
        os.rename(os.path.join(data_prefix, f'{dataset_name}_val.txt'),val_file_path)
    if not os.path.exists(test_file_path):
        os.rename(os.path.join(data_prefix, f'{dataset_name}_test.txt'),test_file_path)

    train_edgelist = []
    # slices = 1/1
    # print(slices)
    # slices = slices ** 0.5
    # import random
    # random.seed(42)
    # sn, sm = random.sample(range(n),n)[:round(n/slices)], random.sample(range(m),m)[:round(m/slices)]
    # idn, idm = [-1 for _ in range(n)], [-1 for _ in range(m)]
    # for i, j in enumerate(sn):
    #     idn[j]=i
    # for i, j in enumerate(sm):
    #     idm[j]=i
    with open(train_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            # if idn[a]!=-1 and idm[b]!=-1:
            #     a, b = idn[a], idm[b]
            train_edgelist.append((a, b, s))

    val_edgelist = []
    with open(val_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            # if idn[a]!=-1 and idm[b]!=-1:
            #     a, b = idn[a], idm[b]
            val_edgelist.append((a, b, s))

    test_edgelist = []
    with open(test_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            # if idn[a]!=-1 and idm[b]!=-1:
            #     a, b = idn[a], idm[b]
            test_edgelist.append((a, b, s))

    return np.array(train_edgelist), np.array(val_edgelist), np.array(test_edgelist)#, round(n/slices), round(m/slices)

def power(tensor,exponent):
    return torch.sign(tensor) * torch.pow(torch.abs(tensor), exponent)

@torch.no_grad()
def test_and_val(dataloader, model, mode='val', epoch=0):
    model.eval()
    preds_ys, ys = [], []
    for data in dataloader:
        d = {key: value.cuda() for key, value in data.items()}
        pred = model(**d, epoch=epoch)
        preds_ys.append(pred)
        ys.append(d['edge_s'].float())
    preds = torch.cat(preds_ys, dim=-1).cpu().numpy()
    y = torch.cat(ys, dim=-1).cpu().numpy()
    preds[preds >= 0.5]  = 1
    preds[preds < 0.5] = 0
    auc = roc_auc_score(y, preds)
    f1 = f1_score(y, preds)
    macro_f1 = f1_score(y, preds, average='macro')
    micro_f1 = f1_score(y, preds, average='micro')
    res = {
        f'{mode}_auc': auc,
        f'{mode}_f1' : f1,
        f'{mode}_mac' : macro_f1,
        f'{mode}_mic' : micro_f1,
    }
    for k, v in res.items():
        mode ,_, metric = k.partition('_')
        # if not args.eval:
        #     tb_writer.add_scalar(f'{metric}/{mode}', v, epoch)
    return res

def format_res(data, p=1):
    s = {k:f"{v*100:.{p}f}" for k,v in data.items()}
    s = str(s).strip('{}').replace("'",'')
    return s

def split(edgelist, u, v, n, m):
    mask1 = np.isin(edgelist[:, 0], u) & np.isin(edgelist[:, 1], v)
    mask2 = ~np.isin(edgelist[:, 0], u) & ~np.isin(edgelist[:, 1], v)
    edgelist1 = edgelist[mask1]
    edgelist2 = edgelist[mask2]

    sorted_u = np.sort(u)
    sorted_v = np.sort(v)
    edgelist1[:, 0] = [np.where(sorted_u == x)[0][0] for x in edgelist1[:, 0]]
    edgelist1[:, 1] = [np.where(sorted_v == x)[0][0] for x in edgelist1[:, 1]]

    u_complement = np.setdiff1d(np.arange(n), u)
    v_complement = np.setdiff1d(np.arange(m), v)
    sorted_u_complement = np.sort(u_complement)
    sorted_v_complement = np.sort(v_complement)
    edgelist2[:, 0] = [np.where(sorted_u_complement == x)[0][0] for x in edgelist2[:, 0]]
    edgelist2[:, 1] = [np.where(sorted_v_complement == x)[0][0] for x in edgelist2[:, 1]]
    return edgelist1, edgelist2

def run():
    global edges
    set_a_num, set_b_num = DATA_EMB_DIC[args.dataset_name]
    train_edgelist, val_edgelist, test_edgelist = load_data(args.dataset_name, set_a_num, set_b_num)

    print(set_a_num, set_b_num)
    edges = Edge(train_edgelist, set_a_num, set_b_num)

    model = SBSN(n=set_a_num, m=set_b_num, emb_size=args.emb_size)
    if args.eval:
        setup_seed(args.seed)
        checkpoint = torch.load(args.ckpt, weights_only=True)
        if args.zero_shot:
            filtered_checkpoint = {k: v for k, v in checkpoint.items() if 'feature' not in k}
            model.load_state_dict(filtered_checkpoint, strict=False)
        else:
            model.load_state_dict(checkpoint)
        model = model.cuda()
        dataset_test = GraphDataset(test_edgelist, test=1, edge=edges)
        dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
        res = test_and_val(dataloader_test, model, mode='test', epoch=args.epoch)
        print(format_res(res))
        return
    model = model.cuda()

    params = [[], []]
    for name, param in model.named_parameters():
        if name.startswith('features'):
            params[0].append(param)
        else:
            params[1].append(param)
    param_groups = [
        {'params': params[0], 'lr': args.lr*1},
        {'params': params[1], 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(param_groups, weight_decay=args.weight_decay)
    # optimizer_node = torch.optim.Adam([param_groups[0]], weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.3)
    scheduler_slow = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998)
    dataset = GraphDataset(train_edgelist, edge=edges)
    dataset_val = GraphDataset(val_edgelist, test=1, edge=edges)
    dataset_test = GraphDataset(test_edgelist, test=1, edge=edges)
    dataset_train = GraphDataset(train_edgelist[np.random.choice(train_edgelist.shape[0], size=len(test_edgelist), replace=False)], test=1, edge=edges)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, shuffle=True, pin_memory=True)
    dataloader_node = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, shuffle=True, pin_memory=True)
    dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
    dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
    dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
    
    epoch = 0
    res_best = dict()
    best_auc = 0.0
    all_loss = 0.0
    # left = torch.tensor(train_edgelist[:, 0]).cuda().unsqueeze(-1)
    # right = torch.tensor(train_edgelist[:, 1]).cuda().unsqueeze(-1)
    # train_yy = torch.tensor(train_edgelist[:, 2]).cuda().float()
    # train_yy[train_yy==-1] = 0
    # data_cat = torch.zeros((set_a_num, set_b_num, args.emb_size*2), dtype=torch.float).cuda()
    tm = time.time()
    al, re = 0, 0
    while epoch < args.epoch:
        for data in dataloader:
            data = {key: value.cuda() for key, value in data.items()}
            model.train()
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                pred_y = model(**data, epoch = epoch)
                train_y = data['edge_s'].float()
                loss = model.loss(pred_y, train_y)[0]
                all_loss += loss.item()
                # device = torch.device("cuda")
                # allocated_memory = torch.cuda.memory_allocated(device) / 1024**2  # 转为 MB
                # reserved_memory = torch.cuda.memory_reserved(device) / 1024**2   # 转为 MB
                # al, re = max(al, allocated_memory), max(re, reserved_memory)
                # if epoch in [150,300,450]:
                #     print(f"Allocated memory: {al:.2f} MB")
                #     print(f"Reserved memory: {re:.2f} MB")

                loss.backward()
                optimizer.step()
            
            epoch += 1
            # if epoch%100 == 0:
            #     print(epoch, (time.time()-tm)/3600)
            if epoch%args.val_epoch == 0 and epoch >= 80000:
                model.eval()
                with torch.set_grad_enabled(False):
                    val_res = test_and_val(dataloader_val, model, mode='val', epoch=epoch)
                    # train_res = test_and_val(dataloader_train, model, mode='train', epoch=epoch)
                    test_res = test_and_val(dataloader_test, model, mode='test', epoch=epoch)
                    if val_res['val_auc']>best_auc:
                        best_auc = val_res['val_auc']
                        res_best.update(val_res)
                        # if comment == 'base':
                        print('saved',epoch)
                        torch.save(model.state_dict(), f'./res/{args.dataset_name}_{epoch}.pth')

                        res_best.update(test_res)
                        print(f'\r\033[Kepoch: {epoch}\t{format_res(test_res)}')
                    with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
                        fh.write(f'{epoch}\t{format_res(res_best,p=4)}\tloss: {all_loss:.4f}\n')
                        fh.write(f'{epoch}\t{format_res(val_res,p=4)}\t{format_res(test_res,p=4)}\n') # \t{format_res(train_res,p=4)}   train_auc: {train_res["train_auc"]*100:.1f}
                    print(f'\r\033[Kepoch: {epoch}\tval_auc: {val_res["val_auc"]*100:.1f}\tbest_auc: {best_auc*100:.1f}\tloss: {all_loss:.4f}\t', end='')
                all_loss = 0.0
 
            # if epoch in [80000, 100000]:
            #     model.eval()
            #     with torch.set_grad_enabled(False):
            #         for d in dataloader_node:
            #             d = {key: value.cuda() for key, value in d.items()}
            #             xs, ys = model.forward_sub(**d, epoch=epoch)
            #             for l,r,x,y in zip(data['left'], data['right'], xs, ys):
            #                 data_cat[l, r] = torch.cat((x,y),dim=-1).detach()
            #     model.train()
            #     for _ in range(400):
            #         optimizer_node.zero_grad()
            #         with torch.set_grad_enabled(True):
            #             x = model.forward_node(left, right)
            #             loss = torch.nn.CosineSimilarity(dim=-1)(x, data_cat[left, right]).mean()
            #             loss.backward()
            #             optimizer_node.step()
                
            if epoch in [20000, 40000]:
                scheduler.step()
            
            if epoch >= args.epoch:
                break
    print()

if __name__ == "__main__":
    dataset_name = args.dataset_name
    if args.eval:
        for i in range(0,2):
            args.dataset_name = dataset_name+'-'+str(i+1)
            run()
        exit()
    if args.comment:
        comment = args.comment
    else:
        comment = str(int(time.time())%100000000)
    print(comment)
    with open(f'./code/{dataset_name}_{comment}.py','a') as fh:
        with open(__file__,'r') as f:
            fh.write(f.read())
            fh.write('\n\n\n\n\n\n')
    for i in range(0,5):
        hyper_params = dict(vars(args))
        del hyper_params['device']
        hyper_params = "~".join([f"{k}-{v}" for k,v in hyper_params.items()])
        # tb_writer = SummaryWriter(log_dir=f'./logs/{hyper_params}')
        with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
            if i==0:
                fh.write(hyper_params)
            fh.write(f'\n{i}\n')
        print(f'training: {i}')
        
        setup_seed(args.seed)
        args.dataset_name = dataset_name+'-'+str(i+1)
        run()