#-*- coding: utf-8 -*-
import os
import time
import math
import random
import argparse
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--dataset_name', type=str, default='house1to10')
parser.add_argument('--emb_size', type=int, default=32, help='Embeding Size')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight Decay')
parser.add_argument('--lr', type=float, default=0.002, help='Learning Rate')
parser.add_argument('--seed', type=int, default=222, help='Random seed')
parser.add_argument('--epoch', type=int, default=80000, help='Epoch')
parser.add_argument('--batch_size', type=int, default=64, help='Batch Size')
parser.add_argument('--neighbours', type=int, default=400, help='Max Neighbours')
parser.add_argument('--val_epoch', type=int, default=50, help='Epoch Interval For Validation')
parser.add_argument('--zero_shot', action='store_true', help='Zero Shot Learning')
parser.add_argument('--ckpt', type=str, help='Checkpoint path')
parser.add_argument('--eval', action='store_true', help='Evaluation')
parser.add_argument('--comment', type=str, default='')
args = parser.parse_args()

data_prefix = 'experiments-data'

if 'cuda' in args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device[-1]

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import Edge, sub_edge, GELU
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import f1_score, roc_auc_score

# from torch.utils.tensorboard import SummaryWriter

def setup_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

from common import DATA_EMB_DIC

def find_neighbours(edges, node_in, node_out, left=1, signal=0):
    neighbour = set()
    ave = args.neighbours//len(node_in)
    if ave < 4:
        ave = 4
    random.shuffle(node_in)
    node_out = set(node_out)
    for node in node_in:
        if left:
            st = edges.edge_abp[node] + edges.edge_abn[node]
        else:
            st = edges.edge_bap[node] + edges.edge_ban[node]
        # samp = np.array(random.sample(st, min(len(st), ave*4)))
        if not len(st):
            return []
        samp = np.array(st)
        if left:
            deg = edges.degb[samp]
        else:
            deg = edges.dega[samp]
        if signal == 0:
            deg = np.random.permutation(deg)
        elif signal == 1:
            deg = -deg
        deg_idx = np.argsort(deg)[::-1][:ave]
        samp = samp[deg_idx].tolist()
        res = list(set(samp) - node_out)
        neighbour.update(res)
        if len(neighbour) >= args.neighbours:
            break
    return list(neighbour)[:args.neighbours]

