import os
import os.path as osp

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from .autoaugment import AutoAugImageNetPolicy

class CUB200(Dataset):

    def __init__(self, root='/data/zhoudw/FSCIL', train=True,
                 index_path=None, index=None, base_sess=None,autoaug=1):
        self.root = os.path.expanduser(root)
        self.train = train  # training set or test set
        self._pre_operate(self.root)

        if autoaug==0:
            #do not use autoaug
            if train:
                self.transform = transforms.Compose([
                    transforms.Resize(256),
                    # transforms.CenterCrop(224),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    #add autoaug
                    #AutoAugImageNetPolicy(),
                    #transforms.ToTensor(),
                    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
                # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
                if base_sess:
                    self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
                else:
                    self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
            else:
                self.transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    #transforms.ToTensor(),
                    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ])
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
        else:
            #use autoaug
            if train:
                self.transform = transforms.Compose([
                    transforms.Resize(256),
                    # transforms.CenterCrop(224),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    #add autoaug
                    #AutoAugImageNetPolicy(),
                    #transforms.ToTensor(),
                    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
                # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
                if base_sess:
                    self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
                else:
                    self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
            else:
                self.transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    #transforms.ToTensor(),
                    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ])
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
    def text_read(self, file):
        with open(file, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                lines[i] = line.strip('\n')
        return lines

    def list2dict(self, list):
        dict = {}
        for l in list:
            s = l.split(' ')
            id = int(s[0])
            cls = s[1]
            if id not in dict.keys():
                dict[id] = cls
            else:
                raise EOFError('The same ID can only appear once')
        return dict

    def _pre_operate(self, root):
        image_file = os.path.join(root, 'CUB_200_2011/images.txt')
        split_file = os.path.join(root, 'CUB_200_2011/train_test_split.txt')
        class_file = os.path.join(root, 'CUB_200_2011/image_class_labels.txt')
        id2image = self.list2dict(self.text_read(image_file))
        id2train = self.list2dict(self.text_read(split_file))  # 1: train images; 0: test iamges
        id2class = self.list2dict(self.text_read(class_file))
        train_idx = []
        test_idx = []
        for k in sorted(id2train.keys()):
            if id2train[k] == '1':
                train_idx.append(k)
            else:
                test_idx.append(k)

        self.data = []
        self.targets = []
        self.data2label = {}
        if self.train:
            for k in train_idx:
                image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k])
                self.data.append(image_path)
                self.targets.append(int(id2class[k]) - 1)
                self.data2label[image_path] = (int(id2class[k]) - 1)

        else:
            for k in test_idx:
                image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k])
                self.data.append(image_path)
                self.targets.append(int(id2class[k]) - 1)
                self.data2label[image_path] = (int(id2class[k]) - 1)

    def SelectfromTxt(self, data2label, index_path):
        index = open(index_path).read().splitlines()
        data_tmp = []
        targets_tmp = []
        for i in index:
            img_path = os.path.join(self.root, i)
            data_tmp.append(img_path)
            targets_tmp.append(data2label[img_path])

        return data_tmp, targets_tmp

    def SelectfromClasses(self, data, targets, index):
        data_tmp = []
        targets_tmp = []
        for i in index:
            ind_cl = np.where(i == targets)[0]
            for j in ind_cl:
                data_tmp.append(data[j])
                targets_tmp.append(targets[j])

        return data_tmp, targets_tmp

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, targets = self.data[i], self.targets[i]
        image = self.transform(Image.open(path).convert('RGB'))
        image = np.array(image)
        image = np.transpose(image, [2, 0, 1])
        #print(image.shape)
        r_img = image[0]
        g_img = image[1]
        b_img = image[2]
        
        # R
        r_f = np.fft.fft2(r_img)
        r_fshift = np.fft.fftshift(r_f)
        #r_magnitude_spectrum = 20*np.log(abs(r_f))
        #r_phase = np.angle(r_fshift)

        # G
        g_f = np.fft.fft2(g_img)
        g_fshift = np.fft.fftshift(g_f)
        #g_magnitude_spectrum = 20*np.log(abs(g_f))
        #g_phase = np.angle(g_fshift)

        # B
        b_f = np.fft.fft2(b_img)
        b_fshift = np.fft.fftshift(b_f)
        #b_magnitude_spectrum = 20*np.log(abs(b_f))
        #b_phase = np.angle(b_fshift)

        (w, h) = r_fshift.shape
        half_w, half_h = int(w/2), int(h/2)
        n = 112
        
        ch_arr = []
        r_fshift_ori = r_fshift.copy()
        g_fshift_ori = g_fshift.copy()
        b_fshift_ori = b_fshift.copy()

        for i in range(2):
            r_fshift = r_fshift_ori.copy()
            r_fshift[0:half_h - int(n/2)*i - int(n/2), :] = 0
            r_fshift[half_h + int(n/2)*i + int(n/2):, :] = 0
            r_fshift[:, 0:half_h - int(n/2)*i - int(n/2)] = 0
            r_fshift[:, half_h + int(n/2)*i + int(n/2):] = 0
            r_fshift[half_w - int(n/2)*i : half_w + int(n/2)*i, half_h - int(n/2)*i : half_h + int(n/2)*i] = 0
            r_fshift = np.fft.ifftshift(r_fshift)
            r_liu = np.fft.ifft2(r_fshift).real
            ch_arr.append(r_liu)

            g_fshift = g_fshift_ori.copy()
            g_fshift[0:half_h - int(n/2)*i - int(n/2), :] = 0
            g_fshift[half_h + int(n/2)*i + int(n/2):, :] = 0
            g_fshift[:, 0:half_h - int(n/2)*i - int(n/2)] = 0
            g_fshift[:, half_h + int(n/2)*i + int(n/2):] = 0
            g_fshift[half_w - int(n/2)*i : half_w + int(n/2)*i, half_h - int(n/2)*i : half_h + int(n/2)*i] = 0
            g_fshift = np.fft.ifftshift(g_fshift)
            g_liu = np.fft.ifft2(g_fshift).real
            ch_arr.append(g_liu)

            b_fshift = b_fshift_ori.copy()
            b_fshift[0:half_h - int(n/2)*i - int(n/2), :] = 0
            b_fshift[half_h + int(n/2)*i + int(n/2):, :] = 0
            b_fshift[:, 0:half_h - int(n/2)*i - int(n/2)] = 0
            b_fshift[:, half_h + int(n/2)*i + int(n/2):] = 0
            b_fshift[half_w - int(n/2)*i : half_w + int(n/2)*i, half_h - int(n/2)*i : half_h + int(n/2)*i] = 0
            b_fshift = np.fft.ifftshift(b_fshift)
            b_liu = np.fft.ifft2(b_fshift).real
            ch_arr.append(b_liu)
        

        ch_arr = np.array(ch_arr)
        ch_arr = ch_arr.astype(np.float32)

        return ch_arr, targets


