#-*- coding: utf-8 -*-
import os
import time
import random
import argparse
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--dataset_name', type=str, default='house1to10')
parser.add_argument('--emb_size', type=int, default=32, help='Embeding Size')
parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay')
parser.add_argument('--lr', type=float, default=0.02, help='Learning Rate')
parser.add_argument('--seed', type=int, default=222, help='Random seed')
parser.add_argument('--epoch', type=int, default=80000, help='Epoch')
parser.add_argument('--batch_size', type=int, default=2048, help='Batch Size')
parser.add_argument('--neighbours', type=int, default=400, help='Max Neighbours')
parser.add_argument('--val_epoch', type=int, default=50, help='Epoch Interval For Validation')
parser.add_argument('--zero_shot', action='store_true', help='Zero Shot Learning')
parser.add_argument('--ckpt', type=str, help='Checkpoint path')
parser.add_argument('--eval', action='store_true', help='Evaluation')
parser.add_argument('--comment', type=str, default='')
args = parser.parse_args()

data_prefix = 'experiments-data'

if 'cuda' in args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device[-1]

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import Edge, sub_edge, GELU
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import f1_score, roc_auc_score

# from torch.utils.tensorboard import SummaryWriter

def setup_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

from common import DATA_EMB_DIC

def find_neighbours(edges, node_in, node_out, left=1):
    neighbour = set()
    ave = args.neighbours//len(node_in)
    if ave < 4:
        ave = 4
    random.shuffle(node_in)
    node_out = set(node_out)
    for node in node_in:
        if left:
            st = edges.edge_abp[node] + edges.edge_abn[node]
        else:
            st = edges.edge_bap[node] + edges.edge_ban[node]
        # samp = np.array(random.sample(st, min(len(st), ave*4)))
        if not len(st):
            return []
        samp = np.array(st)
        if left:
            deg = edges.degb[samp]
        else:
            deg = edges.dega[samp]
        deg = np.random.permutation(deg)
        deg_idx = np.argsort(deg)[::-1][:ave]
        samp = samp[deg_idx].tolist()
        res = list(set(samp) - node_out)
        neighbour.update(res)
        if len(neighbour) >= args.neighbours:
            break
    return list(neighbour)[:args.neighbours]

class GraphDataset(Dataset):
    def __init__(self, val, test=False, edge=None):
        self.val = Edge(val)
        print(len(self.val.edge_set))
        self.test = test
        if edge is None:
            self.edges = edges
        else:
            self.edges = edge

    def __len__(self):
        return len(self.val.edge_list)

    def __getitem__(self, index):
        edge = self.val.edge_list[index]
        left, right, s = [edge[0]], [edge[1]], edge[2]==1
        edge_s = int(s)
        return {
            'left': left,
            'right': right,
            'edge_s': edge_s
        }

def pad_tensor(batch):
    max_size = [max(tensor.size(dim) for tensor in batch) for dim in range(batch[0].dim())]
    batch_size = len(batch)
    background = torch.zeros([batch_size] + max_size, dtype=batch[0].dtype)
    for i, tensor in enumerate(batch):
        indices = tuple(slice(0, sz) for sz in tensor.size())
        background[i][indices] = tensor
    return background

def collate_fn(batch):
    batched_data = {key: [] for key in batch[0]}
    for item in batch:
        for key, value in item.items():
            if isinstance(value, torch.Tensor):
                batched_data[key].append(value)
            elif isinstance(value, list):
                batched_data[key].append(torch.tensor(value))
            else:
                batched_data[key].append(value)
    for key in batched_data:
        if isinstance(batched_data[key][0], torch.Tensor):
            if all(tensor.shape == batched_data[key][0].shape for tensor in batched_data[key]):
                batched_data[key] = torch.stack(batched_data[key])
            else:
                batched_data[key] = pad_tensor(batched_data[key])
        else:
            batched_data[key] = torch.tensor(batched_data[key])

    return batched_data

