import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
import numpy as np
import torch
import time
import random


class MyDataset(data.Dataset):
    def __init__(self, photo: list, label: list):
        self.x = photo
        self.y = label

    def __getitem__(self, item):
        return self.x[item], self.y[item]

    def __len__(self):
        return len(self.x)


class Pic:
    def __init__(self, _id, ntk):
        self.id = _id
        self.ntk = ntk


def random_shuffle(lis):
    random.seed(int(time.time()))
    for i in range(len(lis)):
        j = random.randint(0, len(lis) - 1)
        lis[i], lis[j] = lis[j], lis[i]
    return lis


class Cifar100:
    def __init__(self, path='/mnt/cache/wangyudong1/ntk/data', size=32):
        """有关Cifar数据的一切操作，path是读取路径"""
        if size == 224:
            # print(1)
            train_transform = transforms.Compose([
                        # transforms.RandomHorizontalFlip(),
                        # transforms.RandomCrop(32, 4),
                        transforms.Resize(224),
                        # transforms.Resize(224),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
            test_transform = transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Resize(224),
                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        elif size == 384:
            train_transform = transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        # transforms.RandomCrop(32, 4),
                        transforms.Resize(384),
                        # transforms.CenterCrop(384),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
            test_transform = transforms.Compose(
                        [transforms.ToTensor(),
                         transforms.Resize(384),
                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        else:
            train_transform = transforms.Compose([
                        # transforms.RandomHorizontalFlip(),
                        # transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
            test_transform = transforms.Compose(
                        [transforms.ToTensor(),
                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

        self.train_set = torchvision.datasets.CIFAR100(
            root=path, train=True, download=True, transform=train_transform)
        # for pic, lab in self.train_set:
        #     print(pic.shape)
        #     exit(0)
        self.test_set = torchvision.datasets.CIFAR100(
            root=path, train=False, download=True, transform=test_transform)
        pic, lab = [], []
        for item in self.test_set:
            pi, la = item
            pic.append(pi)
            lab.append(la)
        self.test_set = MyDataset(pic, lab)
        self.div_class = None
        self.ntk = None
        self.tr = None
        self.tr_nt = None
        self.train_pic = None
        self.train_lab = None
        self.tot_rnk = None

    def train_loader(self, data_set=None, batch=1, shuffle=True):
        if not data_set:
            return DataLoader(self.train_set, batch_size=batch, shuffle=shuffle)
        return DataLoader(data_set, batch_size=batch, shuffle=shuffle)

    def get_single(self, data_set=None, item=0):
        if not data_set:
            pic, lab = self.train_set[item]
            return MyDataset([pic], [lab])

    def get_rnd_data(self, size=5):
        random.seed(time.time())
        if not self.div_class:
            self.get_class()
        ids = set()
        pic = []
        lab = []
        for i in range(100):
            for j in range(size):
                tmp = random.randint(0, len(self.div_class[i]) - 1)
                while tmp in ids:
                    tmp = random.randint(0, len(self.div_class[i]) - 1)
                ids.add(tmp)
                pic.append(self.div_class[i][tmp])
                lab.append(i)
        return MyDataset(pic, lab), ids

    def get_class(self):
        self.div_class = [[] for i in range(100)]
        for item in self.train_set:
            pic, lab = item
            self.div_class[lab].append(pic)

    def get_ntk(self, path='./logs/record1/ntk_cifar1.txt'):
        f = open(path)
        prl = [[] for i in range(10)]
        con = 0
        while 1:
            s = f.readline()
            if not s:
                break
            pic, lab = self.train_set[con]
            prl[lab].append(Pic(con, float(s)))
            con += 1
        for i in range(100):
            prl[i].sort(key=lambda pic: pic.ntk)
        self.ntk = prl

    def get_rnk(self, path='./logs/record1/tr_cifar1.txt'):
        f = open(path, 'r')
        lis = []
        while 1:
            s = f.readline()
            if not s:
                break
            lis.append(float(s))
        lis2 = []
        for i in range(len(lis)):
            lis2.append(Pic(i, lis[i]))
        lis2.sort(key=lambda pic: pic.ntk)
        lis3 = [0 for i in range(len(lis))]
        for i in range(len(lis)):
            lis3[lis2[i].id] = i
            self.tot_rnk = lis3

    def get_tr_nt(self, path='./logs/record1/tr_cifar1.txt'):
        f = open(path)
        prl = []
        con = 0
        while 1:
            s = f.readline()
            if not s:
                break
            prl.append(Pic(con, float(s)))
            con += 1
        prl.sort(key=lambda pic: pic.ntk)
        self.tr_nt = prl

    def get_tr(self, path='./logs/record1/tr_cifar1.txt'):
        f = open(path)
        prl = [[] for i in range(100)]
        con = 0
        while 1:
            s = f.readline()
            if not s:
                break
            pic, lab = self.train_set[con]
            prl[lab].append(Pic(con, float(s)))
            con += 1
        for i in range(100):
            prl[i].sort(key=lambda pic: pic.ntk)
        self.tr = prl

    def get_tr_suf(self, size=10, l=0, r=500, path='./logs/record1/tr_cifar_denesnet.txt'):
        if not r:
            r = size
        # if not self.tr:
        self.get_tr(path=path)
        # self.get_class()
        inp = []
        lab = []
        ids = []
        for i in range(100):
            pt = self.tr[i]
            random.shuffle(pt[l:r])
            for j in range(size):
                pic, lan = self.train_set[pt[l + j].id]
                inp.append(pic)
                lab.append(lan)
                ids.append(pt[l + j].id)
        # ids.sort()
        # for i in range(len(ids) - 1):
        #     if ids[i] == ids[i + 1]:
        #         print('worng')
        #         exit(0)
        # print(len(ids))
        # exit(0)
        return MyDataset(inp, lab)

    def get_rnd_suf(self, size=10):
        if not self.div_class:
            self.get_class()
        inp = []
        lab = []
        for i in range(100):
            random.shuffle(self.div_class[i])
            for j in range(size):
                inp.append(self.div_class[i][j])
                lab.append(i)
        return MyDataset(inp, lab)

    def get_dataset_noty(self, lam='tr', l=0, r=0, siz=10):
        if not self.tr:
            self.get_tr_nt()
        if not self.ntk:
            self.get_ntk()
        import time
        random.seed(time.time())
        inn, lan, ids = [], [], []
        tpk = r
        trsum = 0
        prl = self.tr_nt
        if lam == 'ntk':
            prl = self.ntk
        for j in range(siz):
            if r == -1:
                tpk = len(prl)
            tmp = random.randint(l, tpk - 1)
            while prl[tmp].id in ids:
                tmp = random.randint(0, tpk - 1)
            trsum += prl[-tmp].ntk
            tmp = prl[-tmp].id
            ppc, lbb = self.train_set[tmp]
            ids.append(tmp)
            inn.append(ppc)
            lan.append(lbb)
        return MyDataset(inn, lan), ids, trsum

    def get_sta(self, l=0, r=100):
        if not self.tr_nt:
            self.get_tr_nt()
        cou = [0 for i in range(10)]
        for i in range(l, r):
            pic, lab = self.train_set[self.tr_nt[i].id]
            cou[lab] += 1
        return cou

    def get_dataset(self, lam='tr', l=0, r=0, t=0, siz=10):
        if not self.tr:
            self.get_tr()
        if not self.ntk:
            self.get_ntk()
        import time
        random.seed(time.time())
        inn, lan, ids = [], [], []
        tpk = r
        trsum = 0
        prl = self.tr
        if lam == 'ntk':
            prl = self.ntk
        for j in range(siz):
            if r == -1:
                tpk = len(prl[t])
            tmp = random.randint(l,
                                 - 1)
            while prl[t][tmp].id in ids:
                tmp = random.randint(0, tpk - 1)
            trsum += prl[t][-tmp].ntk
            tmp = prl[t][-tmp].id
            ppc, lbb = self.train_set[tmp]
            ids.append(tmp)
            inn.append(ppc)
            lan.append(lbb)
        return MyDataset(inn, lan), ids, trsum

    def get_lab_pic(self):
        self.train_pic = []
        self.train_lab = []
        for pic, lab in self.train_set:
            self.train_pic.append(pic)
            self.train_lab.append(lab)
