import os
import os.path as osp

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from jpeg2dct.numpy import load, loads
class MiniImageNet(Dataset):

    def __init__(self, root='', train=True,
                 transform=None,
                 index_path=None, index=None, base_sess=None):
        if train:
            setname = 'train'
        else:
            setname = 'test'
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train  # training set or test set
        self.IMAGE_PATH = os.path.join(root, '/home/main/datasets/miniimagenet/images')
        self.SPLIT_PATH = os.path.join(root, '/home/main/datasets/miniimagenet/split')

        csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        self.data = []
        self.targets = []
        self.data2label = {}
        lb = -1

        self.wnids = []

        for l in lines:
            name, wnid = l.split(',')
            path = osp.join(self.IMAGE_PATH, name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            self.data.append(path)
            self.targets.append(lb)
            self.data2label[path] = lb

        if train:
            self.image_size = 84
            self.transform = transforms.Compose([
                transforms.Resize([92, 92]),
                transforms.RandomResizedCrop(self.image_size),
                # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip()])
            if base_sess:
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
            else:
                self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
        else:
            self.image_size = 84
            self.transform = transforms.Compose([
                transforms.Resize([84, 84])
                #transforms.CenterCrop(self.image_size)
                ])
            self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)


    def SelectfromTxt(self, data2label, index_path):
        index=[]
        lines = [x.strip() for x in open(index_path, 'r').readlines()]
        for line in lines:
            index.append(line.split('/')[3])
        data_tmp = []
        targets_tmp = []
        for i in index:
            img_path = os.path.join(self.IMAGE_PATH, i)
            data_tmp.append(img_path)
            targets_tmp.append(data2label[img_path])

        return data_tmp, targets_tmp

    def SelectfromClasses(self, data, targets, index):
        data_tmp = []
        targets_tmp = []
        for i in index:
            ind_cl = np.where(i == targets)[0]
            for j in ind_cl:
                data_tmp.append(data[j])
                targets_tmp.append(targets[j])

        return data_tmp, targets_tmp

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):

        path, targets = self.data[i], self.targets[i]
        
        image = self.transform(Image.open(path).convert('RGB'))
        image = np.array(image)
        image = torch.Tensor(image)
        
        image = image.permute(2,0,1)

        r_img = image[0]
        g_img = image[1]
        b_img = image[2]
        
        # R
        r_f = torch.fft.fft2(r_img)
        r_fshift = torch.fft.fftshift(r_f)

        # G
        g_f = torch.fft.fft2(g_img)
        g_fshift = torch.fft.fftshift(g_f)

        # B
        b_f = torch.fft.fft2(b_img)
        b_fshift = torch.fft.fftshift(b_f)

        (w, h) = r_fshift.shape
        half_w, half_h = int(w/2), int(h/2)
        n = 42
        c = 2
        
        ch_arr = []
        r_fshift_ori = r_fshift.clone()
        g_fshift_ori = g_fshift.clone()
        b_fshift_ori = b_fshift.clone()

        for j in range(c): #1,c c-1
            filter = int(self.image_size/c/2)*j
            r_filter = half_h - filter - int(n/2)
            l_filter = half_h + filter + int(n/2)
            if r_filter < 0:
                r_filter = 0
            if l_filter > self.image_size:
                l_filter = 84
            r_fshift = r_fshift_ori.clone()
            r_fshift[0:r_filter, :] = 0
            r_fshift[l_filter:, :] = 0
            r_fshift[:, 0:r_filter] = 0
            r_fshift[:, l_filter:] = 0
            r_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
            r_fshift = torch.fft.ifftshift(r_fshift)
            r_liu = torch.fft.ifft2(r_fshift).real
            ch_arr.append(r_liu)

            g_fshift = g_fshift_ori.clone()
            g_fshift[0:r_filter, :] = 0
            g_fshift[l_filter:, :] = 0
            g_fshift[:, 0:r_filter] = 0
            g_fshift[:, l_filter:] = 0
            g_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
            g_fshift = torch.fft.ifftshift(g_fshift)
            g_liu = torch.fft.ifft2(g_fshift).real
            ch_arr.append(g_liu)

            b_fshift = b_fshift_ori.clone()
            b_fshift[0:r_filter, :] = 0
            b_fshift[l_filter:, :] = 0
            b_fshift[:, 0:r_filter] = 0
            b_fshift[:, l_filter:] = 0
            b_fshift[half_w - filter : half_w + filter, half_h - filter : half_h + filter] = 0
            b_fshift = torch.fft.ifftshift(b_fshift)
            b_liu = torch.fft.ifft2(b_fshift).real
            ch_arr.append(b_liu)
        ch_arr = torch.stack(ch_arr, dim = 0)
        ch_arr = ch_arr.type(torch.float32)
        
        return ch_arr, targets