class Attention(nn.Module):
    def __init__(self, input_dim, head=4):
        a=4
        b=8
        c=12
        super(Attention, self).__init__()
        self.bt_pre = nn.Linear(input_dim, c)
        self.bt_cur = nn.Linear(input_dim, c)
        self.fcc = nn.Sequential(
            GELU(),
            nn.Linear(b, head),
        )
        self.fcg = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.fcm = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.head = head
        self.dim = input_dim
        self.fuse = nn.Sequential(
            GELU(),
            nn.Linear(c*2, a),
        )
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(a*3, b),
        )

    def forward(self, prev, curr, edges):
        bt_pre = self.bt_pre(prev)
        bt_cur = self.bt_cur(curr)
        shape = (bt_pre.shape[0],bt_pre.shape[1],bt_cur.shape[1],bt_pre.shape[2])
        bt_pre = bt_pre.unsqueeze(2).expand(shape)
        bt_cur = bt_cur.unsqueeze(1).expand(shape)
        c = torch.cat((bt_pre, bt_cur), dim=-1)
        c = self.fuse(c)
        c[edges==0] = 0
        edges1, edges2 = torch.sum(edges, dim=1)+1, torch.sum(edges, dim=2)+1
        edges1[edges1==0] = 1
        edges2[edges2==0] = 1
        d = torch.sum(c, dim=1).unsqueeze(1).expand(c.shape) / edges1.unsqueeze(1).unsqueeze(-1)
        e = torch.max(c, dim=1)[0].unsqueeze(1).expand(c.shape)
        fused = self.ffn(torch.cat((c,d,e),dim=-1))
        c = self.fcc(fused).transpose(-2,-3)
        mask = edges.transpose(-1,-2)
        res_list = []
        for i in range(self.head):
            c_true = c[:,:,:,i].clone()
            if c_true.shape[2] != 1:
                c_true[mask==0] = -9e15
                c_true = F.softmax(c_true, dim=2).clone()
            else:
                c_true[mask==0] = 0
            res = torch.bmm(c_true, prev[:,:,self.dim//self.head*i:self.dim//self.head*(i+1)])
            res_list.append(res)
        res = torch.cat(res_list, dim=-1)
        fused = fused.clone()
        fused[edges==0] = 0
        res = res + self.fcg(torch.sum(fused, dim=1) / edges1.unsqueeze(-1)) + self.fcm(torch.max(fused, dim=1)[0]/2)
        return res

class SignFlowLayer(nn.Module):
    def __init__(self, input_dim, output_dim=None):
        super(SignFlowLayer, self).__init__()
        if output_dim is None:
            output_dim = input_dim
        self.weight_curr = nn.Parameter(torch.Tensor(input_dim, output_dim))
        nn.init.normal_(self.weight_curr, mean=0, std=0.1)
        self.fcp = nn.Linear(input_dim, output_dim)
        self.fcn = nn.Linear(input_dim, output_dim)
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(input_dim, input_dim//2),
            GELU(),
            nn.Linear(input_dim//2, output_dim),
        )
        self.attn = Attention(input_dim)
        self.lnrom = nn.LayerNorm(input_dim)

    def forward(self, prev_layer_features, current_layer_features, edges):
        positive_features = self.attn(prev_layer_features, current_layer_features, edges==1)
        negative_features = self.attn(prev_layer_features, current_layer_features, edges==-1)
        transformed_agg_features = self.fcp(positive_features) - self.fcn(negative_features)
        current_layer_features = torch.matmul(current_layer_features, self.weight_curr) + transformed_agg_features
        current_layer_features = self.lnrom(current_layer_features)

        return self.ffn(current_layer_features) + current_layer_features
    
class SubGraphLayer(nn.Module):
    def __init__(self, input_dim):
        super(SubGraphLayer, self).__init__()
        self.attn1 = SignFlowLayer(input_dim)
        self.attn2 = SignFlowLayer(input_dim)
        self.attn3 = SignFlowLayer(input_dim)
        self.attnq = SignFlowLayer(input_dim)
        self.fc = nn.Linear(input_dim, input_dim)
    
    def forward(self, emb_a, emb_an, emb_bn, emb_b, sub_0, sub_1, sub_2):
        sub_1 = sub_1.transpose(1,2)
        sub_11 = sub_2.transpose(1,2).clone()
        sub_11[sub_11!=0] = 1
        sub_22 = sub_0.transpose(1,2).clone()
        sub_22[sub_22!=0] = 1
        x0 = emb_a
        x1 = self.attn1(x0, emb_bn, sub_0)
        x2 = self.attn2(x1, emb_an, sub_1)
        x3 = self.attn3(x2, emb_b, sub_2) + self.fc(self.attnq(x1, emb_b, sub_22))
        return x3

class SBSN(nn.Module):
    def __init__(self, n, m, emb_size=32):
        super(SBSN, self).__init__()

        if not args.zero_shot:
            self.features_a = nn.Parameter(torch.randn((n, emb_size)), requires_grad=True)
            self.features_b = nn.Parameter(torch.randn((m, emb_size)), requires_grad=True)
        self.sub = SubGraphLayer(emb_size)
        self.fcx = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.fcy = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.B = nn.Linear(emb_size, emb_size)
        self.CX = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size*5, emb_size)
        )
        self.C = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size, emb_size//4),
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size//4, 1),
        )
        self.emb_size = emb_size

    def get_embeddings(self, a, b, end=0):
        if args.zero_shot:
            emb_a = torch.randn((*a.shape, self.emb_size), device=a.device)
            emb_b = torch.randn((*b.shape, self.emb_size), device=b.device)
        else:
            emb_a = self.features_a[a.long()]
            emb_b = self.features_b[b.long()]
            if end:
                emb_a, emb_b = emb_a.detach(), emb_b.detach()
        emb_a = self.fcx[end](emb_a*0)
        emb_b = self.fcy[end](emb_b*0)
        return emb_a, emb_b

    def forward(self, left, left_n, right_n, right, sub_0, sub_1, sub_2, **kwargs):
        embed_a, embed_b = self.get_embeddings(left, right, end=1)
        embed_an, embed_bn = self.get_embeddings(right_n, left_n)
        x = self.sub(embed_a, embed_an, embed_bn, embed_b, sub_0, sub_1, sub_2)
        y = self.sub(embed_b, embed_bn, embed_an, embed_a, sub_2.transpose(1,2), sub_1.transpose(1,2), sub_0.transpose(1,2))
        fuse = torch.cat((self.B(x)*y,x,y,embed_a,embed_b),dim=-1)
        fuse = self.CX(fuse)
        fuse = self.C(fuse).squeeze(-1).squeeze(-1)
        return torch.sigmoid(fuse)

    def loss(self, pred_y, y):
        assert y.min() >= 0, 'must 0~1'
        assert pred_y.size() == y.size(), 'must be same length'
        pos_ratio = y.sum() /  y.size()[0]
        weight = torch.where(y > 0.5, 1./pos_ratio, 1./(1-pos_ratio))
        acc = sum((y>0.5) == (pred_y>0.5))/y.shape[0]
        return F.binary_cross_entropy(pred_y, y, weight=weight), acc

