import torch
import torchvision
import torchvision.transforms as transforms
import tqdm
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.myset import MyDataset
from torch.utils.data import DataLoader
from lib.model.mnistnet import MNISTNet
import torch.optim as optim

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])


train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000,
                                         shuffle=False, num_workers=1)


class Pic:
    def __init__(self, _id, ntk):
        self.id = _id
        self.ntk = ntk


criterion = nn.CrossEntropyLoss()


SIZE = 50


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def cal_ntk_c(now_set: MyDataset, net, optimizer, tot_para):
    ntk_loader = DataLoader(now_set, batch_size=1)
    vectors = []
    tmp_tot = len(ntk_loader)
    # print('get gradient vectors')
    # with tqdm.tqdm(total=tmp_tot) as p_bar:
    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(10):
            vectors.append(get_gradient_tensor(net, optimizer, outputs[0][k], tot=tot_para))
        # p_bar.update(1)
    siz = len(vectors)
    mat = np.zeros((siz, siz), dtype=float)
    # print('get H matrix')
    # with tqdm.tqdm(total=siz * (siz + 1) // 2) as p_bar:
    for i in range(siz):
        for j in range(i, siz):
            mat[i][j] = mat[j][i] = vectors[i].dot(vectors[j])
    czz = np.min(np.linalg.eigvalsh(mat))
    print(czz)
    return czz


def train_net(now_set: MyDataset, net, optimizer):
    train_loader = DataLoader(now_set, batch_size=SIZE, num_workers=1)
    lost = []
    accl = []
    epoch = 0
    while 1:
        epoch += 1
        loss = 0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(labels)
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # rate = abs((lost[-min(10, epoch)] - loss) / loss)break
        acc = 0
        for data in testloader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = net(Variable(images))
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            acc = (predicted == labels.data.cpu().numpy()).sum() / 10
        accl.append(acc)
        # print('loss: ', loss)
        if epoch > 10 and sum(accl[-5:]) <= sum(accl[-10:-5]):
            print('')
            break
    print(' epoch : %d  ' % epoch, end='')
    print('acc : %.1f ' % acc, end='')
        # lost.append(loss)
    return accl, epoch


def round1(now_set: MyDataset, ids=None, path='./record.txt'):
    print('loaded')
    setup_seed(20)
    net = MNISTNet().cuda()
    tot_para = cal_para(net)
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    czz = cal_ntk_c(now_set, net, optimizer, tot_para)

    loss, epoch = train_net(now_set, net, optimizer)
    with open(path, 'a') as f:
        f.write(str(czz) + ' ' + str(epoch) + '\n')
        for los in loss:
            f.write(str(float(los)) + ' ')
        f.write('\n')
        for los in ids:
            f.write(str(los) + ' ')
        f.write('\n')


def get_rnk(ty):
    f = open('../../logs/record_mnist.txt')
    prl = []
    while 1:
        s = f.readline()
        if not s:
            prl.sort(key=lambda pic: pic.ntk)
            return prl
        a, b, t = s.split()
        if int(t) == ty:
            prl.append(Pic(int(a), float(b)))


def main():
    prl = []
    for i in range(10):
        prl.append(get_rnk(i))
        print(len(prl[i]))
    inp = []
    lab = []
    for data in train_set:
        data_i, data_l = data
        inp.append(data_i)
        lab.append(data_l)
    tot_siz = len(inp)
    import time
    for i in range(100):
        bg = time.time()
        random.seed(bg)
        inn = []
        lan = []
        ids = []
        for t in range(10):
            for j in range(SIZE // 10):
                tmp = random.randint(0, 15)
                while tmp in ids:
                    tmp = random.randint(0, 50)
                tmp = prl[t][-tmp].id
                ids.append(tmp)
                inn.append(inp[tmp])
                lan.append(lab[tmp])
        round1(MyDataset(inn, lan), ids, path='logs/acc_4.txt')
        print(time.time() - bg)


if __name__ == '__main__':
    main()
