import functools
import random
import math
from PIL import Image

import numpy as np
import torch
from torch.utils.data import Dataset
# from torchvision import transforms
# from torchvision.transforms import InterpolationMode
# import torch.nn.functional as F

from datasets import register
from utils import to_pixel_samples, get_band_interval, make_band_coords, resize_fn, interpolate_bands


@register('cave-sr-implicit-downsampled-paired')
class CAVESRImplicitDownsampled(Dataset):

    def __init__(self, dataset, inp_size=None, scale_min=1, scale_max=None,
                 augment=False, sample_q=None,
                 band_path=None,
                 num_band_min=None, num_band_max=None, num_band_sample=None,
                 spec_min=None, spec_max=None):
        '''
        Args:
            dataset: the dataset obj we create from CAVEPairedImageFolders() in cave_image_folder.py
            inp_size: int, crop image size
            scale_min: int, smallest scale
            scale_max: int, maximum scale
            augment: bool, whether to do augumentation, horizantal/vertical flip/rotate
            sample_q: int, number of (coord, rbg_v) pairs we want to sample

            band_path: the path to save the band wavelength intervals, shape (num_band, 2)
            num_band_min: the minimum number of bands to do super-res
            num_band_max: the maximum number of bands to do super-res
            num_band_sample: the number of band to sample during mini-batch training

            spec_min: the minimum wavelength
            spec_max: the maximum wavelength
        '''
        self.dataset = dataset
        self.inp_size = inp_size
        self.scale_min = scale_min
        if scale_max is None:
            scale_max = scale_min
        self.scale_max = scale_max
        self.augment = augment
        self.sample_q = sample_q

        self.band_path = band_path
        self.num_band_sample = num_band_sample
        if self.band_path is not None:
            # band_intervals:  shape (num_band, 2), the band wavelength intervals
            self.band_intervals = np.load(band_path)

            assert len(self.band_intervals.shape) == 2 and self.band_intervals.shape[-1] == 2

            # get the min and max band number for different band interpolation purpose
            num_band = self.band_intervals.shape[0]
            if num_band_max is None:
                num_band_max = num_band
            if num_band_min is None:
                num_band_min = num_band
            # assert num_band_min <= num_band_max 
            # assert num_band_min <= num_band
            # assert num_band_max >= num_band
            self.num_band_min = num_band_min
            self.num_band_max = num_band_max
            self.num_band = num_band

            # get the min and max wavelength which are used to compute spectral coordinates
            if spec_min is None:
                spec_min = self.band_intervals[0, 0]
            if spec_max is None:
                spec_max = self.band_intervals[-1, -1]
            assert spec_min <= spec_max
            self.spec_min = spec_min
            self.spec_max = spec_max

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # hr_msi: shape (c, H, W)
        # hr_hsi: shape (C, H, W)
        hr_msi, hr_hsi = self.dataset[idx]
        assert hr_msi.shape[-2:] == hr_hsi.shape[-2:]
        c, H, W = hr_msi.shape
        C, H, W = hr_hsi.shape

        # unformly sample a scale
        s = random.uniform(self.scale_min, self.scale_max)

        if self.inp_size is None:
            h_lr = math.floor(H / s + 1e-9)
            w_lr = math.floor(W / s + 1e-9)
            hr_msi = hr_msi[:, :round(h_lr * s), :round(w_lr * s)]  # assume round int
            hr_hsi = hr_hsi[:, :round(h_lr * s), :round(w_lr * s)]  # assume round int

            lr_msi = resize_fn(hr_msi, (h_lr, w_lr))
            # lr_hsi = resize_fn(hr_hsi, (h_lr, w_lr))
            crop_lr_msi, crop_hr_hsi = lr_msi, hr_hsi
            '''
            crop_lr_msi: shape (c, H/s, W/s)
            crop_hr_hsi: shape (C, H, W)
            '''
        else:
            w_lr = self.inp_size
            w_hr = round(w_lr * s)
            x0 = random.randint(0, H - w_hr)
            y0 = random.randint(0, W - w_hr)
            crop_hr_msi = hr_msi[:, x0: x0 + w_hr, y0: y0 + w_hr]
            crop_hr_hsi = hr_hsi[:, x0: x0 + w_hr, y0: y0 + w_hr]

            # crop_lr_msi: shape (c, w_lr, w_lr)
            crop_lr_msi = resize_fn(crop_hr_msi, w_lr)
            '''
            crop_lr_msi: shape (c, inp_size = w_lr, inp_size = w_lr)
            crop_hr_hsi: shape (C, inp_size * s, inp_size * s)
            '''

        # Do data augumentation
        if self.augment:
            hflip = random.random() < 0.5
            vflip = random.random() < 0.5
            dflip = random.random() < 0.5

            def augment(x):
                if hflip:
                    x = x.flip(-2)
                if vflip:
                    x = x.flip(-1)
                if dflip:
                    x = x.transpose(-2, -1)
                return x

            crop_lr_msi = augment(crop_lr_msi)
            crop_hr_hsi = augment(crop_hr_hsi)

        if self.band_path is not None:
            assert C == self.num_band

            # unformly sample number of bands that we want the target has
            # when num_band_min == num_band_max, we consistantly using this band
            num_b = math.floor(random.uniform(self.num_band_min, self.num_band_max + 1))

            if self.num_band != num_b:
                # crop_hr_hsi: shape (num_b, h_hr, w_hr), interpolated image
                crop_hr_hsi = interpolate_bands(img=crop_hr_hsi, bands=num_b)

            # cur_band_intervals: shape (num_b, 2), the current band intervals
            cur_band_intervals = get_band_interval(s_min=self.spec_min, s_max=self.spec_max,
                                                   num_band=num_b)
            # band_coord: shape (num_b, 2), the band interval coordinates
            band_coord = make_band_coords(s_intervals=cur_band_intervals,
                                          s_min=self.spec_min, s_max=self.spec_max)
            band_coord = torch.from_numpy(band_coord).float()
        else:
            band_coord = torch.FloatTensor([0.0])

        '''
        hr_coord: shape (H * W, 2), coordinate mat
        hr_rgb: shape (H * W, C or num_b), image value
        '''
        # print(s, w_lr, w_hr, crop_hr_hsi.shape, crop_hr_hsi.shape[-2:])
        hr_coord, hr_rgb = to_pixel_samples(crop_hr_hsi)

        if self.sample_q is not None:
            sample_lst = np.random.choice(
                len(hr_coord), self.sample_q, replace=False)
            '''
            hr_coord: shape (sample_q, 2), coordinate mat
            hr_rgb: shape (sample_q, C or num_b), image value
            '''
            hr_coord = hr_coord[sample_lst]
            hr_rgb = hr_rgb[sample_lst]

        # cell: shape (sample_q, 2), cell size, (2/h, 2/w)
        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / crop_hr_hsi.shape[-2]
        cell[:, 1] *= 2 / crop_hr_hsi.shape[-1]

        if self.band_path is not None:
            if self.num_band_sample is not None:
                # sample num_band_min bands such that we can do min-batch training
                if num_b >= self.num_band_sample:
                    band_sample_idx = np.random.choice(np.arange(0, num_b),
                                                       self.num_band_sample, replace=False)
                else:
                    band_sample_idx = np.random.choice(np.arange(0, num_b),
                                                       self.num_band_sample, replace=True)

                band_sample_idx = np.sort(
                    band_sample_idx)
                assert hr_rgb.shape[-1] == num_b
                # hr_rgb: shape (sample_q, num_band_sample), image value
                hr_rgb = hr_rgb[:, band_sample_idx]
                # band_coord: shape (num_band_sample, 2), the band coordinates
                band_coord = band_coord[band_sample_idx, :]

        '''
        inp -> crop_lr_msi: shape (c, w_lr, w_lr), low-res MSI image
        coord -> hr_coord: shape (sample_q, 2), coordinate mat
        cell -> cell: shape (sample_q, 2), cell size, (2/h, 2/w)
        gt -> hr_rgb: shape (sample_q, C or num_band_sample), high-res HSI image value
        band_coord -> band_coord: shape (C or num_band_sample, 2), band interval coordinates
                            or just 0 when we do not send band_path
        '''
        return {
            'inp': crop_lr_msi,
            'coord': hr_coord,
            'cell': cell,
            'gt': hr_rgb,
            # 'crop_hr_hsi': crop_hr_hsi
            'band_coord': band_coord
        }