class CUB200_concate(Dataset):
    def __init__(self, train,x1,y1,x2,y2):
        
        if train:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                # transforms.CenterCrop(224),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                #transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                #transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
        self.data=x1+x2
        self.targets=y1+y2
        print(len(self.data),len(self.targets))

    

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, targets = self.data[i], self.targets[i]
        image = self.transform(Image.open(path).convert('RGB'))

        image = np.array(image)
        image = np.transpose(image, [2, 0, 1])
        #print(image.shape)
        r_img = image[0]
        g_img = image[1]
        b_img = image[2]
        
        # R
        r_f = np.fft.fft2(r_img)
        r_fshift = np.fft.fftshift(r_f)
        #r_magnitude_spectrum = 20*np.log(abs(r_f))
        #r_phase = np.angle(r_fshift)

        # G
        g_f = np.fft.fft2(g_img)
        g_fshift = np.fft.fftshift(g_f)
        #g_magnitude_spectrum = 20*np.log(abs(g_f))
        #g_phase = np.angle(g_fshift)

        # B
        b_f = np.fft.fft2(b_img)
        b_fshift = np.fft.fftshift(b_f)
        #b_magnitude_spectrum = 20*np.log(abs(b_f))
        #b_phase = np.angle(b_fshift)

        (w, h) = r_fshift.shape
        half_w, half_h = int(w/2), int(h/2)
        n = 112
        
        ch_arr = []
        r_fshift_ori = r_fshift.copy()
        g_fshift_ori = g_fshift.copy()
        b_fshift_ori = b_fshift.copy()

        for i in range(2):
            r_fshift = r_fshift_ori.copy()
            r_fshift[0:half_h - int(n/2)*i - int(n/2), :] = 0
            r_fshift[half_h + int(n/2)*i + int(n/2):, :] = 0
            r_fshift[:, 0:half_h - int(n/2)*i - int(n/2)] = 0
            r_fshift[:, half_h + int(n/2)*i + int(n/2):] = 0
            r_fshift[half_w - int(n/2)*i : half_w + int(n/2)*i, half_h - int(n/2)*i : half_h + int(n/2)*i] = 0
            r_fshift = np.fft.ifftshift(r_fshift)
            r_liu = np.fft.ifft2(r_fshift).real
            ch_arr.append(r_liu)

            g_fshift = g_fshift_ori.copy()
            g_fshift[0:half_h - int(n/2)*i - int(n/2), :] = 0
            g_fshift[half_h + int(n/2)*i + int(n/2):, :] = 0
            g_fshift[:, 0:half_h - int(n/2)*i - int(n/2)] = 0
            g_fshift[:, half_h + int(n/2)*i + int(n/2):] = 0
            g_fshift[half_w - int(n/2)*i : half_w + int(n/2)*i, half_h - int(n/2)*i : half_h + int(n/2)*i] = 0
            g_fshift = np.fft.ifftshift(g_fshift)
            g_liu = np.fft.ifft2(g_fshift).real
            ch_arr.append(g_liu)

            b_fshift = b_fshift_ori.copy()
            b_fshift[0:half_h - int(n/2)*i - int(n/2), :] = 0
            b_fshift[half_h + int(n/2)*i + int(n/2):, :] = 0
            b_fshift[:, 0:half_h - int(n/2)*i - int(n/2)] = 0
            b_fshift[:, half_h + int(n/2)*i + int(n/2):] = 0
            b_fshift[half_w - int(n/2)*i : half_w + int(n/2)*i, half_h - int(n/2)*i : half_h + int(n/2)*i] = 0
            b_fshift = np.fft.ifftshift(b_fshift)
            b_liu = np.fft.ifft2(b_fshift).real
            ch_arr.append(b_liu)
        

        ch_arr = np.array(ch_arr)
        ch_arr = ch_arr.astype(np.float32)

        return ch_arr, targets

if __name__ == '__main__':
    txt_path = "../../data/index_list/cub200/session_1.txt"
    # class_index = open(txt_path).read().splitlines()
    base_class = 100
    class_index = np.arange(base_class)
    dataroot = '~/dataloader/data'
    batch_size_base = 400
    trainset = CUB200(root=dataroot, train=False,  index=class_index,
                      base_sess=True)
    cls = np.unique(trainset.targets)
    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_base, shuffle=True, num_workers=8,
                                              pin_memory=True)

    