import functools
import random
import math
from PIL import Image

import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode 
import torch.nn.functional as F

from datasets import register
from utils import to_pixel_samples, get_band_interval, make_band_coords, resize_fn, interpolate_bands









@register('pavia-centra-sr-implicit-downsampled-paired')
class PaviaCentraSRImplicitDownsampled(Dataset):

    def __init__(self, dataset, is_train = True, inp_size=None, scale_min=1, scale_max=None,
                 augment=False, sample_q=None, 
                 band_path = None,
                 num_band_min = None, num_band_max = None, num_band_sample = None,
                 spec_min = None, spec_max = None):
        '''
        Args:
            dataset: the dataset obj we create from CAVEPairedImageFolders() in cave_image_folder.py
            inp_size: int, crop image size
            scale_min: int, smallest scale
            scale_max: int, maximum scale
            augment: bool, whether to do augumentation, horizantal/vertical flip/rotate
            sample_q: int, number of (coord, rbg_v) pairs we want to sample

            band_path: the path to save the band wavelength intervals, shape (num_band, 2)
            num_band_min: the minimum number of bands to do super-res
            num_band_max: the maximum number of bands to do super-res
            num_band_sample: the number of band to sample during mini-batch training

            spec_min: the minimum wavelength
            spec_max: the maximum wavelength
        '''
        self.dataset = dataset
        self.inp_size = inp_size
        self.scale_min = scale_min
        self.is_train = is_train
        if scale_max is None:
            scale_max = scale_min
        self.scale_max = scale_max
        self.augment = augment
        self.sample_q = sample_q

        self.band_path = band_path
        self.num_band_sample = num_band_sample
        if self.band_path is not None:
            # band_intervals:  shape (num_band, 2), the band wavelength intervals
            self.band_intervals = np.load(band_path)
            
            assert  len(self.band_intervals.shape) == 2 and self.band_intervals.shape[-1] == 2

            # get the min and max band number for different band interpolation purpose
            num_band = self.band_intervals.shape[0]
            if num_band_max is None:
                num_band_max = num_band
            if num_band_min is None:
                num_band_min = num_band
            # assert num_band_min <= num_band_max 
            # assert num_band_min <= num_band
            # assert num_band_max >= num_band
            self.num_band_min = num_band_min
            self.num_band_max = num_band_max
            self.num_band = num_band

            # get the min and max wavelength which are used to compute spectral coordinates
            if spec_min is None:
                spec_min = self.band_intervals[0,0]
            if spec_max is None:
                spec_max = self.band_intervals[-1,-1]
            assert spec_min <= spec_max
            self.spec_min = spec_min
            self.spec_max = spec_max



    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # hr_msi: shape (c, H, W)
        # hr_hsi: shape (C, H, W)
        hr_msi, hr_hsi = self.dataset[idx]
        assert hr_msi.shape[-2:] == hr_hsi.shape[-2:]
        c, H, W = hr_msi.shape
        C, H, W = hr_hsi.shape
        

        # unformly sample a scale
        s = random.uniform(self.scale_min, self.scale_max)

        if self.inp_size is None:
            h_lr = math.floor(H / s + 1e-9)
            w_lr = math.floor(W / s + 1e-9)
            hr_msi = hr_msi[:, :round(h_lr * s), :round(w_lr * s)] # assume round int
            hr_hsi = hr_hsi[:, :round(h_lr * s), :round(w_lr * s)] # assume round int

            lr_msi = resize_fn(hr_msi, (h_lr, w_lr))
            # lr_hsi = resize_fn(hr_hsi, (h_lr, w_lr))
            crop_lr_msi, crop_hr_hsi = lr_msi, hr_hsi
            '''
            crop_lr_msi: shape (c, H/s, W/s)
            crop_hr_hsi: shape (C, H, W)
            '''
        else:
            w_lr = self.inp_size
            w_hr = round(w_lr * s)
            if self.is_train:
                # For Pavia Centra, we cut off the upper left 1024 * 128 pixel for test
                # So avoid sample (x0, y0) into this region
                x0 = random.randint(0, H - w_hr)
                if x0 < 1024: 
                    y0 = random.randint(128, W - w_hr)
                else:
                    y0 = random.randint(0, W - w_hr) 
            else:
                x0 = random.randint(0, H - w_hr)
                y0 = random.randint(0, W - w_hr)
            crop_hr_msi = hr_msi[:, x0: x0 + w_hr, y0: y0 + w_hr]
            crop_hr_hsi = hr_hsi[:, x0: x0 + w_hr, y0: y0 + w_hr]
            
            # crop_lr_msi: shape (c, w_lr, w_lr)
            crop_lr_msi = resize_fn(crop_hr_msi, w_lr)
            '''
            crop_lr_msi: shape (c, inp_size = w_lr, inp_size = w_lr)
            crop_hr_hsi: shape (C, inp_size * s, inp_size * s)
            '''

        # Do data augumentation
        if self.augment:
            hflip = random.random() < 0.5
            vflip = random.random() < 0.5
            dflip = random.random() < 0.5

            def augment(x):
                if hflip:
                    x = x.flip(-2)
                if vflip:
                    x = x.flip(-1)
                if dflip:
                    x = x.transpose(-2, -1)
                return x

            crop_lr_msi = augment(crop_lr_msi)
            crop_hr_hsi = augment(crop_hr_hsi)

        if self.band_path is not None:
            assert C == self.num_band

            # unformly sample number of bands that we want the target has
            # when num_band_min == num_band_max, we consistantly using this band
            num_b = math.floor(random.uniform(self.num_band_min, self.num_band_max+1))

            if self.num_band != num_b:
                # crop_hr_hsi: shape (num_b, h_hr, w_hr), interpolated image
                crop_hr_hsi = interpolate_bands(img = crop_hr_hsi, bands = num_b)

            # cur_band_intervals: shape (num_b, 2), the current band intervals
            cur_band_intervals = get_band_interval(s_min = self.spec_min, s_max = self.spec_max, 
                                    num_band = num_b)
            # band_coord: shape (num_b, 2), the band interval coordinates
            band_coord = make_band_coords(s_intervals = cur_band_intervals, 
                                    s_min = self.spec_min, s_max = self.spec_max)
            band_coord = torch.from_numpy(band_coord).float()
        else:
            band_coord = torch.FloatTensor([0.0])






        '''
        hr_coord: shape (H * W, 2), coordinate mat
        hr_rgb: shape (H * W, C or num_b), image value
        '''
        # print(s, w_lr, w_hr, crop_hr_hsi.shape, crop_hr_hsi.shape[-2:])
        hr_coord, hr_rgb = to_pixel_samples(crop_hr_hsi)

        if self.sample_q is not None:
            sample_lst = np.random.choice(
                len(hr_coord), self.sample_q, replace=False)
            '''
            hr_coord: shape (sample_q, 2), coordinate mat
            hr_rgb: shape (sample_q, C or num_b), image value
            '''
            hr_coord = hr_coord[sample_lst]
            hr_rgb = hr_rgb[sample_lst]

        # cell: shape (sample_q, 2), cell size, (2/h, 2/w)
        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / crop_hr_hsi.shape[-2]
        cell[:, 1] *= 2 / crop_hr_hsi.shape[-1]

        if self.band_path is not None:
            if self.num_band_sample is not None:
                # sample num_band_min bands such that we can do min-batch training
                if num_b >= self.num_band_sample:
                    band_sample_idx = np.random.choice(np.arange(0, num_b), 
                        self.num_band_sample, replace=False)
                else:
                    band_sample_idx = np.random.choice(np.arange(0, num_b), 
                        self.num_band_sample, replace=True)

                band_sample_idx = np.sort(band_sample_idx)
                assert hr_rgb.shape[-1] == num_b
                # hr_rgb: shape (sample_q, num_band_sample), image value
                hr_rgb = hr_rgb[:, band_sample_idx]
                # band_coord: shape (num_band_sample, 2), the band coordinates
                band_coord = band_coord[band_sample_idx, :]


        '''
        inp -> crop_lr_msi: shape (c, w_lr, w_lr), low-res MSI image
        coord -> hr_coord: shape (sample_q, 2), coordinate mat
        cell -> cell: shape (sample_q, 2), cell size, (2/h, 2/w)
        gt -> hr_rgb: shape (sample_q, C or num_band_sample), high-res HSI image value
        band_coord -> band_coord: shape (C or num_band_sample, 2), band interval coordinates
                            or just 0 when we do not send band_path
        '''
        return {
            'inp': crop_lr_msi,
            'coord': hr_coord,
            'cell': cell,
            'gt': hr_rgb,
            # 'crop_hr_hsi': crop_hr_hsi,
            'band_coord': band_coord
        }