import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential
import numpy as np
import pickle

class AugClassifier(torch.nn.Module):
    def __init__(self, hidden, num_classes) -> None:
        super().__init__()

        self.lin1 = Linear((num_classes+1)*hidden//2, (num_classes+1)*hidden//2)
        self.lin2 = Linear((num_classes+1)*hidden//2, num_classes)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

class STNet(torch.nn.Module):
    def __init__(self, n_subtree, classifier) -> None:
        super().__init__()

        self.classifier = classifier
        for clf in self.classifier:
            clf.eval()

        self.weight = torch.nn.Parameter(torch.rand(1,n_subtree)-0.5)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform(self.weight)

    def get_weighted_emb(self, clus_emb, mask):
        masked_weight = self.weight*mask
        out = torch.zeros(mask.shape[0],clus_emb.shape[1]).to(masked_weight.device)
        for i in range(out.shape[0]):
            out[i,:masked_weight.shape[1]]=masked_weight[i]
        return out
        # _emb = torch.ones(clus_emb.shape).to(masked_weight.device)
        # out = torch.mm(masked_weight, _emb)
        out = torch.mm(masked_weight, clus_emb)
        return out

    def _forward(self, clus_emb):
        out = torch.mm(self.weight, clus_emb)

        for clf in self.classifier[:-1]:
            out = F.relu(clf(out))
        out = self.classifier[-1](out)
        return F.log_softmax(out, dim=-1)
        return out

    def abs_forward(self, clus_emb, mask):
        masked_weight = self.weight*mask
        # masked_weight = torch.abs(self.weight)*mask
        out = torch.mm(masked_weight, clus_emb)

        for clf in self.classifier[:-1]:
            out = F.relu(clf(out))
        out = self.classifier[-1](out)
        return F.log_softmax(out, dim=-1)

    def forward(self, clus_emb, mask):
        masked_weight = self.weight*mask
        # masked_weight = F.relu(self.weight)*mask
        out = torch.mm(masked_weight, clus_emb)

        for clf in self.classifier[:-1]:
            out = F.relu(clf(out))
        out = self.classifier[-1](out)
        return F.log_softmax(out, dim=-1)

def eval_acc(masks, data, model, y, orig_logits=None): 

    model.eval()
    train_acc,train_prob = [],[]
    flag = True
    if orig_logits is None: flag = False
    class_0_prob,class_1_prob=[],[]
    for i, mask in enumerate(masks[0]): 
        # mask = mask.view(-1,1)
        # ipt = mask.repeat(1,data.shape[1])*data
        # out = model(ipt)
        mask = mask.view(1,-1)
        out = model.abs_forward(data, mask)[0]
        if flag: 
            fid = orig_logits[0][i][int(y[0][i])]-out[int(y[0][i])]
            # print(orig_logits[0][i], out)
            train_prob.append(float(fid))
            if int(y[0][i])==0:
                class_0_prob.append(float(fid))
            else: class_1_prob.append(float(fid))
        pred = int(torch.argmax(out))
        train_acc.append(1*(pred==int(y[0][i])))
    if flag: print(f"Train set: Class 0 ProbFidelity={(1e-5+sum(class_0_prob))/(1e-5+len(class_0_prob))}, Class 1 ProbFidelity={(1e-5+sum(class_1_prob))/(1e-5+len(class_1_prob))}")
    
    val_acc,val_prob = [], []
    class_0_prob,class_1_prob=[],[]
    for i, mask in enumerate(masks[1]): 
        # mask = mask.view(-1,1)
        # ipt = mask.repeat(1,data.shape[1])*data
        # out = model(ipt)
        mask = mask.view(1,-1)
        out = model.abs_forward(data, mask)[0]
        if flag: 
            fid = orig_logits[1][i][int(y[1][i])]-out[int(y[1][i])]
            val_prob.append(float(fid))
            if int(y[0][i])==0:
                class_0_prob.append(float(fid))
            else: class_1_prob.append(float(fid))
        pred = int(torch.argmax(out))
        val_acc.append(1*(pred==int(y[1][i])))
    if flag: print(f"Val set: Class 0 ProbFidelity={(1e-5+sum(class_0_prob))/(1e-5+len(class_0_prob))}, Class 1 ProbFidelity={(1e-5+sum(class_1_prob))/(1e-5+len(class_1_prob))}")
    
    test_acc,test_prob=[],[]
    class_0_prob,class_1_prob=[],[]
    for i, mask in enumerate(masks[2]): 
        # mask = mask.view(-1,1)
        # ipt = mask.repeat(1,data.shape[1])*data
        # out = model(ipt)
        mask = mask.view(1,-1)
        out = model.abs_forward(data, mask)[0]
        if flag: 
            fid = orig_logits[2][i][int(y[2][i])]-out[int(y[2][i])]
            test_prob.append(float(fid))
            if int(y[0][i])==0:
                class_0_prob.append(float(fid))
            else: class_1_prob.append(float(fid))
        pred = int(torch.argmax(out))
        test_acc.append(1*(pred==int(y[2][i])))
    if flag: print(f"Test set: Class 0 ProbFidelity={(1e-5+sum(class_0_prob))/(1e-5+len(class_0_prob))}, Class 1 ProbFidelity={(1e-5+sum(class_1_prob))/(1e-5+len(class_1_prob))}")
    
    if flag:
        return [sum(train_acc)/len(train_acc), sum(val_acc)/len(val_acc), sum(test_acc)/len(test_acc)], [sum(train_prob)/len(train_prob), sum(val_prob)/len(val_prob), sum(test_prob)/len(test_prob)]
    else:
        return sum(train_acc)/len(train_acc), sum(val_acc)/len(val_acc), sum(test_acc)/len(test_acc)

def eval_acc2(masks, data, st_model, model, y, ori_graph_embs): 

    model.eval()
    train_acc=[]
    for i, mask in enumerate(masks[0]): 
        mask = mask.view(1,-1)
        x=[]
        for j, st_cls_model in enumerate(st_model):
            x.append(st_cls_model.get_weighted_emb(data,mask).view(-1))
        x.append(ori_graph_embs[0][i].view(-1))
        cat_x = torch.cat(x,dim=-1)
        out = model(cat_x)
        pred = int(torch.argmax(out))
        train_acc.append(1*(pred==int(y[0][i])))
    val_acc=[]
    for i, mask in enumerate(masks[1]): 
        mask = mask.view(1,-1)
        x=[]
        for j, st_cls_model in enumerate(st_model):
            x.append(st_cls_model.get_weighted_emb(data,mask).view(-1))
        x.append(ori_graph_embs[1][i].view(-1))
        cat_x = torch.cat(x,dim=-1)
        out = model(cat_x)
        pred = int(torch.argmax(out))
        val_acc.append(1*(pred==int(y[1][i])))
    test_acc=[]
    for i, mask in enumerate(masks[2]): 
        mask = mask.view(1,-1)
        x=[]
        for j, st_cls_model in enumerate(st_model):
            x.append(st_cls_model.get_weighted_emb(data,mask).view(-1))
        x.append(ori_graph_embs[2][i].view(-1))
        cat_x = torch.cat(x,dim=-1)
        out = model(cat_x)
        pred = int(torch.argmax(out))
        test_acc.append(1*(pred==int(y[2][i])))
    return sum(train_acc)/len(train_acc), sum(val_acc)/len(val_acc), sum(test_acc)/len(test_acc)


def eval_reload(train_loader, val_loader, test_loader, data, model, args, t_cls, gnn_model): 
    model.eval()
    train_acc=[]
    for i, d in enumerate(train_loader): 
        d = d.to(device)
        logits = gnn_model(d)[0]
        if torch.argmax(logits) != int(d.y) or int(d.y)!=t_cls: continue 


def redetermine_pred(st_model,data,y,args,masks):
    # st_model.eval()
    train_acc=[]
    for i, mask in enumerate(masks[0]): 
        mask = mask.view(1,-1)
        x, rule_pred=[],[]
        for j, st_cls_model in enumerate(st_model):
            logits=torch.exp(st_cls_model(data,mask))
            rule_pred.append(int(torch.argmax(logits)))
            x.append(logits.view(-1)[rule_pred[j]])
        x = torch.stack(x)
        print(rule_pred,x,int(y[0][i]))
        if rule_pred[0]==rule_pred[1]: pred = int(rule_pred[0])
        else: pred = int(rule_pred[int(torch.argmax(x))])
        train_acc.append(1*(pred==int(y[0][i])))
    val_acc=[]
    for i, mask in enumerate(masks[1]): 
        mask = mask.view(1,-1)
        x, rule_pred=[],[]
        for j, st_cls_model in enumerate(st_model):
            logits=torch.exp(st_cls_model(data,mask))
            rule_pred.append(int(torch.argmax(logits)))
            x.append(logits.view(-1)[rule_pred[j]])
        x = torch.stack(x)
        if rule_pred[0]==rule_pred[1]: pred = int(rule_pred[0])
        else: pred = int(rule_pred[int(torch.argmax(x))])
        val_acc.append(1*(pred==int(y[1][i])))
    test_acc=[]
    for i, mask in enumerate(masks[2]): 
        mask = mask.view(1,-1)
        x, rule_pred=[],[]
        for j, st_cls_model in enumerate(st_model):
            logits=torch.exp(st_cls_model(data,mask))
            rule_pred.append(int(torch.argmax(logits)))
            x.append(logits.view(-1)[rule_pred[j]])
        x = torch.stack(x)
        if rule_pred[0]==rule_pred[1]: pred = int(rule_pred[0])
        else: pred = int(rule_pred[int(torch.argmax(x))])
        test_acc.append(1*(pred==int(y[2][i])))
    return sum(train_acc)/len(train_acc), sum(val_acc)/len(val_acc), sum(test_acc)/len(test_acc)


def train_aug(model,st_model,data,y,args,masks,ori_graph_embs):
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    path = "saved_models/"+f'aug_cls_{args.dataset}-{args.gnn}-clus{args.clusters}-stnum{len(data)}-{args.local_cluster}'+".model"

    [train_g_embs, val_g_embs, test_g_embs]=ori_graph_embs
    [y_train, y_val, y_test]=y
    [train_masks, val_masks, test_masks] = masks

    best_val_acc = 0
    best_loss = 9999
    for epoch in range(1, 5*args.epochs):
        tot_loss = 0
        queue, y_q=[],[]
        for i, mask in enumerate(train_masks):
            model.train()
            optimizer.zero_grad()
            mask = mask.view(1,-1)
            x=[]
            for j, st_cls_model in enumerate(st_model):
                x.append(st_cls_model.get_weighted_emb(data,mask).view(-1))
            x.append(train_g_embs[i].view(-1))
            queue.append(torch.cat(x,dim=-1).view(-1))
            y_q.append(y_train[i])
            if len(queue)>=args.batch_size:
                cat_x = torch.stack(queue)
                out = model(cat_x)
                loss = F.nll_loss(out, torch.stack(y_q).view(-1))
                loss.backward(retain_graph=True)
                optimizer.step()
                tot_loss+=loss.item()
                queue,y_q = [],[]
        train_acc, val_acc, test_acc = eval_acc2(masks, data, st_model, model, y, ori_graph_embs)
        print(f"*Aug, Epoch: {epoch} - traing loss: {tot_loss} - train_acc: {train_acc} - val_acc: {val_acc} - test_acc: {test_acc}")
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), path)
            # early stop
            if abs(tot_loss-best_loss)<1e-4 or (abs(tot_loss-best_loss)<1e-1 and train_acc==1.0 and val_acc==1.0): break
            best_loss = tot_loss
            print(" ^ NEW!!")

