import collections
import avalanche
from torch.nn import CrossEntropyLoss

from torch.utils.data import DataLoader
from utils_ import *
from models_alexnet import basic_cnn_net_batchnorm_multheadandwei
import numpy
import torchvision.transforms as transforms
from gpm_plugin import GPMPlugin
gpm_plugin = GPMPlugin(penalty_lambda=1.)
transform_test = transforms.Compose(
        [
            transforms.ToTensor()
         ])
BATCHSIZE = 10
benchmark = avalanche.benchmarks.classic.SplitCIFAR100(n_experiences=10, return_task_id=True, shuffle=True, train_transform=transform_test, eval_transform=transform_test,class_ids_from_zero_in_each_exp=True)
test_stream = benchmark.test_stream
device = torch.device("cuda:1")
epsilon = 0.015
lr = 0.05
percent = 300
nf = 64
EPOCHES = 100
a = 0.97
b = [0.996 for iii in range(100)]
cc = [0.99 for jjj in range(100)]

layer_hyperparameter = {
    10: {'type': 'conv', 'kernel_size': 3, 'pad': 0, 'stride': 1, 'channel_in': nf},
    11: {'type': 'conv', 'kernel_size': 2, 'pad': 0, 'stride': 1, 'channel_in': nf},
    12: {'type': 'fc', 'kernel_size': 3, 'pad': 1, 'stride': 1, 'channel_in': nf},
    13: {'type': 'fc', 'kernel_size': 3, 'pad': 1, 'stride': 1, 'channel_in': nf}}
model =basic_cnn_net_batchnorm_multheadandwei().to(device=device)
sample_grad = collections.defaultdict(list)
s_grad = []
grad_similarity = []
grad_similarity.append(1)
indexs = numpy.random.choice(a=5000, size=percent, replace=False)
grade_list = []
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_list):
        super(MyDataset, self).__init__()
        self.data_list = data_list
    def __len__(self):
        return len(self.data_list)
    def __getitem__(self, index):
        pic, label_, t = self.data_list[index]
        label = torch.from_numpy(numpy.array(label_)).long()
        return pic, label, t
for n_n in indexs:
    grade_list.append(benchmark.train_stream[0].dataset.__getitem__(n_n))
data0_grade_dataset = MyDataset(grade_list)
data0_grade_loader = DataLoader(dataset=data0_grade_dataset,
                          batch_size=1, shuffle=False)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, )
cl_strategy = avalanche.training.Naive(
    model, optimizer=optimizer,
    criterion=CrossEntropyLoss(), train_mb_size=BATCHSIZE, train_epochs=EPOCHES, eval_mb_size=BATCHSIZE,plugins=[gpm_plugin,],
    device=device)
