import os
import os.path as osp

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class CUB200(Dataset):

    def __init__(self, root='./', train=True, index_path=None, index=None, base_sess=None, two_images=False, validation=False):
        self.root = os.path.expanduser(root)
        self.train = train  # training set or test set
        self._pre_operate(self.root)

        image_size = 224
        train_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)),
            transforms.RandomHorizontalFlip()
            #transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),  # not strengthened
        ])

        train_transform_v2 = transforms.Compose([
            transforms.Resize(256),
            # transforms.CenterCrop(224),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip()
        ])

        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])

        base_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])

        base_transforms_v2 = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])


        if train:
            if validation:
                print('---- CUB200 Base Transform ---')
                self.transform = base_transforms
            elif not base_sess:
                print('---- CUB200 Base Transform ---')
                self.transform = base_transforms
            else:
                if two_images:
                    print('---- CUB200 TwoCrops Training Transform ---')
                    self.transform = TwoCropsTransform(train_transforms)
                else:
                    print('---- CUB200 OneCrops Training Transform ---')
                    self.transform = train_transforms

            if base_sess:
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
            else:
                self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
        else:
            if validation:
                print('---- CUB200 Base Transform ---')
                self.transform = base_transforms
            else:
                print('---- CUB200 Testing Transform ---')
                self.transform = transform_test

            self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)

    def text_read(self, file):
        with open(file, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                lines[i] = line.strip('\n')
        return lines

    def list2dict(self, list):
        dict = {}
        for l in list:
            s = l.split(' ')
            id = int(s[0])
            cls = s[1]
            if id not in dict.keys():
                dict[id] = cls
            else:
                raise EOFError('The same ID can only appear once')
        return dict

    def _pre_operate(self, root):
        image_file = os.path.join(root, 'CUB_200_2011/images.txt')
        split_file = os.path.join(root, 'CUB_200_2011/train_test_split.txt')
        class_file = os.path.join(root, 'CUB_200_2011/image_class_labels.txt')
        id2image = self.list2dict(self.text_read(image_file))
        id2train = self.list2dict(self.text_read(split_file))  # 1: train images; 0: test iamges
        id2class = self.list2dict(self.text_read(class_file))
        train_idx = []
        test_idx = []
        for k in sorted(id2train.keys()):
            if id2train[k] == '1':
                train_idx.append(k)
            else:
                test_idx.append(k)

        self.data = []
        self.targets = []
        self.data2label = {}
        if self.train:
            for k in train_idx:
                image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k])
                self.data.append(image_path)
                self.targets.append(int(id2class[k]) - 1)
                self.data2label[image_path] = (int(id2class[k]) - 1)

        else:
            for k in test_idx:
                image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k])
                self.data.append(image_path)
                self.targets.append(int(id2class[k]) - 1)
                self.data2label[image_path] = (int(id2class[k]) - 1)

    def SelectfromTxt(self, data2label, index_path):
        data_tmp = []
        targets_tmp = []

        for i in range(len(index_path)):
            index = open(index_path[i]).read().splitlines()
            for i in index:
                img_path = os.path.join(self.root, i)
                data_tmp.append(img_path)
                targets_tmp.append(data2label[img_path])

        return data_tmp, targets_tmp

    def SelectfromClasses(self, data, targets, index):
        data_tmp = []
        targets_tmp = []
        for i in index:
            ind_cl = np.where(i == targets)[0]
            for j in ind_cl:
                data_tmp.append(data[j])
                targets_tmp.append(targets[j])

        return data_tmp, targets_tmp

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, targets = self.data[i], self.targets[i]
        image = self.transform(Image.open(path).convert('RGB'))
        image = np.array(image)
        n = 112
        c = 2
        image_size = 224
        if len(image) == 2:
            output = []
            for img in image:
                img = np.array(img)
                image = torch.Tensor(img)
            
                image = image.permute(2,0,1)
                #print(image.shape)
                r_img = image[0]
                g_img = image[1]
                b_img = image[2]
                
                # R
                r_f = torch.fft.fft2(r_img)
                r_fshift = torch.fft.fftshift(r_f)

                # G
                g_f = torch.fft.fft2(g_img)
                g_fshift = torch.fft.fftshift(g_f)

                # B
                b_f = torch.fft.fft2(b_img)
                b_fshift = torch.fft.fftshift(b_f)

                (w, h) = r_fshift.shape
                half_w, half_h = int(w/2), int(h/2)
                
                
                ch_arr = []
                r_fshift_ori = r_fshift.clone()
                g_fshift_ori = g_fshift.clone()
                b_fshift_ori = b_fshift.clone()

                for j in range(c):
                    filter = int(image_size/c/2)*j
                    r_filter = half_h - filter - int(n/2)
                    l_filter = half_h + filter + int(n/2)
                    if r_filter < 0:
                        r_filter = 0
                    if l_filter > image_size:
                        l_filter = 84
                    r_fshift = r_fshift_ori.clone()
                    r_fshift[0:r_filter, :] = 0
                    r_fshift[l_filter:, :] = 0
                    r_fshift[:, 0:r_filter] = 0
                    r_fshift[:, l_filter:] = 0
                    r_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
                    r_fshift = torch.fft.ifftshift(r_fshift)
                    r_liu = torch.fft.ifft2(r_fshift).real
                    ch_arr.append(r_liu)

                    g_fshift = g_fshift_ori.clone()
                    g_fshift[0:r_filter, :] = 0
                    g_fshift[l_filter:, :] = 0
                    g_fshift[:, 0:r_filter] = 0
                    g_fshift[:, l_filter:] = 0
                    g_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
                    g_fshift = torch.fft.ifftshift(g_fshift)
                    g_liu = torch.fft.ifft2(g_fshift).real
                    ch_arr.append(g_liu)

                    b_fshift = b_fshift_ori.clone()
                    b_fshift[0:r_filter, :] = 0
                    b_fshift[l_filter:, :] = 0
                    b_fshift[:, 0:r_filter] = 0
                    b_fshift[:, l_filter:] = 0
                    b_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
                    b_fshift = torch.fft.ifftshift(b_fshift)
                    b_liu = torch.fft.ifft2(b_fshift).real
                    ch_arr.append(b_liu)
                
                ch_arr = torch.stack(ch_arr, dim = 0)
                ch_arr = ch_arr.type(torch.float32)
                output.append(ch_arr)
        else:
            image = torch.Tensor(image)
            
            image = image.permute(2,0,1)
            r_img = image[0]
            g_img = image[1]
            b_img = image[2]
            
            # R
            r_f = torch.fft.fft2(r_img)
            r_fshift = torch.fft.fftshift(r_f)

            # G
            g_f = torch.fft.fft2(g_img)
            g_fshift = torch.fft.fftshift(g_f)

            # B
            b_f = torch.fft.fft2(b_img)
            b_fshift = torch.fft.fftshift(b_f)

            (w, h) = r_fshift.shape
            half_w, half_h = int(w/2), int(h/2)
                
            ch_arr = []
            r_fshift_ori = r_fshift.clone()
            g_fshift_ori = g_fshift.clone()
            b_fshift_ori = b_fshift.clone()

            for j in range(c):
                filter = int(image_size/c/2)*j
                r_filter = half_h - filter - int(n/2)
                l_filter = half_h + filter + int(n/2)
                if r_filter < 0:
                    r_filter = 0
                if l_filter > image_size:
                    l_filter = 84
                r_fshift = r_fshift_ori.clone()
                r_fshift[0:r_filter, :] = 0
                r_fshift[l_filter:, :] = 0
                r_fshift[:, 0:r_filter] = 0
                r_fshift[:, l_filter:] = 0
                r_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
                r_fshift = torch.fft.ifftshift(r_fshift)
                r_liu = torch.fft.ifft2(r_fshift).real
                ch_arr.append(r_liu)

                g_fshift = g_fshift_ori.clone()
                g_fshift[0:r_filter, :] = 0
                g_fshift[l_filter:, :] = 0
                g_fshift[:, 0:r_filter] = 0
                g_fshift[:, l_filter:] = 0
                g_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
                g_fshift = torch.fft.ifftshift(g_fshift)
                g_liu = torch.fft.ifft2(g_fshift).real
                ch_arr.append(g_liu)

                b_fshift = b_fshift_ori.clone()
                b_fshift[0:r_filter, :] = 0
                b_fshift[l_filter:, :] = 0
                b_fshift[:, 0:r_filter] = 0
                b_fshift[:, l_filter:] = 0
                b_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
                b_fshift = torch.fft.ifftshift(b_fshift)
                b_liu = torch.fft.ifft2(b_fshift).real
                ch_arr.append(b_liu)
                
            ch_arr = torch.stack(ch_arr, dim = 0)
            output = ch_arr.type(torch.float32)
        
        return output, targets

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