def sep_train_stnet(model, data, y, orig_logits, args, masks, tcls, f_tcls_ys=None, false_pred_mask=None, tcls_ys=None, tcls_orig_logits=None, tcls_masks=None):
# def sep_train_stnet(model, data, y, args, masks, train_loader, val_loader, test_loader, gnn_model):

    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    path = "saved_models/"+f'sep_st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-stnum{len(data)}-{args.local_cluster}-{args.lmda}'+".model"

    # len_all_samples = len(masks)
    # train_thres = int(len_all_samples*0.8)
    # val_thres = int(len_all_samples*0.9)
    # train_masks = masks[:train_thres]
    # val_masks = masks[train_thres:val_thres]
    # test_masks = masks[val_thres:]
    # y_train=y

    [y_train, y_val, y_test]=y
    [train_masks, val_masks, test_masks] = masks

    best_val_acc = 0
    best_loss = 9999
    for epoch in range(1, args.epochs):
        tot_loss = 0
        for i, mask in enumerate(train_masks):
            model.train()
            optimizer.zero_grad()
            # mask = mask.view(-1,1)
            # ipt = mask.repeat(1,data.shape[1])*data
            # out = model(ipt)
            mask = mask.view(1,-1)
            out = model(data, mask)
            if int(y_train[i])==tcls:
                loss = F.nll_loss(out, y_train[i].view(-1))+float(args.lmda)*torch.norm(model.weight,p=2)
            else:
                loss = F.nll_loss(out, ((y_train[i]+1)%2).view(-1))+float(args.lmda)*torch.norm(model.weight,p=2)
            # loss = F.nll_loss(out, y_train[i].view(-1))+float(args.lmda)*torch.norm(model.weight,p=2)-F.nll_loss(out, (y_train[i].view(-1)+1)%2)
            loss.backward(retain_graph=True)
            optimizer.step()
            tot_loss+=loss.item()

        if tcls_ys is None:
            [train_acc, val_acc, test_acc], [train_prob, val_prob, test_prob] = eval_acc(masks, data, model, y, orig_logits)
        else:
            [train_acc, val_acc, test_acc], [train_prob, val_prob, test_prob] = eval_acc(tcls_masks, data, model, tcls_ys, tcls_orig_logits)

        # train_acc, val_acc, test_acc = eval_reload(train_loader, val_loader, test_loader, data, model, args, int(max(y)), gnn_model)

        print(f"Epoch: {epoch} - traing loss: {tot_loss} - train_acc: {train_acc} - val_acc: {val_acc} - test_acc: {test_acc}")
        print(f"train_prob: {train_prob} - val_prob: {val_prob} - test_prob: {test_prob}")
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), path)
            # early stop
            if abs(tot_loss-best_loss)<5e-3 or (train_acc==1.0 and val_acc==1.0 and test_acc==1.0): break
            best_loss = tot_loss

    model.load_state_dict(torch.load(path))
    model.eval()
    if tcls_ys is None:
        [train_acc, val_acc, test_acc], [train_prob, val_prob, test_prob] = eval_acc(masks, data, model, y, orig_logits)
    else:
        [train_acc, val_acc, test_acc], [train_prob, val_prob, test_prob] = eval_acc(tcls_masks, data, model, tcls_ys, tcls_orig_logits)

    print(f"\n-------------------------\nFinally, traing loss: {tot_loss} - train_acc: {train_acc} - val_acc: {val_acc} - test_acc: {test_acc}\nnumber of train={len(masks[0])}, val={len(masks[1])}, test={len(masks[2])}")
    print(f"train_prob: {train_prob} - val_prob: {val_prob} - test_prob: {test_prob}")

    if false_pred_mask is not None:
        f_train_acc, f_val_acc, f_test_acc = eval_acc(false_pred_mask, data, model, f_tcls_ys, None)
        print(f'Falsely predicted data: train={len(f_tcls_ys[0])}, val={len(f_tcls_ys[1])}, test={len(f_tcls_ys[2])} new prediction acc: train_acc: {f_train_acc} - val_acc: {f_val_acc} - test_acc: {f_test_acc}')

    # return F.normalize(model.weight, p=1, dim=-1)

    tot_weights = []
    all_weights = []
    weights = model.weight
    for i, mask in enumerate(train_masks):
        model.eval()
        mask = mask.view(-1,1)
        all_weights.append(abs(mask.view(-1).detach()*weights.view(-1).detach()))
        tot_weights.append(sum(abs(mask.view(-1).detach()*weights.view(-1).detach())))
    
    from statistics import mean
    print(f'max-weight: {torch.max(torch.stack(tot_weights).view(-1))}, min_weights: {torch.min(torch.stack(tot_weights).view(-1))}, mean_weights: {torch.mean(torch.stack(tot_weights).view(-1))}')
    all_weights=torch.stack(all_weights)

    path2 = "saved_data/"+f'{args.dataset}-{args.gnn}-cls{int(max(y[0]))}-clus{args.clusters}-stnum{len(data)}-{args.local_cluster}-{args.lmda}-'
    with open(path2+'all_weights'+".pkl", 'wb') as handle:
        pickle.dump(all_weights, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return weights/torch.max(torch.abs(weights))

def train_stnet(model, data, y, args):
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    path = "saved_models/"+f'st_{args.dataset}-{args.gnn}-cls{int(y)}-clus{args.clusters}-stnum{len(data)}-{args.local_cluster}-{args.lmda}'+".model"

    best_loss = 9999
    for epoch in range(1, args.epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        # get_loss = torch.nn.MSELoss()
        # loss = get_loss(out, y)+0.2*torch.norm(model.weight,p=1)
        loss = F.nll_loss(out, y)+float(args.lmda)*torch.norm(model.weight,p=2)

        loss.backward(retain_graph=True)
        optimizer.step()
        loss= loss.item()

        if loss < best_loss:
            # print("Epoch",epoch, "- loss:", loss)
            best_loss = loss
            torch.save(model.state_dict(), path)
            # early stop
            if loss<5e-3: break
    model.load_state_dict(torch.load(path))
    model.eval()
    # return F.normalize(model.weight, p=1, dim=-1)
    return model.weight
    