#-*- coding: utf-8 -*-
import os
import time
import random
import argparse
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--dataset_name', type=str, default='house1to10')
parser.add_argument('--emb_size', type=int, default=32, help='Embeding Size')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight Decay')
parser.add_argument('--lr', type=float, default=0.002, help='Learning Rate')
parser.add_argument('--seed', type=int, default=222, help='Random seed')
parser.add_argument('--epoch', type=int, default=80000, help='Epoch')
parser.add_argument('--batch_size', type=int, default=64, help='Batch Size')
parser.add_argument('--neighbours', type=int, default=200, help='Max Neighbours')
parser.add_argument('--val_epoch', type=int, default=50, help='Epoch Interval For Validation')
parser.add_argument('--zero_shot', action='store_true', help='Zero Shot Learning')
parser.add_argument('--ckpt', type=str, help='Checkpoint path')
parser.add_argument('--eval', action='store_true', help='Evaluation')
parser.add_argument('--comment', type=str, default='')
args = parser.parse_args()

data_prefix = 'experiments-data'

if 'cuda' in args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device[-1]

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import Edge, sub_edge, GELU
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import f1_score, roc_auc_score

# from torch.utils.tensorboard import SummaryWriter

def setup_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

from common import DATA_EMB_DIC

def find_neighbours(edges, node_in, node_out, neighbours, left=1):
    if neighbours <= 5:
        return []
    neighbour = set()
    ave = neighbours
    random.shuffle(node_in)
    node_out = set(node_out)
    for node in node_in:
        if left:
            st = edges.edge_abp[node] + edges.edge_abn[node]
        else:
            st = edges.edge_bap[node] + edges.edge_ban[node]
        # samp = np.array(random.sample(st, min(len(st), ave*2)))
        if not len(st):
            continue
        samp = np.array(st)
        if left:
            deg = edges.degb[samp]
        else:
            deg = edges.dega[samp]
        # deg = np.random.permutation(deg)
        deg_idx = np.argsort(deg)[::-1][:ave]
        samp = samp[deg_idx].tolist()
        res = set(samp) - node_out
        neighbour = neighbour | res
        if len(neighbour) >= neighbours:
            break
    return list(neighbour)[:neighbours]

