import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from scipy import io as sio
import random

def shift(inputs, step=1):
    [row, col, nC] = inputs.shape
    inputs = F.pad(inputs, (0, 0, 0, (nC-1)*step, 0, 0), "constant", 0)
    output = torch.zeros_like(inputs)
    for i in range(nC):
        output[:, :, i] = torch.roll(inputs[:, :, i], step*i, dims=-1)
    return output

class dataset(Dataset):
    def __init__(self, KAIST, Phi, response, num_per_epoch, size=128):
        super(dataset, self).__init__()

        # self.CAVE = CAVE
        self.KAIST = KAIST
        self.Phi = Phi.permute(1, 2, 0)
        self.response = response
        self.num = num_per_epoch
        self.size = size

    def __getitem__(self, index):
        # d = random.randint(0, 1)
        # if d == 0:
        #     index1 = random.randint(0, 31)
        #     hsi = self.CAVE[index1]
        # elif d == 1:
        index1 = random.randint(0, 29)
        hsi = self.KAIST[index1]

        shape = hsi.shape

        px = random.randint(0, shape[0] - self.size)
        py = random.randint(0, shape[1] - self.size)
        pxm = random.randint(0, self.Phi.shape[0] - self.size)
        pym = random.randint(0, self.Phi.shape[1] - self.size - 25)
        label = hsi[px:px + self.size, py:py + self.size, :]
        Phi = self.Phi[pxm:pxm+self.size, pym:pym+self.size+25, :]

        rotTimes = random.randint(0, 3)
        vFlip    = random.randint(0, 1)
        hFlip    = random.randint(0, 1)

        # Random rotation
        for j in range(rotTimes):
            label  =  torch.rot90(label)

        # Random vertical Flip
        for j in range(vFlip):
            label = torch.flipud(label)

        # Random horizontal Flip
        for j in range(hFlip):
            label = torch.fliplr(label)

        label = label * self.response
        temp = Phi * shift(label, 1)

        mea = torch.sum(temp, dim=-1)
        mea = mea / 26 * 2 * 1.2

        QE, bit = 0.4, 2**11
        mea = torch.binomial(mea * bit / QE, QE * torch.ones_like(mea)) / bit
        mea = mea * 26 / 2 / 1.2

        label = label.permute(2,0,1)
        Phi = Phi.permute(2,0,1)

        pan = torch.sum(label, dim=0) + 1e-6
        gt_chroma = label / pan * 26
        pan = pan / 26

        for i in range(26):
            Phi[i, :, i:i+self.size] = Phi[i, :, i:i+self.size] * pan

        PhiPhi_T = torch.sum(Phi**2, dim=0)
        PhiPhi_T[PhiPhi_T==0] = 1

        return gt_chroma, Phi, PhiPhi_T, mea

    def __len__(self):
        return self.num