# TRAINING LOOP
print('Starting experiment...')
results = []
test_data = {}
Pool = {}
accuracies_all_model = []
accuracies_all_task_for_dataset = collections.defaultdict(list)
ACC_mean = []
series_average_robustness_acc = []
accuracies_all_task_for_advers_samples_dataset = collections.defaultdict(list)
for index, experience in enumerate(benchmark.train_stream):
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)
    model.train()
    # train returns a dictionary which contains all the metric values
    res = cl_strategy.train(experience)
    print('Training completed')
    print('Computing accuracy on the whole test set')
    # eval also returns a dictionary which contains all the metric values
    # results.append(cl_strategy.eval(benchmark.test_stream, num_workers=0))
    test_loader = DataLoader(dataset=benchmark.test_stream[index].dataset, batch_size=1, shuffle=True)
    test_data[index] = test_loader
    L = collections.defaultdict(torch.Tensor)
    G = collections.defaultdict(torch.Tensor)

    model.eval()
    indexs = numpy.random.choice(a=5000, size=percent, replace=False)
    grade_list = []
    class MyDataset(torch.utils.data.Dataset):
        def __init__(self, data_list):
            super(MyDataset, self).__init__()
            self.data_list = data_list
        def __len__(self):
            return len(self.data_list)
        def __getitem__(self, index):
            pic, label_, t = self.data_list[index]
            label = torch.from_numpy(numpy.array(label_)).long()
            return pic, label, t
    for n_n in indexs:
        grade_list.append(experience.dataset.__getitem__(n_n))
    grade_dataset = MyDataset(grade_list)
    grade_loader = DataLoader(dataset=grade_dataset,
                              batch_size=BATCHSIZE, shuffle=True)
    for sample, label, task_ids in grade_loader:
        sample = sample.to(device)  # [bs, features]
        label = label.to(device)
        task_ids = task_ids.to(device)
        optimizer.zero_grad()
        combs, logit, batch_recoder = model.get_mid_value(sample, task_ids)
        weight = model.c1[experience.current_experience].weight.data#[cout,cin,k,k]
        weiget_sum = torch.sum(weight, dim=[1, 2, 3], keepdim=False)
        weiget_sum = weiget_sum.unsqueeze(dim=1).unsqueeze(dim=2)
        weight_extend = weiget_sum.repeat((1, 29, 29)).flatten()\
            .unsqueeze(dim=0).repeat((BATCHSIZE, 1))
        print(weight_extend.shape)
        x3, x7, x10, x14 = batch_recoder
        x3[x3 <= 0] = 0
        x3[x3 > 0] = 1
        after_relu = weight_extend.reshape(x3.shape)*(x3)
        grad_1 = model.pool(after_relu)
        grad_1_ = maps_reshape_strade(grad_1,filter_size=3, stride=1)
        grad_mid = model.c2(grad_1)
        x7[x7 <= 0] = 0
        x7[x7 > 0] = 1
        grad_2 = model.pool(grad_mid*x7)
        grad_2_ = maps_reshape_strade(grad_2, filter_size=2, stride=1)
        grad_mid__ = model.c3(grad_2)
        x10[x10 <= 0] = 0
        x10[x10 > 0] = 1
        grad_3 = model.floor(model.pool(grad_mid__*x10))
        grad_mid___ = model.l1(grad_3)
        x14[x14 <= 0] = 0
        x14[x14 > 0] = 1
        grad_4 = grad_mid___*x14

        s_grad = [grad_1_, grad_2_, grad_3, grad_4]
        for iiii, (layer, many_) in enumerate(combs.items()):
            G[layer] = torch.cat([G[layer], s_grad[iiii].detach().cpu()], dim=0)
            # GPM
            temp = many_['input']  # [bs,c,h,w]
            if layer_hyperparameter[layer]['type'] == 'conv':
                temp = maps_reshape_strade(temp, filter_size=layer_hyperparameter[layer]['kernel_size'],
                                           stride=layer_hyperparameter[layer]['stride'])
            L[layer] = torch.cat([L[layer], temp.detach().cpu()], dim=0)
    for layer, temp in L.items():
        print(f'GPM temp.shape: {temp.shape}')  # [bs,c,h,w]
        whole = torch.sum(torch.abs(temp) ** 2)
        if layer in Pool:
            pc = Pool[layer].cpu()  # [ feature, n ]
            temp_ = torch.matmul(torch.matmul(temp, pc.T), pc)  # temp.shape
            pr = torch.square(torch.linalg.matrix_norm(temp_))
            temp2 = temp - temp_
            U, S, Vh = torch.linalg.svd(temp2, full_matrices=True)
            V = Vh.T
        else:
            U, S, Vh = torch.linalg.svd(temp, full_matrices=True)
            V = Vh.T
            pr = 0
        re = 0
        if layer in Pool:
            temp_re = torch.zeros_like(temp2)
        else:
            temp_re = torch.zeros_like(temp)
        for k1 in range(S.shape[0]):
            u, s, v = U[:, k1:k1 + 1], S[k1], V[:, k1:k1 + 1]
            temp_re += s * torch.matmul(u, v.T)
            re = torch.square(torch.linalg.matrix_norm(temp_re))
            if re + pr >= whole * (a+0.0025*experience.current_experience):
                print(k1)
                break
        new_pc = V[:, :k1 + 1].T
        if layer in Pool:
            Pool[layer] = torch.cat([pc, new_pc], dim=0).to(device)
        else:
            Pool[layer] = new_pc.to(device)
    print(
        f'current task {experience.current_experience} is over, number of directions in pool is {[pc.shape[0] for pc in Pool.values()]} ')
    for layer, many in G.items():
        temp = many
        print(f'Robust input_grad.shape: {temp.shape}')  # [ bs,  c * h * w ] or # [ bs, f1 ]
        whole = torch.sum(torch.abs(temp) ** 2)
        pc = Pool[layer].cpu()
        whole = torch.square(torch.linalg.matrix_norm(temp))
        temp_ = torch.matmul(torch.matmul(temp, pc.T), pc)  # temp.shape
        pr = torch.square(torch.linalg.matrix_norm(temp_))
        temp2 = temp - temp_
        U, S, Vh = torch.linalg.svd(temp2, full_matrices=True)
        V = Vh.T
        re = 0
        temp_re = torch.zeros_like(temp2)
        for k2 in range(S.shape[0]):
            u, s, v = U[:, k2:k2 + 1], S[k2], V[:, k2:k2 + 1]
            temp_re += s * torch.matmul(u, v.T)
            re = torch.square(torch.linalg.matrix_norm(temp_re))
            if re + pr >= whole * b[layer]:
                print(k2)
                break
        new_pc = V[:, :k2 + 1].T
        Pool[layer] = torch.cat([pc, new_pc], dim=0).to(device)
        # Pool[layer] = torch.cat([pc, new_pc], dim=0)# [features,k]
    print(
        f'current task {experience.current_experience} is over, number of directions in pool is {[pc.shape[0] for pc in Pool.values()]} ')
    for n_n, pool_dim in Pool.items():
        pool_dim = pool_dim.cpu()
        whole_ = torch.sum(torch.abs(pool_dim) ** 2)
        U, S, Vh = torch.linalg.svd(pool_dim, full_matrices=True)
        V = Vh.T
        temp_re = torch.zeros_like(pool_dim)
        for k_K in range(S.shape[0]):
            u, s, v = U[:, k_K:k_K + 1], S[k_K], V[:, k_K:k_K + 1]
            temp_re += s * torch.matmul(u, v.T)
            re = torch.square(torch.linalg.matrix_norm(temp_re))
            if re >= whole_ * cc[n_n]:
                print(k_K)
                break
        k_K = min(k_K, pool_dim.shape[1] - 2)
        pool_dim = V[:, :k_K + 1].T
        Pool[n_n] = pool_dim.to(device)
    print(f'current task {experience.current_experience} is over, number of directions in pool is {[pc.shape[0] for pc in Pool.values()]} ')
    gpm_plugin.renew_pool(Pool=Pool)

    correct_numbers = 0
    for v, data in enumerate(test_loader):
        x, y, t = data
        x, y, t= x.to(device), y.long().to(device),t.to(device)
        with torch.no_grad():
            logits = model(x, t)
        correct_numbers += torch.eq(logits.argmax(dim=1), y).sum().float().item()
    accuracy = correct_numbers / len(test_loader.dataset)
    accuracies_all_model.append(accuracy)
    print(
        "accuracy for data {} in task{} is :{}".format(experience.current_experience, experience.current_experience,
                                                       accuracy))
    ACC = []
    for q, test_loadr in test_data.items():
        correct_numbers = 0
        for v, data in enumerate(test_loadr):
            x, y, t = data
            x, y, t = x.to(device), y.long().to(device), t.to(device)
            with torch.no_grad():
                logits = model(x, t)
            correct_numbers += torch.eq(logits.argmax(dim=1), y).sum().float().item()
        accuracy = correct_numbers / len(test_loadr.dataset)
        ACC.append(accuracy)
        accuracies_all_task_for_dataset[q].append(accuracy)
        print("accuracy for data {} in task{} is :{}".format(q,experience.current_experience, accuracy))
    acc_mean = numpy.mean(ACC)
    ACC_mean.append(acc_mean)
    print("ACC_mean:{}".format(ACC_mean))
    # test for the adver_samples
    ss = []
    for p, test_loadr in test_data.items():
        advers_samples = make_advers_samples_normal_model_2(model=model, test_dataloader=test_loadr, device=device,
                                                          epsilon=epsilon)
        correct_numbers = 0
        for data in advers_samples:
            x, y, t = data
            x, y, t = x.to(device), y.to(device), t.to(device)
            with torch.no_grad():
                logits_adver = model(x, t)
            correct_numbers += torch.eq(logits_adver.argmax(dim=1), y).sum().float().item()
        accuracy = correct_numbers / len(advers_samples)
        ss.append(accuracy)
        accuracies_all_task_for_advers_samples_dataset[p].append(accuracy)
        print(
            "accuracy for adver_samples of data{} in task{} is :{}".format(p, experience.current_experience, accuracy))
    series_average_robustness_acc.append(numpy.mean(ss))

    s_grad_similarity = []
    for ii, data in enumerate(data0_grade_loader):
        samples, targets, task_ids = data
        samples, targets, task_ids = samples.to(device=device), targets.to(device=device), task_ids.to(device=device)
        samples.requires_grad = True
        cl_strategy.model.zero_grad()
        logits = model(samples, task_ids)
        loss = torch.nn.functional.cross_entropy(logits, targets)
        samples_grad = torch.autograd.grad(loss, samples)[0].flatten()  # [num_sample, pixels]
        if experience.current_experience == 0:
            sample_grad[ii] = samples_grad
        # print(samples_grad)
        if experience.current_experience > 0:
            origin_sample_grad = sample_grad[ii]
            curr_sample_grad = samples_grad
            print(origin_sample_grad.shape, curr_sample_grad.shape)
            # print(torch.dot(curr_sample_grad, origin_sample_grad))
            s_grad_similarity.append(torch.dot(curr_sample_grad, origin_sample_grad) / (
                    torch.sqrt(torch.sum(curr_sample_grad ** 2)) * torch.sqrt(torch.sum(origin_sample_grad ** 2))))
    if experience.current_experience > 0:
        # print(s_grad_similarity)
        grad_similarity.append(torch.mean(torch.Tensor(s_grad_similarity)).item())
    # if experience.current_experience == 0:
    #     s_grad.append(sample_grad)
    print('grad_similarity:{}'.format(grad_similarity))
    print("accuracy_for_normal_sample：{}".format(accuracies_all_task_for_dataset))
    print("accuracy_for_advers_sample：{}".format(accuracies_all_task_for_advers_samples_dataset))
    print("accuracy_for_models_themseves：{}".format(accuracies_all_model))
bwt = (numpy.sum(ACC) - ACC[-1]) / (len(ACC) - 1) - (numpy.sum(accuracies_all_model) - accuracies_all_model[-1]) / (
        len(accuracies_all_model) - 1)
bwt_robustness = (numpy.sum(ss) - ss[-1]) / (len(ss) - 1) - numpy.mean(
    [accuracies_all_task_for_advers_samples_dataset[zz][0] for zz in range(9)])
print("ACC:{},BWT:{}".format(ACC_mean[-1], bwt))
print("ACC_MEAN_CURVE:{}".format(ACC_mean))
print("ACC_MEAN_CURVE_robustness:{}".format(series_average_robustness_acc))
print("BWT:{}".format(bwt_robustness))
print('grad_similarity:{}'.format(grad_similarity))
