from utils import *
import matplotlib.pyplot as plt
from models_all import wpm_net
import torch
import collections as cls
import numpy
from torch.utils.data import DataLoader
import collections
#hyperparameter
if __name__=="__main__":
    # numpy.random.seed(123)
    epsilon = 0.1
    lamda = 50
    lr = 0.05
    epoches = 10
    percent = 300
    batch_size = 32
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    accuracies_all_task_for_dataset = cls.defaultdict(list)
    accuracies_all_task_for_advers_samples_dataset = cls.defaultdict(list)
    accuracies_all_model = []
    test_data = {}
    series_average_robustness_acc = []
    Pool = {}
    ACC_mean = []
    sample_index = {}
    a = [0.95, 0.95, 0.99, 0.99]
    b = [0.95, 0.999, 0.999, 0.999]
    c = [0.996, 0.996, 0.996, 0.996]
    numpy.random.seed(123)
    model =wpm_net(hidden_unit=256).to(device=device)
    s_train_data, s_test_data = torch.load("../mnist_permutations.pt")
    sample_grad = collections.defaultdict(list)
    s_grad = []
    grad_similarity = []
    grad_similarity.append(1)
    indexs__ = numpy.random.choice(a=60000, size=percent, replace=False)
    data0_grade_loader = DataLoader(dataset=list(zip(s_train_data[0][1][indexs__].to(torch.float), s_train_data[0][2][indexs__].long())), batch_size=1, shuffle=False)
    for k in range(10):
        print(k)
        optimizer = torch.optim.SGD(lr=lr, params=model.parameters())
        criterion = torch.nn.CrossEntropyLoss()
        train_data_ = list(zip(s_train_data[k][1], s_train_data[k][2]))
        test_data_ = list(zip(s_test_data[k][1], s_test_data[k][2]))
        train_loader = DataLoader(dataset=train_data_, batch_size=batch_size,
                                  shuffle=True)
        test_loader = DataLoader(dataset=test_data_, batch_size=1, shuffle=True
                                 )
        test_data[k] = test_loader
        train_losses_epoch = []
        model.train()
        if k == 0:
            for epoch in range(epoches):
                running_loss = 0
                for i, data in enumerate(train_loader):
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    inputs.requires_grad = True
                    output = model(inputs, k)[-1]
                    loss_1 = criterion(output, labels)
                    inputs_grad = torch.autograd.grad(loss_1, inputs, create_graph=True)[0]
                    # print(inputs_grad)
                    loss_2 =torch.mean(torch.sum(inputs_grad**2),dim=0)
                    # print(loss_2)
                    loss = loss_1 + lamda * loss_2
                    # print("loss_1:{},loss_2:{},loss:{}".format(loss_1,loss_2,loss))
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2, norm_type="inf")
                    # if k > 0:
                    #     for m, param in enumerate(model.parameters()):
                    #         param.grad -= torch.matmul(torch.matmul(param.grad, Pool[m].T), Pool[m])
                    optimizer.step()

        else:
            for epoch in range(epoches):
                running_loss = 0
                for i, data in enumerate(train_loader):
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    # print(torch.max(inputs))
                    optimizer.zero_grad()
                    inputs.requires_grad = True
                    output = model(inputs, k)[-1]
                    loss_1 = criterion(output, labels)
                    inputs_grad = torch.autograd.grad(loss_1, inputs, create_graph=True)[0]
                    # loss_2 = torch.mean(torch.sum(inputs_grad ** 2, dim=1, keepdim=False), dim=0)
                    loss_2 =torch.mean(torch.sum(inputs_grad**2), dim=0)
                    # loss_2 = torch.sum(inputs_grad[0] ** 2)
                    loss = loss_1 + lamda * loss_2
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2, norm_type="inf")
                    loss.backward()

                    param = list(model.parameters())[0]
                    param.grad -= torch.matmul(torch.matmul(param.grad,Pool[1].T), Pool[1])
                    param = list(model.parameters())[1]
                    param.grad -= torch.matmul(torch.matmul(param.grad, Pool[2].T), Pool[2])
                    optimizer.step()
        #update GPM
        model.eval()
        indexs = numpy.random.choice(a=60000, size=percent, replace=False)
        sample_index[k] = indexs
        print(indexs)
        grade_dataloader = DataLoader(dataset=list(zip(s_train_data[k][1][indexs].to(torch.float), s_train_data[k][2][indexs].long())),
                                  batch_size=batch_size, shuffle=True)

        X = torch.Tensor([])
        L = cls.defaultdict(torch.Tensor)
        Weights_basis = cls.defaultdict(torch.Tensor)
        for v, data in enumerate(grade_dataloader):
            sample, label = data
            sample = sample.to(device)  # [bs, features]
            label = label.to(device)
            optimizer.zero_grad()
            # l1_logit, l2_logit, l3_logit, l4_logit = model( sample ) #[bs, features]
            l_logit = model(sample, k)  # [bs, features]
            for j, temp in enumerate(l_logit[:-1]):  # output of layer1, layer2 EXCEPT FOR LAST LAYER...
                L[j] = torch.cat([L[j], temp.detach().cpu()], dim=0)
                if j == 0:
                    w = model.dense_first[k].state_dict()['weight'].T.sum(dim=0).tile((sample.shape[0], 1))
                    # w1 is weight of layer1  ##[out_dim, in_dim] -- [in_dim, out_dim] --[out_dim] -- [bs, out_dim(neural numb) of layer1]
                else:  # j>1
                    w = model.l2(w)  # no activation
                temp[temp <= 0] = 0
                temp[temp > 0] = 1
                w = w * temp
                Weights_basis[j] = torch.cat([Weights_basis[j], w.detach().cpu()], dim=0)
        # GPM
        for layer, temp in zip(range(1, len(L) + 1), list(L.values())):
            print(f'GPM temp.shape: {temp.shape}')
            whole = torch.square(torch.linalg.matrix_norm(temp))
            if layer in Pool:
                pc = Pool[layer].cpu() # [ feature, n ]
                temp_ = torch.matmul(torch.matmul(temp, pc.T), pc)  # temp.shape
                pr = torch.square(torch.linalg.matrix_norm(temp_))
                temp2 = temp - temp_
                print(temp2)
                U, S, Vh = torch.linalg.svd(temp2, full_matrices=True)
                V = Vh.T
            else:
                U, S, Vh = torch.linalg.svd(temp, full_matrices=True)
                V = Vh.T
                pr = 0
            re = 0
            if layer in Pool:
                temp_re=torch.zeros_like(temp2)
            else:
                temp_re=torch.zeros_like(temp)
            for k1 in range(S.shape[0]):
                u, s, v = U[:, k1:k1 + 1], S[k1], V[:, k1:k1 + 1]
                temp_re += s * torch.matmul(u, v.T)
                re = torch.square(torch.linalg.matrix_norm(temp_re))

                if re + pr >= whole * a[layer]:
                    print(k1)
                    break
            new_pc = V[:, :k1 + 1].T  # [ n, feature ] -- [ feature, n ]
            # new_pc = U[:k + 1, :]
            if layer in Pool:
                Pool[layer] = torch.cat([pc, new_pc], dim=0).to(device) # [features,k]
            else:
                Pool[layer] = new_pc.to(device)
        # DGP
        for layer, temp in zip(range(1, len(Weights_basis) + 1), list(Weights_basis.values())):
            print(f'Robust temp.shape: {temp.shape}')
            whole = torch.sum(torch.abs(temp) ** 2)
            if layer in Pool:
                pc = Pool[layer].cpu() # [ feature, n ]
                temp_ = torch.matmul(torch.matmul(temp, pc.T), pc)  # temp.shape
                pr = torch.square(torch.linalg.matrix_norm(temp_))
                temp2 = temp - temp_
                print(temp2)
                U, S, Vh = torch.linalg.svd(temp2, full_matrices=True)
                V = Vh.T
            else:
                U, S, Vh = torch.linalg.svd(temp, full_matrices=True)
                V = Vh.T
                pr = 0
            re = 0
            if layer in Pool:
                temp_re = torch.zeros_like(temp2)
            else:
                temp_re = torch.zeros_like(temp)
            for k1 in range(S.shape[0]):
                u, s, v = U[:, k1:k1 + 1], S[k1], V[:, k1:k1 + 1]
                temp_re += s * torch.matmul(u, v.T)
                re = torch.square(torch.linalg.matrix_norm(temp_re))
                # re=torch.square(torch.linalg.matrix_norm(temp_re))
                if re + pr >= whole * b[layer]:
                    print(k1)
                    break
            new_pc = V[:, :k1 + 1].T  # [ n, feature ] -- [ feature, n ]
            # new_pc = U[:k + 1, :]
            if layer in Pool:
                Pool[layer] = torch.cat([pc, new_pc], dim=0).to(device) # [features,k]
            else:
                Pool[layer] = new_pc.to(device)
        print(f'current task {k} is over, number of directions in pool is {[pc.shape[0] for pc in Pool.values()]} ')
        for n_n, pool_dim in Pool.items():
            print(n_n)
            pool_dim = pool_dim.cpu()
            # if pool_dim.shape[0]>=max(list(model.parameters())[n_n+9].shape[0],list(model.parameters())[n_n+9].shape[1]):
            whole_ = torch.sum(torch.abs(pool_dim) ** 2)
            U, S, Vh = torch.linalg.svd(pool_dim, full_matrices=True)
            V = Vh.T
            temp_re = torch.zeros_like(pool_dim)
            for k_K in range(S.shape[0]):
                u, s, v = U[:, k_K:k_K + 1], S[k_K], V[:, k_K:k_K + 1]
                temp_re += s * torch.matmul(u, v.T)
                re = torch.square(torch.linalg.matrix_norm(temp_re))
                     # re=torch.square(torch.linalg.matrix_norm(temp_re))
                if re >= whole_ * c[n_n]:
                    print(k_K)
                    break
            pool_dim = V[:, :k_K + 1].T
            Pool[n_n] = pool_dim.to(device)
        print(f'current task {k} is over, number of directions in pool is {[pc.shape[0] for pc in Pool.values()]} ')
        correct_numbers = 0
        for v, data in enumerate(test_loader):
            x, y = data
            x, y = x.to(device), y.long().to(device)
            with torch.no_grad():
                logits = model(x, k)[-1]
            correct_numbers += torch.eq(logits.argmax(dim=1), y).sum().float().item()
        accuracy = correct_numbers / len(test_loader.dataset)
        accuracies_all_model.append(accuracy)
        print("accuracy for data {} in task{} is :{}".format(k, k,accuracy))
        ACC = []
        # test data0 in task {k} accuracy
        for q, test_loadr in test_data.items():
            correct_numbers = 0
            for v, data in enumerate(test_loadr):
                x, y = data
                x, y = x.to(device), y.long().to(device)
                with torch.no_grad():
                    logits = model(x, q)[-1]
                correct_numbers += torch.eq(logits.argmax(dim=1), y).sum().float().item()
            accuracy = correct_numbers / len(test_loadr.dataset)
            ACC.append(accuracy)
            accuracies_all_task_for_dataset[q].append(accuracy)
            print("accuracy for data {} in task{} is :{}".format(q, k, accuracy))
        acc_mean = numpy.mean(ACC)
        ACC_mean.append(acc_mean)
        print("ACC_mean:{}".format(ACC_mean))
        # test for the adver_samples
        ss = []
        for p, test_loadr in test_data.items():
            advers_samples = make_advers_samples_normal_model(model=model, test_dataloader=test_loadr, device=device,
                                                              epsilon=epsilon, task_number_=p)
            correct_numbers = 0
            for data in advers_samples:
                x, y = data
                x, y = x.to(device), y.to(device)
                with torch.no_grad():
                    logits_adver = model(x, p)[-1]
                correct_numbers += torch.eq(logits_adver.argmax(dim=1), y).sum().float().item()
            accuracy = correct_numbers / len(advers_samples)
            ss.append(accuracy)
            accuracies_all_task_for_advers_samples_dataset[p].append(accuracy)
            print(
                "accuracy for adver_samples of data{} in task{} is :{}".format(p, k, accuracy))
        series_average_robustness_acc.append(numpy.mean(ss))

        print("accuracy_for_normal_sample：{}".format(accuracies_all_task_for_dataset))
        print("accuracy_for_advers_sample：{}".format(accuracies_all_task_for_advers_samples_dataset))
        print("accuracy_for_models_themseves：{}".format(accuracies_all_model))
    bwt = (numpy.sum(ACC) - ACC[-1]) / (len(ACC) - 1) - (numpy.sum(accuracies_all_model) - accuracies_all_model[-1]) / (
            len(accuracies_all_model) - 1)
    bwt_robustness = (numpy.sum(ss) - ss[-1]) / (len(ss) - 1) - numpy.mean(
        [accuracies_all_task_for_advers_samples_dataset[zz][0] for zz in range(9)])
    print("ACC:{},BWT:{}".format(ACC_mean[-1], bwt))
    print("ACC_MEAN_CURVE:{}".format(ACC_mean))
    print("ACC_MEAN_CURVE_robustness:{}".format(series_average_robustness_acc))
    print("BWT:{}".format(bwt_robustness))
