import torch
from torch.functional import norm
from torch.nn.modules import linear
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from math import log, tanh
from numpy.core.fromnumeric import size
import numpy as np
from torch import autograd
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import optimizer
from numpy import float32, pi
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable, forward_ad
from torchvision import transforms
import ge
from scipy import optimize
from scipy import linalg
import torch
from torch.functional import norm
from torch.nn.modules import linear
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from math import log, tanh
from numpy.core.fromnumeric import size
import numpy as np
from torch import autograd
import torch.nn as nn
from torch.nn import Sequential, ReLU, Tanh, Sigmoid
import matplotlib.pyplot as plt
from torch.optim import optimizer
from numpy import float32, pi
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable, forward_ad
from torchvision import transforms
import ge
from scipy import optimize
from scipy import linalg
from resnet18 import resnet18
import torchvision.models as models
from densenet import densenet121
from densenet import densenet201
from myvgg import vgg16
import random
from resnet50 import resnet50
def sampling_by_category(data, num=3):
    class_to_idx = train_dataset.class_to_idx
    idx_to_class = list(class_to_idx.keys())
    image_wall = {k:list() for k in class_to_idx.keys()}
    indices = list(range(len(data)))
    random.shuffle(indices)
    all_image = []
    for i in indices:
        img, label_idx = data[i]
        label = idx_to_class[label_idx]
        if len(image_wall[label]) < num:
            image_wall[label].append(img)
            all_image.append(img)
    return all_image
def cl1(a, dataset, bs, c_0):
    #bj = 0
    dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
    for data, label in dataloader:
        print("start optimum shifting")
        #all_image = sampling_by_category(dataset)
        #label = np.array( list(image_wall.keys())).astype(float)
        #data = torch.stack(list(all_image), 0)
        #print(torch.tensor(data).shape)
        #print(torch.tensor(label).shape)
        #print("hhh")
        if torch.cuda.is_available():
            data = data.cuda()
            label = label.cuda()
        ans = a(data)
        #print("end")
        #print("Loss_before:",F.cross_entropy(ans, label))
        print("norm_before:", torch.norm(a.l1.weight))
        for j in range(len(a.l1.weight)):
            c = a.l1.weight[j].cpu().detach().numpy().flatten()
            c = c.reshape(len(c),1)
            tmp = a.cin.cpu().detach().numpy()
            tmp1 = a.cout[:,j].cpu().detach().numpy().reshape((a.cout[:,j].shape[0], 1)) - a.l1.bias[j].cpu().detach().numpy() #- tmp@np.array([c_0]*200).reshape(200,1)
            mat = np.concatenate([tmp, tmp1], axis=1)
            mat = mat[0:bs,:]
            assert(np.linalg.norm(tmp@(c-c_0) - tmp1)<1e-8)
            ans, r = ge.Gaussain_elimination(mat)
            
            rk = bs - 1
            for xx in range(bs):
                if abs(ans[rk,-1]) > 1e-5:
                    break
                else:
                    rk = rk - 1
            """ if bj == 0:
                print("error before elimination:", np.linalg.norm(tmp@(c-c_0) - tmp1))
                print("error after elimination:", np.linalg.norm(ans[:,0:200]@(c-c_0) - ans[:, -1].reshape(bs,1)))
                print(rk)
                bj = 1 """
            
            if (np.linalg.norm(ans[:,0:len(a.l1.weight[0])]@(c-c_0) - ans[:, -1].reshape(bs,1))>1e-3):
                #print(np.linalg.norm(ans[:,0:len(a.l1.weight[0])]@(c-c_0) - ans[:, -1].reshape(bs,1)))
                continue
            #ans[np.where(np.abs(ans)<1e-10)] = 0
            with torch.no_grad():
                if torch.cuda.is_available():
                    a.l1.weight[j] = (torch.tensor(c_0+ans[0:rk,0:len(a.l1.weight[0])].T@linalg.inv(ans[0:rk,0:len(a.l1.weight[0])]@ans[0:rk,0:len(a.l1.weight[0])].T)@(ans[0:rk, -1]).reshape((rk,1))))[:,0].cuda()
                else:
                    a.l1.weight[j] = (torch.tensor(c_0+ans[0:rk,0:len(a.l1.weight[0])].T@linalg.inv(ans[0:rk,0:len(a.l1.weight[0])]@ans[0:rk,0:len(a.l1.weight[0])].T)@(ans[0:rk, -1]).reshape((rk,1))))[:,0]
        ans = a(data)
        #print("Loss_after:", F.cross_entropy(ans, label))
        print("norm_after:", torch.norm(a.l1.weight))
        return
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


torch18 = models.resnet50(pretrained=True)
a = resnet50()
a.load_state_dict(torch18.state_dict(), strict = False)

opt = torch.optim.SGD(a.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(optimizer=opt, milestones=[150, 225, 250], gamma=0.1)


if torch.cuda.is_available():
    a = a.cuda()
data_transfrom = transforms.Compose(
    [transforms.ToTensor(),
     transforms.RandomHorizontalFlip(),
     transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
     #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
data_transform2 = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
     #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
recx = []
recy = []
recy1 = []
train_dataset = datasets.CIFAR100(root='/root/autodl-tmp/CIFAR100', train = True, download = True, transform = data_transfrom)
test_data = datasets.CIFAR100(root='/root/autodl-tmp/CIFAR100', train = False, download = True, transform = data_transform2)
bs = 64
max_acc1 = 0
dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=8)
t_d = DataLoader(test_data, batch_size=100, shuffle=True, num_workers=8)

losst = 0
a1_w2 = []
a2_w2 = []
a1_w1 = []
a2_w1 = []
cnt = 0
a1_tl = []
a2_tl = []
print("start training")
opt = torch.optim.SGD(a.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(optimizer=opt, milestones=[150, 225], gamma=0.1)

#a.load_state_dict(torch.load("/root/autodl-tmp/model3/a1"+str(2*96+1)+"paras.pth"))
#a1.load_state_dict(torch.load("/root/autodl-tmp/model3/a2"+str(2*96+1)+"paras.pth"))
#a = a.cuda()
#a1 = a1.cuda()
acc_1total = []
acc_2total = []
for e in range(300):
    cl1(a, train_dataset, bs = 300, c_0 = 0)
    losst = 0
    cnt = 0
    for i in dataloader:
            img1, label = i
            if torch.cuda.is_available():
                img1 = img1.cuda()
                label = label.cuda()
            inputs, targets_a, targets_b, lam = mixup_data(img1, label)
            inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b))
            #with shift
            outputs = a(inputs)
            loss = mixup_criterion(nn.CrossEntropyLoss(), outputs, targets_a, targets_b, lam)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losst = losst + loss.item()/(50000/bs)
            cnt = cnt + 1
    if e >= 0:
        recx.append(e)
        recy.append(losst)
    test_loss = 0
    test_loss_2 = 0
    print("epoch: ", e)
    sch.step()
    print("a1 training loss:", losst)
    acca1 = 0
    acca2 = 0
    for j in t_d:
        img, label = j
        img = img.cuda()
        ans = a(img)
        pred = np.array(torch.argmax(ans, dim = 1).cpu())
        label = np.array(label)
        acca1 = acca1 + len(np.where(label == pred)[0])
    print("a1 training acc:", acca1)
    print(torch.norm(a.l1.weight).item(), acca1)
    if acca1 > max_acc1:
        max_acc1 = acca1
        torch.save(a.state_dict(), "./resnet18/a1paras.pth")