class GraphDataset(Dataset):
    def __init__(self, val, test=False, edge=None):
        self.val = Edge(val)
        print(len(self.val.edge_set))
        self.test = test
        if edge is None:
            self.edges = edges
        else:
            self.edges = edge

    def __len__(self):
        return len(self.val.edge_list)
    
    def calc_neighbours(self, A, D, p=0.005):
        N = args.neighbours*2
        mxt = 0
        res = (A, (N-A-D)//2, (N-A-D)//2, D)
        for a in range(0,A+1):
            for d in range(0,D+1):
                b_plus_c = (N-a-d)//2
                if b_plus_c >= A-a and b_plus_c >= D-d:
                    b=b_plus_c
                    c=N-a-b-d
                else:
                    left = N-A-D
                    b,c = D-d, A-a
                    if b<c:
                        b+=left
                    else:
                        c+=left
                target =(a*d*p+a*(D-d)*p+(A-a)*d*p)+a*b*c*d*p*p*p
                if target > mxt:
                    mxt = target
                    res = (a,b,c,d)
        return mxt, res

    def __getitem__(self, index):
        edge = self.val.edge_list[index]
        left, right, s = [edge[0]], [edge[1]], edge[2]==1
        left_n = find_neighbours(self.edges, left, right, args.neighbours, left=1)
        right_n = find_neighbours(self.edges, right, left, args.neighbours, left=0)
        remain = (2*args.neighbours-len(left_n)-len(right_n))//2
        l2 = find_neighbours(self.edges, left_n, right_n + left, remain, left=0)
        l3 = find_neighbours(self.edges, right_n, left_n + right, remain, left=1)
        # l2 = l3 = []
        # a,b,c,d = self.calc_neighbours(len(left_n), len(right_n))[1]
        # l2 = find_neighbours(self.edges, left_n, right_n + left, b+d-len(right_n), left=0)
        # l3 = find_neighbours(self.edges, right_n, left_n + right, a+c-len(left_n), left=1)
        # l2.extend(right_n[d:])
        # right_n = right_n[:d]
        # l3.extend(left_n[a:])
        # left_n = left_n[:a]
        
        sub_0 = sub_edge(left, left_n, self.edges)
        sub_1 = sub_edge(right_n, left_n, self.edges)
        sub_2 = sub_edge(right_n, right, self.edges)
        edge_s = int(s)
        return {
            'l0': left,
            'l5': right,
            'l1': left_n,
            'l4': right_n,
            'l2': l2,
            'l3': l3,
            'sub_01': sub_0,
            'sub_41': sub_1,
            'sub_45': sub_2,
            'sub_21': sub_edge(l2, left_n, self.edges),
            'sub_23': sub_edge(l2, l3, self.edges),
            'sub_03': sub_edge(left, l3, self.edges),
            'sub_25': sub_edge(l2, right, self.edges),
            'sub_43': sub_edge(right_n, l3, self.edges),
            'edge_s': edge_s
        }

def pad_tensor(batch):
    max_size = [max(tensor.size(dim) for tensor in batch) for dim in range(batch[0].dim())]
    batch_size = len(batch)
    background = torch.zeros([batch_size] + max_size, dtype=batch[0].dtype)
    for i, tensor in enumerate(batch):
        indices = tuple(slice(0, sz) for sz in tensor.size())
        background[i][indices] = tensor
    return background

def collate_fn(batch):
    batched_data = {key: [] for key in batch[0]}
    for item in batch:
        for key, value in item.items():
            if isinstance(value, torch.Tensor):
                batched_data[key].append(value)
            elif isinstance(value, list):
                batched_data[key].append(torch.tensor(value))
            else:
                batched_data[key].append(value)
    for key in batched_data:
        if isinstance(batched_data[key][0], torch.Tensor):
            if all(tensor.shape == batched_data[key][0].shape for tensor in batched_data[key]):
                batched_data[key] = torch.stack(batched_data[key])
            else:
                batched_data[key] = pad_tensor(batched_data[key])
        else:
            batched_data[key] = torch.tensor(batched_data[key])

    return batched_data

class Attention(nn.Module):
    def __init__(self, input_dim, head=4):
        a=4
        b=8
        c=12
        super(Attention, self).__init__()
        self.bt_pre = nn.Linear(input_dim, c)
        self.bt_cur = nn.Linear(input_dim, c)
        self.fcc = nn.Sequential(
            GELU(),
            nn.Linear(b, head),
        )
        self.fcg = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.fcm = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.head = head
        self.dim = input_dim
        self.fuse = nn.Sequential(
            GELU(),
            nn.Linear(c*2, a),
        )
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(a*3, b),
        )

    def forward(self, prev, curr, edges):
        bt_pre = self.bt_pre(prev)
        bt_cur = self.bt_cur(curr)
        shape = (bt_pre.shape[0],bt_pre.shape[1],bt_cur.shape[1],bt_pre.shape[2])
        bt_pre = bt_pre.unsqueeze(2).expand(shape)
        bt_cur = bt_cur.unsqueeze(1).expand(shape)
        c = torch.cat((bt_pre, bt_cur), dim=-1)
        c = self.fuse(c)
        c[edges==0] = 0
        edges1, edges2 = torch.sum(edges, dim=1)+1, torch.sum(edges, dim=2)+1
        edges1[edges1==0] = 1
        edges2[edges2==0] = 1
        d = torch.sum(c, dim=1).unsqueeze(1).expand(c.shape) / edges1.unsqueeze(1).unsqueeze(-1)
        e = torch.max(c, dim=1)[0].unsqueeze(1).expand(c.shape)
        fused = self.ffn(torch.cat((c,d,e),dim=-1))
        c = self.fcc(fused).transpose(-2,-3)
        mask = edges.transpose(-1,-2)
        res_list = []
        for i in range(self.head):
            c_true = c[:,:,:,i].clone()
            if c_true.shape[2] != 1:
                c_true[mask==0] = -9e15
                c_true = F.softmax(c_true, dim=2).clone()
            else:
                c_true[mask==0] = 0
            res = torch.bmm(c_true, prev[:,:,self.dim//self.head*i:self.dim//self.head*(i+1)])
            res_list.append(res)
        res = torch.cat(res_list, dim=-1)
        fused = fused.clone()
        fused[edges==0] = 0
        res = res + self.fcg(torch.sum(fused, dim=1) / edges1.unsqueeze(-1)) + self.fcm(torch.max(fused, dim=1)[0]/2)
        return res

class SignFlowLayer(nn.Module):
    def __init__(self, input_dim, output_dim=None):
        super(SignFlowLayer, self).__init__()
        if output_dim is None:
            output_dim = input_dim
        self.weight_curr = nn.Parameter(torch.Tensor(input_dim, output_dim))
        nn.init.normal_(self.weight_curr, mean=0, std=0.1)
        self.fcp = nn.Linear(input_dim, output_dim)
        self.fcn = nn.Linear(input_dim, output_dim)
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(input_dim, input_dim//2),
            GELU(),
            nn.Linear(input_dim//2, output_dim),
        )
        self.attn = Attention(input_dim)
        self.lnrom = nn.LayerNorm(input_dim)

    def forward(self, prev_layer_features, current_layer_features, edges):
        if edges.shape[1] == 0 or edges.shape[2] == 0:
            return current_layer_features
        positive_features = self.attn(prev_layer_features, current_layer_features, edges==1)
        negative_features = self.attn(prev_layer_features, current_layer_features, edges==-1)
        transformed_agg_features = self.fcp(positive_features) - self.fcn(negative_features)
        current_layer_features = torch.matmul(current_layer_features, self.weight_curr) + transformed_agg_features
        current_layer_features = self.lnrom(current_layer_features)

        return self.ffn(current_layer_features) + current_layer_features
    
class SubGraphLayer(nn.Module):
    def __init__(self, input_dim):
        super(SubGraphLayer, self).__init__()
        self.attns = SignFlowLayer(input_dim)
        self.attnt = SignFlowLayer(input_dim)
        self.attnm1 = SignFlowLayer(input_dim)
        self.attnm2 = SignFlowLayer(input_dim)
        self.attnm3 = SignFlowLayer(input_dim)
        self.attnsc = SignFlowLayer(input_dim)
        self.fc = nn.Linear(input_dim, input_dim)
    
    def forward(self, embed_0, embed_1, embed_2, embed_3, embed_4, embed_5, sub_01, sub_41, sub_45, sub_21, sub_23, sub_03, sub_25, sub_43):
        sub_14 = sub_41.transpose(1,2)
        sub_12 = sub_21.transpose(1,2)
        sub_34 = sub_43.transpose(1,2)
        sub_15 = sub_01.transpose(1,2).clone()
        sub_35 = sub_03.transpose(1,2).clone()
        sub_250 = sub_25.clone()
        sub_15[sub_15!=0] = 1
        sub_35[sub_35!=0] = 1
        sub_250[sub_250!=0] = 1
        x0 = embed_0
        x1 = self.attns(x0, embed_1, sub_01)
        x2 = self.attnm1(x1, embed_2, sub_12)
        x3 = self.attnm2(x2, embed_3, sub_23) + self.attns(x0, embed_3, sub_03)
        x4 = self.attnm3(x3, embed_4, sub_34) + self.attnm1(x1, embed_4, sub_14)
        x5 = self.attnt(x4, embed_5, sub_45) + self.attnt(x2, embed_5, sub_25) + self.attnsc(x1, embed_5, sub_15) + self.attnsc(x3, embed_5, sub_35)# + self.attn250(x2, embed_5, sub_250)
        return x5

class SBSN(nn.Module):
    def __init__(self, n, m, emb_size=32):
        super(SBSN, self).__init__()

        if not args.zero_shot:
            self.features_a = nn.Parameter(torch.randn((n, emb_size)), requires_grad=True)
            self.features_b = nn.Parameter(torch.randn((m, emb_size)), requires_grad=True)
        self.sub = SubGraphLayer(emb_size)
        self.fcx = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.fcy = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.B = nn.Linear(emb_size, emb_size)
        self.CX = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size*5, emb_size)
        )
        self.C = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size, emb_size//4),
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size//4, 1),
        )
        self.emb_size = emb_size

    def get_embeddings(self, a, b, end=0):
        emb_a = torch.zeros((*a.shape, self.emb_size), device=a.device)
        emb_b = torch.zeros((*b.shape, self.emb_size), device=b.device)
        emb_a = self.fcx[end](emb_a)
        emb_b = self.fcy[end](emb_b)
        return emb_a, emb_b
    
    def forward(self, l0, l1, l2, l3, l4, l5, sub_01, sub_41, sub_45, sub_21, sub_23, sub_03, sub_25, sub_43, **kwargs):
        embed_0, embed_1 = self.get_embeddings(l0, l1, end=0)
        embed_2, embed_3 = self.get_embeddings(l2, l3, end=1)
        embed_4, embed_5 = self.get_embeddings(l4, l5, end=2)
        x = self.sub(embed_0, embed_1, embed_2, embed_3, embed_4, embed_5, sub_01, sub_41, sub_45, sub_21, sub_23, sub_03, sub_25, sub_43)
        y = self.sub(embed_5, embed_4, embed_3, embed_2, embed_1, embed_0, sub_45.transpose(1,2), sub_41.transpose(1,2), sub_01.transpose(1,2), sub_43.transpose(1,2), sub_23.transpose(1,2), sub_25.transpose(1,2), sub_03.transpose(1,2), sub_21.transpose(1,2))
        fuse = torch.cat((self.B(x)*y,x,y,embed_0,embed_5),dim=-1)
        fuse = self.CX(fuse)
        fuse = F.normalize(fuse, p=2, dim=-1)*4
        fuse = self.C(fuse).squeeze(-1).squeeze(-1)
        return torch.sigmoid(fuse)

    def loss(self, pred_y, y):
        assert y.min() >= 0, 'must 0~1'
        assert pred_y.size() == y.size(), 'must be same length'
        pos_ratio = y.sum() /  y.size()[0]
        weight = torch.where(y > 0.5, 1./pos_ratio, 1./(1-pos_ratio))
        acc = sum((y>0.5) == (pred_y>0.5))/y.shape[0]
        return F.binary_cross_entropy(pred_y, y, weight=weight), acc