def get_info(a, anchors, left=1):
    global edges
    if left:
        pos = set(edges.edge_abp[a])
        neg = set(edges.edge_abn[a])
    else:
        pos = set(edges.edge_bap[a])
        neg = set(edges.edge_ban[a])
    length = max(1, len(pos) + len(neg))
    res = []
    for an in anchors:
        if left:
            pp = len(set(edges.edge_abp[an]) & pos)
            nn = len(set(edges.edge_abn[an]) & neg)
            pn = len(set(edges.edge_abn[an]) & pos)
            np = len(set(edges.edge_abp[an]) & neg)
        else:
            pp = len(set(edges.edge_bap[an]) & pos)
            nn = len(set(edges.edge_ban[an]) & neg)
            pn = len(set(edges.edge_ban[an]) & pos)
            np = len(set(edges.edge_bap[an]) & neg)
        res.append((pp/length, nn/length, pn/length, np/length, (length-pp-nn-pn-np)/length))
    return res, length

class FinalLayer(nn.Module):
    def __init__(self, n, m, emb_size, feat_a, feat_b):
        super(FinalLayer, self).__init__()
        self.features_a = feat_a.detach()
        self.features_b = feat_b.detach()
        randa = torch.randn((n, emb_size-self.features_a.shape[1]), device=feat_a.device)
        randb = torch.randn((m, emb_size-self.features_b.shape[1]), device=feat_b.device)
        self.features_a = torch.cat((self.features_a, randa), dim=-1)
        self.features_b = torch.cat((self.features_b, randb), dim=-1)
        # self.features_a = torch.randn((n, emb_size))
        # self.features_b = torch.randn((m, emb_size))
        self.features_a = nn.Parameter(self.features_a, requires_grad=True)
        self.features_b = nn.Parameter(self.features_b, requires_grad=True)
        self.fc = nn.Sequential(
            GELU(),
            nn.Linear(64, 64, bias=True)
        )
        self.fcy = nn.Sequential(
            GELU(),
            nn.Linear(64, 64, bias=True),
            GELU(),
            nn.Linear(64, 1, bias=True)
        )
        self.fcc = nn.Linear(emb_size, 64, bias=True)
        
    def forward(self, left, right, epoch=0, **kwargs):
        emb_a = self.features_a[left]
        emb_b = self.features_b[right]
        # return torch.sigmoid(torch.sum(emb_a + emb_b, dim=-1).squeeze(-1).squeeze(-1))
        edge_feat = self.fcc(emb_a * emb_b)
        self.x = self.fc(edge_feat)
        y = self.fcy(edge_feat.detach())
        return torch.sigmoid(y.squeeze(-1).squeeze(-1))
        
    def forward_old(self, left, right, edges, emb, anchors, accs, **kwargs):
        left = left.squeeze(-1)
        right = right.squeeze(-1)
        ans = emb[left, right]
        # if not self.training:
        #     return torch.sigmoid(ans)
        info_list_l, info_list_r = [], []
        emb_list_l, emb_list_r = [], []
        deg_list_l, deg_list_r = [], []
        for l,r in zip(left, right):
            info, length = get_info(l, anchors[0], left=1)
            info_list_l.append(info)
            deg_list_l.append(length)
            info, length = get_info(r, anchors[1], left=0)
            info_list_r.append(info)
            deg_list_r.append(length)
            emb_list_l.append(emb[anchors[0], r])
            emb_list_r.append(emb[l, anchors[1]])
        info_list_l, info_list_r = torch.tensor(info_list_l).cuda(), torch.tensor(info_list_r).cuda()
        emb_list_l, emb_list_r = torch.stack(emb_list_l), torch.stack(emb_list_r)
        deg_list_l = torch.tensor(deg_list_l).cuda()
        deg_list_r = torch.tensor(deg_list_r).cuda()
        info_list_l = torch.cat((info_list_l, accs.expand(len(left),-1).unsqueeze(-1).float(), deg_list_l.unsqueeze(-1).expand(-1,info_list_l.shape[1]).unsqueeze(-1).float()/100), dim=-1)
        info_list_r = torch.cat((info_list_r, accs.expand(len(right),-1).unsqueeze(-1).float(), deg_list_r.unsqueeze(-1).expand(-1,info_list_r.shape[1]).unsqueeze(-1).float()/100), dim=-1)
        coeff_l, coeff_r = self.fcl(info_list_l).squeeze(-1), self.fcr(info_list_r).squeeze(-1)
        emb_list_l = emb_list_l * torch.exp(coeff_l)
        emb_list_r = emb_list_r * torch.exp(coeff_r)
        ans = torch.sum(emb_list_l, dim=1)/deg_list_l + torch.sum(emb_list_r, dim=1)/deg_list_r
        return torch.sigmoid(ans)

    def loss(self, pred_y, y):
        assert y.min() >= 0, 'must 0~1'
        assert pred_y.size() == y.size(), 'must be same length'
        pos_ratio = y.sum() /  y.size()[0]
        weight = torch.where(y > 0.5, 1./pos_ratio, 1./(1-pos_ratio))
        # acc = sum((y>0.5) == (pred_y>0.5))/y.shape[0]
        acc = 0
        return F.binary_cross_entropy(pred_y, y, weight=weight), acc
        