class GraphDataset(Dataset):
    def __init__(self, val, signal, test=False, edge=None):
        self.val = Edge(val)
        print(len(self.val.edge_set))
        self.test = test
        self.signal = signal
        if edge is None:
            self.edges = edges
        else:
            self.edges = edge

    def __len__(self):
        return len(self.val.edge_list)
        # return (self.edges.m+self.edges.n)*self.an + sum(len(x) for x in self.extra)
        return len(self.extra[0])#self.edges.n*self.edges.m

    def __getitem__(self, index):
        # left, right, s = [index//self.edges.m], [index%self.edges.m], 0
        edge = self.val.edge_list[index]
        left, right, s = [edge[0]], [edge[1]], edge[2]==1
        # if index < self.edges.n*self.an:
        #     left, right, s = [index//self.an], [self.anchors[1][index%self.an]], 0
        # elif index < self.edges.n*self.an+self.edges.m*self.an:
        #     index = index - self.edges.n*self.an
        #     left, right, s = [self.anchors[0][index%self.an]], [index//self.an], 0
        # elif index < self.edges.n*self.an+self.edges.m*self.an+len(self.extra[0]):
        #     index = index - self.edges.n*self.an-self.edges.m*self.an
        #     left, right, s = self.extra[0][index]
        #     left, right = [left], [right]
        # elif index < self.edges.n*self.an+self.edges.m*self.an+len(self.extra[0])+len(self.extra[1]):
        #     index = index - self.edges.n*self.an-self.edges.m*self.an-len(self.extra[0])
        #     left, right, s = self.extra[1][index]
        #     left, right = [left], [right]
        # else:
        #     index = index - self.edges.n*self.an-self.edges.m*self.an-len(self.extra[0])-len(self.extra[1])
        #     left, right, s = self.extra[2][index]
        #     left, right = [left], [right]
        left_n = find_neighbours(self.edges, left, right, left=1, signal=self.signal//3)
        right_n = find_neighbours(self.edges, right, left, left=0, signal=self.signal%3)
        sub_0 = sub_edge(left, left_n, self.edges)
        sub_1 = sub_edge(right_n, left_n, self.edges)
        sub_2 = sub_edge(right_n, right, self.edges)
        edge_s = int(s)
        return {
            'left': left,
            'right': right,
            'left_n': left_n,
            'right_n': right_n,
            'sub_0': sub_0,
            'sub_1': sub_1,
            'sub_2': sub_2,
            'edge_s': edge_s,
            'left_ns': len(left_n),
            'right_ns': len(right_n)
        }

def pad_tensor(batch):
    max_size = [max(tensor.size(dim) for tensor in batch) for dim in range(batch[0].dim())]
    batch_size = len(batch)
    background = torch.zeros([batch_size] + max_size, dtype=batch[0].dtype)
    for i, tensor in enumerate(batch):
        indices = tuple(slice(0, sz) for sz in tensor.size())
        background[i][indices] = tensor
    return background

def collate_fn(batch):
    batched_data = {key: [] for key in batch[0]}
    for item in batch:
        for key, value in item.items():
            if isinstance(value, torch.Tensor):
                batched_data[key].append(value)
            elif isinstance(value, list):
                batched_data[key].append(torch.tensor(value))
            else:
                batched_data[key].append(value)
    for key in batched_data:
        if isinstance(batched_data[key][0], torch.Tensor):
            if all(tensor.shape == batched_data[key][0].shape for tensor in batched_data[key]):
                batched_data[key] = torch.stack(batched_data[key])
            else:
                batched_data[key] = pad_tensor(batched_data[key])
        else:
            batched_data[key] = torch.tensor(batched_data[key])

    return batched_data

class Attention(nn.Module):
    def __init__(self, input_dim, head=4):
        a=4
        b=8
        c=12
        super(Attention, self).__init__()
        self.bt_pre = nn.Linear(input_dim, c)
        self.bt_cur = nn.Linear(input_dim, c)
        self.fcc = nn.Sequential(
            GELU(),
            nn.Linear(b, head),
        )
        self.fcg = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.fcm = nn.Sequential(
            GELU(),
            nn.Linear(b, input_dim),
        )
        self.head = head
        self.dim = input_dim
        self.fuse = nn.Sequential(
            GELU(),
            nn.Linear(c*2, a),
        )
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(a*3, b),
        )

    def forward(self, prev, curr, edges):
        bt_pre = self.bt_pre(prev)
        bt_cur = self.bt_cur(curr)
        shape = (bt_pre.shape[0],bt_pre.shape[1],bt_cur.shape[1],bt_pre.shape[2])
        bt_pre = bt_pre.unsqueeze(2).expand(shape)
        bt_cur = bt_cur.unsqueeze(1).expand(shape)
        c = torch.cat((bt_pre, bt_cur), dim=-1)
        c = self.fuse(c)
        c[edges==0] = 0
        edges1, edges2 = torch.sum(edges, dim=1)+1, torch.sum(edges, dim=2)+1
        edges1[edges1==0] = 1
        edges2[edges2==0] = 1
        d = torch.sum(c, dim=1).unsqueeze(1).expand(c.shape) / edges1.unsqueeze(1).unsqueeze(-1)
        e = torch.max(c, dim=1)[0].unsqueeze(1).expand(c.shape)
        fused = self.ffn(torch.cat((c,d,e),dim=-1))
        c = self.fcc(fused).transpose(-2,-3)
        mask = edges.transpose(-1,-2)
        res_list = []
        for i in range(self.head):
            c_true = c[:,:,:,i].clone()
            if c_true.shape[2] != 1:
                c_true[mask==0] = -9e15
                c_true = F.softmax(c_true, dim=2).clone()
            else:
                c_true[mask==0] = 0
            res = torch.bmm(c_true, prev[:,:,self.dim//self.head*i:self.dim//self.head*(i+1)])
            res_list.append(res)
        res = torch.cat(res_list, dim=-1)
        fused = fused.clone()
        fused[edges==0] = 0
        res = res + self.fcg(torch.sum(fused, dim=1) / edges1.unsqueeze(-1)) + self.fcm(torch.max(fused, dim=1)[0]/2)
        return res

class SignFlowLayer(nn.Module):
    def __init__(self, input_dim, output_dim=None):
        super(SignFlowLayer, self).__init__()
        if output_dim is None:
            output_dim = input_dim
        self.weight_curr = nn.Parameter(torch.Tensor(input_dim, output_dim))
        nn.init.normal_(self.weight_curr, mean=0, std=0.1)
        self.fcp = nn.Linear(input_dim, output_dim)
        self.fcn = nn.Linear(input_dim, output_dim)
        self.ffn = nn.Sequential(
            GELU(),
            nn.Linear(input_dim, input_dim//2),
            GELU(),
            nn.Linear(input_dim//2, output_dim),
        )
        self.attn = Attention(input_dim)
        self.lnrom = nn.LayerNorm(input_dim)

    def forward(self, prev_layer_features, current_layer_features, edges):
        positive_features = self.attn(prev_layer_features, current_layer_features, edges==1)
        negative_features = self.attn(prev_layer_features, current_layer_features, edges==-1)
        transformed_agg_features = self.fcp(positive_features) - self.fcn(negative_features)
        current_layer_features = torch.matmul(current_layer_features, self.weight_curr) + transformed_agg_features
        current_layer_features = self.lnrom(current_layer_features)

        return self.ffn(current_layer_features) + current_layer_features
    
class SubGraphLayer(nn.Module):
    def __init__(self, input_dim):
        super(SubGraphLayer, self).__init__()
        self.attn1 = SignFlowLayer(input_dim)
        self.attn2 = SignFlowLayer(input_dim)
        self.attn3 = SignFlowLayer(input_dim)
        self.attnq = SignFlowLayer(input_dim)
        self.fc = nn.Linear(input_dim, input_dim)
        self.pe = self.positional_encoding(args.neighbours, input_dim) * 0.01
        self.pe = nn.Parameter(self.pe, requires_grad=False)
        
    def positional_encoding(self, num_nodes, embedding_dim):
        pe = torch.zeros(num_nodes, embedding_dim)
        position = torch.arange(0, num_nodes).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * -(math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe
    
    def forward(self, emb_a, emb_an, emb_bn, emb_b, sub_0, sub_1, sub_2):
        sub_1 = sub_1.transpose(1,2)
        sub_11 = sub_2.transpose(1,2).clone()
        sub_11[sub_11!=0] = 1
        sub_22 = sub_0.transpose(1,2).clone()
        sub_22[sub_22!=0] = 1
        x0 = emb_a
        indices = torch.randperm(self.pe.size(0))
        shuffled_tensor = self.pe[indices]
        x1 = self.attn1(x0, emb_bn, sub_0)# + shuffled_tensor[:emb_a.shape[1]].unsqueeze(0).expand(emb_a.shape[0],-1,-1)
        x2 = self.attn2(x1, emb_an, sub_1)
        x3 = self.attn3(x2, emb_b, sub_2) + self.fc(self.attnq(x1, emb_b, sub_22))
        return x3

class SBSN(nn.Module):
    def __init__(self, n, m, emb_size=32):
        super(SBSN, self).__init__()

        self.features_a = nn.Parameter(torch.randn((n, emb_size)), requires_grad=True)
        self.features_b = nn.Parameter(torch.randn((m, emb_size)), requires_grad=True)
        self.sub = SubGraphLayer(emb_size)
        self.fcx = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.fcy = nn.ModuleList([nn.Linear(emb_size, emb_size),nn.Linear(emb_size, emb_size)])
        self.B = nn.Linear(emb_size, emb_size)
        self.CX = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size*5, emb_size)
        )
        self.C = nn.Sequential(
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size, emb_size//4),
            GELU(),
            nn.Dropout(0.2),
            nn.Linear(emb_size//4, 1),
        )
        self.emb_size = emb_size

    def get_embeddings(self, a, b, epoch, end=0):
        # if args.zero_shot:
        #     emb_a = torch.randn((*a.shape, self.emb_size), device=a.device)
        #     emb_b = torch.randn((*b.shape, self.emb_size), device=b.device)
        # else:
        emb_a = self.features_a[a.long()]
        emb_b = self.features_b[b.long()]
        if end:
            emb_a, emb_b = emb_a.detach(), emb_b.detach()
        if epoch < 40000:
            epoch = 0.0
        else:
            epoch = min(1.0, (epoch-40000)/10000)
        emb_a = emb_a * epoch
        emb_b = emb_b * epoch
        emb_a = self.fcx[end](emb_a)
        emb_b = self.fcy[end](emb_b)
        return emb_a, emb_b

    def forward(self, left, left_n, right_n, right, sub_0, sub_1, sub_2, epoch, **kwargs):
        embed_a, embed_b = self.get_embeddings(left, right, epoch, end=1)
        embed_an, embed_bn = self.get_embeddings(right_n, left_n, epoch)
        x = self.sub(embed_a, embed_an, embed_bn, embed_b, sub_0, sub_1, sub_2)
        y = self.sub(embed_b, embed_bn, embed_an, embed_a, sub_2.transpose(1,2), sub_1.transpose(1,2), sub_0.transpose(1,2))
        fuse = torch.cat((self.B(x)*y,x,y,embed_a,embed_b),dim=-1)
        fuse = self.CX(fuse)
        fuse = self.C(fuse).squeeze(-1).squeeze(-1)
        return torch.sigmoid(fuse)

    def loss(self, pred_y, y):
        assert y.min() >= 0, 'must 0~1'
        assert pred_y.size() == y.size(), 'must be same length'
        pos_ratio = y.sum() /  y.size()[0]
        weight = torch.where(y > 0.5, 1./pos_ratio, 1./(1-pos_ratio))
        acc = sum((y>0.5) == (pred_y>0.5))/y.shape[0]
        return F.binary_cross_entropy(pred_y, y, weight=weight), acc

    def forward_f(self, left, left_n, right_n, right, sub_0, sub_1, sub_2, mode, **kwargs):
        embed_a, embed_b = self.get_embeddings(left, right, 100000, end=1)
        embed_an, embed_bn = self.get_embeddings(right_n, left_n, 100000)
        x = self.sub(embed_a, embed_an, embed_bn, embed_b, sub_0, sub_1, sub_2)
        y = self.sub(embed_b, embed_bn, embed_an, embed_a, sub_2.transpose(1,2), sub_1.transpose(1,2), sub_0.transpose(1,2))
        fuse = torch.cat((self.B(x)*y,x,y,embed_a,embed_b),dim=-1)
        fuse = self.CX(fuse)
        fuse = self.C(fuse).squeeze(-1).squeeze(-1)
        return fuse, x, y

def load_data(dataset_name):
    train_file_path = os.path.join(data_prefix, f'{dataset_name}_training.txt')
    val_file_path = os.path.join(data_prefix, f'{dataset_name}_validation.txt')
    test_file_path = os.path.join(data_prefix, f'{dataset_name}_testing.txt')
    # if not os.path.exists(val_file_path):
    #     os.rename(os.path.join(data_prefix, f'{dataset_name}_val.txt'),val_file_path)
    # if not os.path.exists(test_file_path):
    #     os.rename(os.path.join(data_prefix, f'{dataset_name}_test.txt'),test_file_path)

    train_edgelist = []
    with open(train_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            train_edgelist.append((a, b, s))

    val_edgelist = []
    with open(val_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            val_edgelist.append((a, b, s))

    test_edgelist = []
    with open(test_file_path) as f:
        for ind, line in enumerate(f):
            if ind == 0: continue
            a, b, s = map(int, line.split('\t'))
            test_edgelist.append((a, b, s))

    return np.array(train_edgelist), np.array(val_edgelist), np.array(test_edgelist)

def power(tensor,exponent):
    return torch.sign(tensor) * torch.pow(torch.abs(tensor), exponent)

@torch.no_grad()
def test_and_val(dataloader, model, mode='val', epoch=0):
    model.eval()
    preds_ys, ys = [], []
    for data in dataloader:
        d = {key: value.cuda() for key, value in data.items()}
        pred = model(**d)
        preds_ys.append(pred)
        ys.append(d['edge_s'].float())
    preds = torch.cat(preds_ys, dim=-1).cpu().numpy()
    y = torch.cat(ys, dim=-1).cpu().numpy()
    preds[preds >= 0.5]  = 1
    preds[preds < 0.5] = 0
    auc = roc_auc_score(y, preds)
    f1 = f1_score(y, preds)
    macro_f1 = f1_score(y, preds, average='macro')
    micro_f1 = f1_score(y, preds, average='micro')
    res = {
        f'{mode}_auc': auc,
        f'{mode}_f1' : f1,
        f'{mode}_mac' : macro_f1,
        f'{mode}_mic' : micro_f1,
    }
    for k, v in res.items():
        mode ,_, metric = k.partition('_')
        # if not args.eval:
        #     tb_writer.add_scalar(f'{metric}/{mode}', v, epoch)
    return res

def format_res(data, p=1):
    s = {k:f"{v*100:.{p}f}" for k,v in data.items()}
    s = str(s).strip('{}').replace("'",'')
    return s

def split(edgelist, u, v, n, m):
    mask1 = np.isin(edgelist[:, 0], u) & np.isin(edgelist[:, 1], v)
    mask2 = ~np.isin(edgelist[:, 0], u) & ~np.isin(edgelist[:, 1], v)
    edgelist1 = edgelist[mask1]
    edgelist2 = edgelist[mask2]

    sorted_u = np.sort(u)
    sorted_v = np.sort(v)
    edgelist1[:, 0] = [np.where(sorted_u == x)[0][0] for x in edgelist1[:, 0]]
    edgelist1[:, 1] = [np.where(sorted_v == x)[0][0] for x in edgelist1[:, 1]]

    u_complement = np.setdiff1d(np.arange(n), u)
    v_complement = np.setdiff1d(np.arange(m), v)
    sorted_u_complement = np.sort(u_complement)
    sorted_v_complement = np.sort(v_complement)
    edgelist2[:, 0] = [np.where(sorted_u_complement == x)[0][0] for x in edgelist2[:, 0]]
    edgelist2[:, 1] = [np.where(sorted_v_complement == x)[0][0] for x in edgelist2[:, 1]]
    return edgelist1, edgelist2

def run(slice=1):
    global edges
    train_edgelist, val_edgelist, test_edgelist = load_data(args.dataset_name)
    # train_edgelist = train_edgelist[:10000]

    set_a_num, set_b_num = DATA_EMB_DIC[args.dataset_name]
    print(set_a_num, set_b_num)
    edges = Edge(train_edgelist, set_a_num, set_b_num)

    model = SBSN(n=set_a_num, m=set_b_num, emb_size=args.emb_size)
    setup_seed(args.seed)
    checkpoint = torch.load(f'./res/{args.dataset_name}_150000.pth', weights_only=True)
    model.load_state_dict(checkpoint, strict=False)
    model = model.cuda()
    # all_data = torch.zeros((set_a_num, set_b_num), dtype=torch.float)
    # all_data_xs = torch.zeros((set_a_num, set_b_num, args.emb_size), dtype=torch.float)
    # all_data_ys = torch.zeros((set_a_num, set_b_num, args.emb_size), dtype=torch.float)
    all_emb = [[] for _ in range(9)]
    # all_dens = torch.zeros((set_a_num, set_b_num), dtype=torch.float)
    model.eval()
    with torch.no_grad():
        for _ in range(9):
            dataset = GraphDataset(train_edgelist, edge=edges, signal=_)
            dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=16, pin_memory=True, shuffle=False)
            tm = time.time()
            cc = 0
            for data in dataloader:
                data = {key: value.cuda() for key, value in data.items()}
                emb, xs, ys = model.forward_f(**data, mode=_)
                for l,r,e,x,y in zip(data['left'], data['right'], emb, xs, ys):
                    # print(l,r)
                    # if edges.dega[l] >= 30 and edges.degb[r] >= 30:
                    # all_data[l,r] = e.squeeze().cpu()
                    # all_data_xs[l,r] = x.squeeze().cpu()
                    # all_data_ys[l,r] = y.squeeze().cpu()
                    all_emb[_].append(torch.cat((x.squeeze().cpu(),y.squeeze().cpu()),dim=-1))
            print(time.time()-tm)
            all_emb[_] = torch.stack(all_emb[_])
            torch.save({"all_emb":all_emb[_].detach(), "feature_a":model.features_a.cpu().detach(), "feature_b":model.features_b.cpu().detach()}, f'./res/{args.dataset_name}_emb.pth')
            exit()
    all_emb = torch.stack(all_emb)
    torch.save({"all_emb":all_emb.detach(), "feature_a":model.features_a.cpu().detach(), "feature_b":model.features_b.cpu().detach()}, f'./res/{args.dataset_name}_emb.pth')

if __name__ == "__main__":
    dataset_name = args.dataset_name
    if args.eval:
        for i in range(0,2):
            args.dataset_name = dataset_name+'-'+str(i+1)
            run()
        exit()
    if args.comment:
        comment = args.comment
    else:
        comment = str(int(time.time())%100000000)
    print(comment)
    with open(f'./code/{dataset_name}_{comment}.py','a') as fh:
        with open(__file__,'r') as f:
            fh.write(f.read())
            fh.write('\n\n\n\n\n\n')
    for i in range(0,5):
        hyper_params = dict(vars(args))
        del hyper_params['device']
        hyper_params = "~".join([f"{k}-{v}" for k,v in hyper_params.items()])
        # tb_writer = SummaryWriter(log_dir=f'./logs/{hyper_params}')
        with open(f'./txt/{dataset_name}_{comment}.txt','a') as fh:
            if i==0:
                fh.write(hyper_params)
            fh.write(f'\n{i}\n')
        print(f'training: {i}')
        
        setup_seed(args.seed)
        args.dataset_name = dataset_name+'-'+str(i+1)
        run(i)