import torch
import time
import tqdm
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData
from torch.utils.data import DataLoader
from lib.model.cifarnet import Net
import torch.optim as optim
from lib.util.mytoolbag import setup_seed


criterion = nn.CrossEntropyLoss().cuda()


class Pic:
    def __init__(self, _id, ntk):
        self.id = _id
        self.ntk = ntk


def mix_up_f(inp, lab, rnk=None, lam=None, method=None, idd=1):
    if not lam:
        lam = np.random.beta(1, 1)
    batch_size = inp.size()[0]
    index = torch.randperm(batch_size).cuda()
    lis = []
    for i in range(rnk.size()[0]):
        lis.append(Pic(i, rnk[i]))
    lis.sort(key=lambda pic: pic.ntk)
    lis2 = [0 for i in range(batch_size)]
    for i in range(rnk.size()[0]):
        lis2[lis[i].id] = i
    # l1 = (batch_size // 2) + idd
    l1 = (batch_size // 2)
    if method == 'near':
        l1 = 1
    elif method == 'near_rk':
        l1 = idd
    for i in range(rnk.size()[0]):
        if method == 'near_r':
            l1 = random.randint(1, 50)
        index[i] = i
        if random.randint(1, 10) != 1:
            index[i] = lis[(lis2[i] + l1) % batch_size].id
            while lab[index[i]] == lab[i]:
                index[i] = lis[(lis2[index[i]] + 1) % batch_size].id
        else:
            index[i] = lis[(lis2[i] + l1) % batch_size].id
            while lab[index[i]] != lab[i]:
                index[i] = lis[(lis2[index[i]] + 1) % batch_size].id
    if lam < 0.5:
        lam = 1 - lam
    mixed_x = lam * inp + (1 - lam) * inp[index, :]
    y_a, y_b = lab, lab[index]

    return mixed_x, y_a, y_b, lam


def mix_up(inp, lab, rnk=None, lam=None, method=None, idd=1, test_long=False):
    if not lam:
        lam = np.random.beta(1, 1)
    batch_size = inp.size()[0]
    index = torch.randperm(batch_size).cuda()
    # print(index)
    # input()
    if test_long:
        long = 0
        for i in range(batch_size):
            long += abs(rnk[i] - rnk[index[i]])
        return long / batch_size

    mixed_x = lam * inp + (1 - lam) * inp[index, :]
    y_a, y_b = lab, lab[index]

    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_net(train_loader, net, optimizer, testloader, mixup=None, epoch1=100, scheduler=None):
    accl = 0
    epoch = 0
    print(mixup)
    for i in range(epoch1):
        epoch += 1
        bg = time.time()
        net.train()
        # acc2 = 0
        for _, data in enumerate(train_loader):
            inputs, labels, rnk = data
            if mixup == 'None':
                inputs, l_a, l_b, lam = mix_up(inputs, labels)
            elif mixup == 'ori':
                inputs, l_a, l_b, lam = inputs, labels, labels, 0
            elif mixup == 'far':
                inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='far', idd=i)
            elif mixup == 'near_rk':
                inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near_rk', idd=i)
            elif mixup == 'near_r':
                inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near_r', idd=i)
            elif mixup == 'near':
                inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near')
            elif mixup == 'nfnf':
                if i < epoch1 * 0.9:
                    inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='far', idd=i)
                else:
                    inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near', idd=i)
            elif mixup == 'fnfn':
                if i < epoch1 * 0.4:
                    inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near', idd=i)
                else:
                    inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='far', idd=i)
            # print(labels)
            inputs, l_a, l_b = Variable(inputs).cuda(), Variable(l_a).cuda(), Variable(l_b).cuda()
            outputs = net(inputs)
            loss = mixup_criterion(criterion, outputs, l_a, l_b, lam=lam)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # acc2 += (torch.max(outputs, 1)[1].data.cpu().numpy() == labels.data.cpu().numpy()).sum() / 500
        # print('train acc : %d  ' % acc2, end='')
        # rate = abs((lost[-min(10, epoch)] - loss) / loss)break
        # if epoch1 - epoch < 20:
        acc = 0
        net.eval()
        for data in testloader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = net(Variable(images))
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            acc += (predicted == labels.data.cpu().numpy()).sum() / 100
        accl = max(accl, acc)
        print('epoch : %d  ' % epoch, end='')
        print('acc : %.1f ' % acc, end='')
        print(time.time() - bg)
        # # print('loss: ', loss)
        # if epoch > 20 and sum(accl[-10:]) <= sum(accl[-20:-10]) + 0.1:
        #     print('')
        #     break
        # lost.append(loss)
        if scheduler:
            scheduler.step()
    print(accl)
    return accl