def load_data(dataset_name):
    train_file_path = os.path.join(data_prefix, f'{dataset_name}_training.txt')
    val_file_path = os.path.join(data_prefix, f'{dataset_name}_validation.txt')
    test_file_path = os.path.join(data_prefix, f'{dataset_name}_testing.txt')
    if not os.path.exists(val_file_path):
        os.rename(os.path.join(data_prefix, f'{dataset_name}_val.txt'),val_file_path)
    if not os.path.exists(test_file_path):
        os.rename(os.path.join(data_prefix, f'{dataset_name}_test.txt'),test_file_path)

    train_edgelist = []
    with open(train_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            train_edgelist.append((a, b, s))

    val_edgelist = []
    with open(val_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            val_edgelist.append((a, b, s))

    test_edgelist = []
    with open(test_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            test_edgelist.append((a, b, s))

    return np.array(train_edgelist), np.array(val_edgelist), np.array(test_edgelist)

def power(tensor,exponent):
    return torch.sign(tensor) * torch.pow(torch.abs(tensor), exponent)

@torch.no_grad()
def test_and_val(dataloader, model, mode='val', epoch=0):
    model.eval()
    preds_ys, ys = [], []
    for data in dataloader:
        d = {key: value.cuda() for key, value in data.items()}
        pred = model(**d)
        preds_ys.append(pred)
        ys.append(d['edge_s'].float())
    preds = torch.cat(preds_ys, dim=-1).cpu().numpy()
    y = torch.cat(ys, dim=-1).cpu().numpy()
    preds[preds >= 0.5]  = 1
    preds[preds < 0.5] = 0
    auc = roc_auc_score(y, preds)
    f1 = f1_score(y, preds)
    macro_f1 = f1_score(y, preds, average='macro')
    micro_f1 = f1_score(y, preds, average='micro')
    res = {
        f'{mode}_auc': auc,
        f'{mode}_f1' : f1,
        f'{mode}_mac' : macro_f1,
        f'{mode}_mic' : micro_f1,
    }
    for k, v in res.items():
        mode ,_, metric = k.partition('_')
        # if not args.eval:
        #     tb_writer.add_scalar(f'{metric}/{mode}', v, epoch)
    return res

def format_res(data, p=1):
    s = {k:f"{v*100:.{p}f}" for k,v in data.items()}
    s = str(s).strip('{}').replace("'",'')
    return s

def split(edgelist, u, v, n, m):
    mask1 = np.isin(edgelist[:, 0], u) & np.isin(edgelist[:, 1], v)
    mask2 = ~np.isin(edgelist[:, 0], u) & ~np.isin(edgelist[:, 1], v)
    edgelist1 = edgelist[mask1]
    edgelist2 = edgelist[mask2]

    sorted_u = np.sort(u)
    sorted_v = np.sort(v)
    edgelist1[:, 0] = [np.where(sorted_u == x)[0][0] for x in edgelist1[:, 0]]
    edgelist1[:, 1] = [np.where(sorted_v == x)[0][0] for x in edgelist1[:, 1]]

    u_complement = np.setdiff1d(np.arange(n), u)
    v_complement = np.setdiff1d(np.arange(m), v)
    sorted_u_complement = np.sort(u_complement)
    sorted_v_complement = np.sort(v_complement)
    edgelist2[:, 0] = [np.where(sorted_u_complement == x)[0][0] for x in edgelist2[:, 0]]
    edgelist2[:, 1] = [np.where(sorted_v_complement == x)[0][0] for x in edgelist2[:, 1]]
    return edgelist1, edgelist2

def run(j):
    global edges
    train_edgelist, val_edgelist, test_edgelist = load_data(args.dataset_name)
    train_edgelist = train_edgelist[:int(len(train_edgelist)*(j+1)/10)]

    set_a_num, set_b_num = DATA_EMB_DIC[args.dataset_name]
    print(set_a_num, set_b_num)
    edges = Edge(train_edgelist, set_a_num, set_b_num)

    model = SBSN(n=set_a_num, m=set_b_num, emb_size=args.emb_size)
    if args.eval:
        setup_seed(args.seed)
        checkpoint = torch.load(args.ckpt, weights_only=True)
        if args.zero_shot:
            filtered_checkpoint = {k: v for k, v in checkpoint.items() if 'feature' not in k}
            model.load_state_dict(filtered_checkpoint, strict=False)
        else:
            model.load_state_dict(checkpoint)
        model = model.cuda()
        dataset_test = GraphDataset(test_edgelist, test=1, edge=edges)
        dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
        res = test_and_val(dataloader_test, model, mode='test', epoch=args.epoch)
        print(format_res(res))
        return
    model = model.cuda()

    params = [[], []]
    for name, param in model.named_parameters():
        if name.startswith('features'):
            params[0].append(param)
        else:
            params[1].append(param)
    param_groups = [
        {'params': params[0], 'lr': args.lr*1},
        {'params': params[1], 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(param_groups, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.3)
    scheduler_slow = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998)
    dataset = GraphDataset(train_edgelist, edge=edges)
    dataset_val = GraphDataset(val_edgelist, test=1, edge=edges)
    dataset_test = GraphDataset(test_edgelist, test=1, edge=edges)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, shuffle=True, pin_memory=True)
    dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
    dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True)
    
    epoch = 0
    res_best = dict()
    best_auc = 0.0
    all_loss = 0.0
    tm = time.time()
    while epoch < args.epoch:
        for data in dataloader:
            data = {key: value.cuda() for key, value in data.items()}
            model.train()
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                pred_y = model(**data)
                train_y = data['edge_s'].float()
                loss = model.loss(pred_y, train_y)[0]
                all_loss += loss.item()
                # tb_writer.add_scalar('loss0/train', float(loss), epoch)

                loss.backward()
                optimizer.step()
            
            epoch += 1
            # if epoch%100 == 0:
            #     print(epoch, (time.time()-tm)/3600)
            if epoch%args.val_epoch == 0:
                model.eval()
                with torch.set_grad_enabled(False):
                    val_res = test_and_val(dataloader_val, model, mode='val', epoch=epoch)
                    if val_res['val_auc']>best_auc:
                        best_auc = val_res['val_auc']
                        res_best.update(val_res)
                        # torch.save(model.state_dict(), f'./res/{args.dataset_name}.pth')

                        test_res = test_and_val(dataloader_test, model, mode='test', epoch=epoch)
                        res_best.update(test_res)
                        print(f'\r\033[Kepoch: {epoch}\t{format_res(test_res)}')
                        with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
                            fh.write(f'{epoch}\t{format_res(res_best,p=4)}\tloss: {all_loss:.4f}\n')
                    print(f'\r\033[Kepoch: {epoch}\tval_auc: {val_res["val_auc"]*100:.1f}\tbest_auc: {best_auc*100:.1f}\tloss: {all_loss:.4f}', end='')
                all_loss = 0.0

            if epoch in [20000, 40000]:
                scheduler.step()
            
            if epoch >= args.epoch:
                break
    print()

if __name__ == "__main__":
    dataset_name = args.dataset_name
    if args.eval:
        for i in range(0,2):
            args.dataset_name = dataset_name+'-'+str(i+1)
            run()
        exit()
    if args.comment:
        comment = args.comment
    else:
        comment = str(int(time.time())%100000000)
    print(comment)
    with open(f'./code/{dataset_name}_{comment}.py','a') as fh:
        with open(__file__,'r') as f:
            fh.write(f.read())
            fh.write('\n\n\n\n\n\n')
    for i in range(0,5):
        for j in range(10):
            hyper_params = dict(vars(args))
            del hyper_params['device']
            hyper_params = "~".join([f"{k}-{v}" for k,v in hyper_params.items()])
            # tb_writer = SummaryWriter(log_dir=f'./logs/{hyper_params}')
            with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
                if i==0 and j==0:
                    fh.write(hyper_params)
                fh.write(f'\n{i}\n')
            print(f'training: {i}')
            
            setup_seed(args.seed)
            args.dataset_name = dataset_name+'-'+str(i+1)
            run(j)