def load_data(dataset_name):
    train_file_path = os.path.join(data_prefix, f'{dataset_name}_training.txt')
    val_file_path = os.path.join(data_prefix, f'{dataset_name}_validation.txt')
    test_file_path = os.path.join(data_prefix, f'{dataset_name}_testing.txt')
    if not os.path.exists(val_file_path):
        os.rename(os.path.join(data_prefix, f'{dataset_name}_val.txt'),val_file_path)
    if not os.path.exists(test_file_path):
        os.rename(os.path.join(data_prefix, f'{dataset_name}_test.txt'),test_file_path)

    train_edgelist = []
    with open(train_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            train_edgelist.append((a, b, s))

    val_edgelist = []
    with open(val_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            val_edgelist.append((a, b, s))

    test_edgelist = []
    with open(test_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            test_edgelist.append((a, b, s))

    return np.array(train_edgelist), np.array(val_edgelist), np.array(test_edgelist)

def power(tensor,exponent):
    return torch.sign(tensor) * torch.pow(torch.abs(tensor), exponent)

@torch.no_grad()
def test_and_val(left, right, edge_list, model, mode='val'):
    global edges
    model.eval()
    pred = model(left, right)
    preds = pred.cpu().numpy()
    y = np.copy(edge_list[:, 2])
    y[y==-1] = 0
    preds[preds >= 0.5]  = 1
    preds[preds < 0.5] = 0
    auc = roc_auc_score(y, preds)
    f1 = f1_score(y, preds)
    macro_f1 = f1_score(y, preds, average='macro')
    micro_f1 = f1_score(y, preds, average='micro')
    res = {
        f'{mode}_auc': auc,
        f'{mode}_f1' : f1,
        f'{mode}_mac' : macro_f1,
        f'{mode}_mic' : micro_f1,
    }
    for k, v in res.items():
        mode ,_, metric = k.partition('_')
    return res

def format_res(data, p=1):
    s = {k:f"{v*100:.{p}f}" for k,v in data.items()}
    s = str(s).strip('{}').replace("'",'')
    return s

def split(edgelist, u, v, n, m):
    mask1 = np.isin(edgelist[:, 0], u) & np.isin(edgelist[:, 1], v)
    mask2 = ~np.isin(edgelist[:, 0], u) & ~np.isin(edgelist[:, 1], v)
    edgelist1 = edgelist[mask1]
    edgelist2 = edgelist[mask2]

    sorted_u = np.sort(u)
    sorted_v = np.sort(v)
    edgelist1[:, 0] = [np.where(sorted_u == x)[0][0] for x in edgelist1[:, 0]]
    edgelist1[:, 1] = [np.where(sorted_v == x)[0][0] for x in edgelist1[:, 1]]

    u_complement = np.setdiff1d(np.arange(n), u)
    v_complement = np.setdiff1d(np.arange(m), v)
    sorted_u_complement = np.sort(u_complement)
    sorted_v_complement = np.sort(v_complement)
    edgelist2[:, 0] = [np.where(sorted_u_complement == x)[0][0] for x in edgelist2[:, 0]]
    edgelist2[:, 1] = [np.where(sorted_v_complement == x)[0][0] for x in edgelist2[:, 1]]
    return edgelist1, edgelist2

def run():
    global edges
    data = torch.load(f'./res/{args.dataset_name}_emb.pth', map_location='cpu', weights_only=True)
    # data = data0["all_emb"]
    # data_all = data.cuda().permute(1, 0, 2)
    # data = data0
    
    # data = torch.load(f'./res/{args.dataset_name}_emb.pth', map_location='cpu', weights_only=True)
    # data_all = data["all_emb"].cuda().permute(1, 0, 2)
    
    
    # for _ in range(9):#[1,3,4,5,7]:
    #     data_all[:,_,:] = data_all[:,0,:]
    sp = data["all_emb"].shape
    data_all = data["all_emb"].reshape(sp[0], 1, sp[1]).cuda()
    train_edgelist, val_edgelist, test_edgelist = load_data(args.dataset_name)

    set_a_num, set_b_num = DATA_EMB_DIC[args.dataset_name]
    print(set_a_num, set_b_num)
    edges = Edge(train_edgelist, set_a_num, set_b_num)

    model = FinalLayer(n=set_a_num, m=set_b_num, emb_size=args.emb_size, feat_a=data["feature_a"].detach(), feat_b=data["feature_b"].detach())
    model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)
    
    epoch = 0
    res_best = dict()
    best_auc = 0.0
    all_loss = 0.0
    left = torch.tensor(train_edgelist[:, 0]).cuda().unsqueeze(-1)
    right = torch.tensor(train_edgelist[:, 1]).cuda().unsqueeze(-1)
    left_v = torch.tensor(val_edgelist[:, 0]).cuda().unsqueeze(-1)
    right_v = torch.tensor(val_edgelist[:, 1]).cuda().unsqueeze(-1)
    left_t = torch.tensor(test_edgelist[:, 0]).cuda().unsqueeze(-1)
    right_t = torch.tensor(test_edgelist[:, 1]).cuda().unsqueeze(-1)
    train_y = torch.tensor(train_edgelist[:, 2]).cuda().float()
    train_y[train_y==-1] = 0
    slices = 5
    while epoch < args.epoch:
        model.train()
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            for i in range(slices):
                start, end = train_edgelist.shape[0]//slices*i, train_edgelist.shape[0]//slices*(i+1)
                pred_y = model(left[start:end], right[start:end], epoch=epoch)
                loss = 0
                if epoch > 120:
                    loss += model.loss(pred_y, train_y[start:end])[0]
                if epoch < 150:
                    loss += torch.nn.MSELoss()(model.x,data_all[start:end])
                    # for j in range(2):
                    #     loss += torch.nn.CosineSimilarity(dim=-1)(model.x[...,j*32:(j+1)*32], data_all[start:end,:,j*32:(j+1)*32]).mean()*2
                all_loss += loss.item()

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        
        epoch += 1
        scheduler.step()
        if epoch%args.val_epoch == 0:
            model.eval()
            with torch.set_grad_enabled(False):
                val_res = test_and_val(left_v, right_v, val_edgelist, model, mode='val')
                test_res = test_and_val(left_t, right_t, test_edgelist, model, mode='test')
                if val_res['val_auc']>best_auc or 1:
                    best_auc = val_res['val_auc']
                    res_best.update(val_res)
                    res_best.update(test_res)
                    print(f'\r\033[Kepoch: {epoch}\t{format_res(test_res)}')
                    with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
                        fh.write(f'{epoch}\t{format_res(res_best,p=4)}\tloss: {all_loss:.4f}\n')
                print(f'\r\033[Kepoch: {epoch}\tval_auc: {val_res["val_auc"]*100:.1f}\tbest_auc: {best_auc*100:.1f}\tloss: {all_loss:.4f}\ttest: {test_res["test_auc"]*100:.1f}', end='')
            all_loss = 0.0

        if epoch >= args.epoch:
            break
    print()

if __name__ == "__main__":
    dataset_name = args.dataset_name
    if args.eval:
        for i in range(0,2):
            args.dataset_name = dataset_name+'-'+str(i+1)
            run()
        exit()
    if args.comment:
        comment = args.comment
    else:
        comment = str(int(time.time())%100000000)
    print(comment)
    with open(f'./code/{dataset_name}_{comment}.py','a') as fh:
        with open(__file__,'r') as f:
            fh.write(f.read())
            fh.write('\n\n\n\n\n\n')
    for i in range(0,5):
        hyper_params = dict(vars(args))
        del hyper_params['device']
        hyper_params = "~".join([f"{k}-{v}" for k,v in hyper_params.items()])
        with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
            if i==0:
                fh.write(hyper_params)
            fh.write(f'\n{i}\n')
        print(f'training: {i}')
        
        setup_seed(args.seed)
        args.dataset_name = dataset_name+'-'+str(i+1)
        run()