def mix_up_f11(inp, lab, rnk=None, lam=None, method=None, idd=1):
    if not lam:
        lam = np.random.beta(1, 1)
    batch_size = inp.size()[0]
    index = torch.randperm(batch_size).cuda()
    lis = []
    for i in range(rnk.size()[0]):
        lis.append(Pic(i, rnk[i]))
    lis.sort(key=lambda pic: pic.ntk)
    lis2 = [0 for i in range(batch_size)]
    for i in range(rnk.size()[0]):
        lis2[lis[i].id] = i
    # l1 = (batch_size // 2) + idd
    l1 = (batch_size // 2)
    if method == 'near':
        l1 = 1
    elif method == 'near_rk':
        l1 = idd
    for i in range(rnk.size()[0]):
        if method == 'near_r':
            l1 = random.randint(1, 50)
        index[i] = i
        if random.randint(1, 10) != 1:
            index[i] = lis[(lis2[i] + l1) % batch_size].id
            while lab[index[i]] == lab[i]:
                index[i] = lis[(lis2[index[i]] + 1) % batch_size].id
        else:
            index[i] = lis[(lis2[i] + l1) % batch_size].id
            while lab[index[i]] != lab[i]:
                index[i] = lis[(lis2[index[i]] + 1) % batch_size].id
    long = 0
    for i in range(batch_size):
        long += abs(rnk[i] - rnk[index[i]])
    return long / batch_size


def train_net1(train_loader, net, optimizer, testloader, mixup=None, epoch1=100):
    accl = []
    epoch = 0
    same = 0
    diff = 0
    for i in range(1):
        epoch += 1
        loss = 0
        long1 = []
        for _, data in enumerate(train_loader):
            inputs, labels, rnk = data
            long = 0
            if mixup == 'None':
                # inputs, l_a, l_b, lam = mix_up(inputs, labels, rnk=rnk, test_long=True)
                long = mix_up(inputs, labels, rnk=rnk, test_long=True)

            elif mixup == 'far':
                # inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='far')
                long =  mix_up_f11(inputs, labels, rnk, method='far')
            elif mixup == 'near_rk':
                # inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near_rk', idd=i)
                long =  mix_up_f11(inputs, labels, rnk, method='near_rk', idd=i)
            elif mixup == 'near_r':
                # inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near_r', idd=i)
                long =  mix_up_f11(inputs, labels, rnk, method='near_r', idd=i)
            else:
                # inputs, l_a, l_b, lam = mix_up_f(inputs, labels, rnk, method='near')
                long =  mix_up_f11(inputs, labels, rnk, method='near')
            # print(labels)
            long1.append(long)
            print(long, end=' ')
            print(sum(long1) / len(long1))


            # 比例
            # smm = sum(l_a == l_b)
            # same += smm
            # diff += len(l_a) - smm
            # print(same, diff, float(same) / (float(diff) + float(same)))


        # rate = abs((lost[-min(10, epoch)] - loss) / loss)break
        # # print('loss: ', loss)
        # if epoch > 20 and sum(accl[-10:]) <= sum(accl[-20:-10]) + 0.1:
        #     print('')
        #     break
        # lost.append(loss)
    return same, diff