from typing import Tuple
from ._utils import *
from abc import abstractmethod

import pandas as pd
import torch
from torchvision.transforms import transforms
from torch.utils import data
from PIL import Image



class DataLoader:
    """
    A bas class generator of generators for general-purpose meta-learning regression tasks.
    """
    def __init__(self, 
                 data_path, 
                 envs_batch_size=250, 
                 envs_shuffle=True, 
                 shots_batch_size=1,
                 shots_shuffle=False, 
                 data_split="train", 
                 key=None):

        self.data_path = data_path
        self.envs_batch_size = envs_batch_size
        self.envs_shuffle = envs_shuffle

        self.adaptation = data_split in ["adapt", "test"]
        self.shots_batch_size = shots_batch_size
        self.shots_shuffle = shots_shuffle

        ## Define in the child class
        self.nb_batches = None

        if shots_batch_size <= 0 or shots_batch_size <=0 :
            raise ValueError("A batch size must be greater than 0.")

        self.key = key
        if (self.envs_shuffle or self.shots_shuffle) and self.key is None:
            raise ValueError("Shuffling the dataset requires a key.")


    @abstractmethod
    def sample_environments(self, key, batch_id, nb_envs):
        """ Provides a stateless way to sample a batch of environments """
        pass

    @abstractmethod
    def __iter__(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """ Loads, transforms and yields a batch of environments """
        pass

    @abstractmethod
    def __len__(self):
        """ Total number of environments / envs_batch size. """
        pass

    def __str__(self) -> str:
        return f"Dataloader properties: \n" + \
                f"Total number of environments: {len(self)} \n" + \
                f"Batch size (envs): {self.envs_batch_size} \n" + \
                f"Number of points per environment = batch size (datapoints): {self.shots_batch_size} \n" + \
                f"Input dimension: {self.input_dim} \n" + \
                f"Output dimension: {self.output_dim} \n"





class CelebADataLoader(data.DataLoader):
    """
    A celeb a dataloader for meta-learning.
    """
    def __init__(self, 
                 data_path="./data/", 
                 envs_batch_size=250, 
                 envs_shuffle=True, 
                 shots_batch_size=100,
                 shots_shuffle=False, 
                 data_split="train",
                 resolution=(32, 32),
                 order_pixels=False,
                 key=None):

        super().__init__(data_path, 
                        envs_batch_size, 
                        envs_shuffle, 
                        shots_batch_size, 
                        shots_shuffle, 
                        data_split, 
                        key)

        self.input_dim = 2
        self.output_dim = 3
        self.img_size = (*resolution, self.output_dim)
        self.order_pixels = order_pixels
        ## Read the partitioning file: train(0), val(1), test(2)

        partitions = pd.read_csv(self.data_path+'list_eval_partition.txt', 
                                 header=None, 
                                 sep=r'\s+', 
                                 names=['filename', 'partition'])
        if data_split in ["train"]:
            self.files = partitions[partitions['partition'] == 0]['filename'].values
        elif data_split in ["val"]:
            self.files = partitions[partitions['partition'] == 1]['filename'].values
        elif data_split in ["test"]:
            self.files = partitions[partitions['partition'] == 2]['filename'].values
        else:
            raise ValueError(f"Invalid data split provided. Got {data_split}")

        ## A list of MVPs images (or the worst during self-modulation) - Useful for active learning
        # self.mvp_files = self.files

        self.total_envs = len(self.files)
        if self.total_envs == 0:
            raise ValueError("No files found for the split.")
        if envs_batch_size > self.total_envs:
            raise ValueError(f"Envs batch size must be less than the total number of environments")

        self.total_pixels = self.img_size[0] * self.img_size[1]
        if shots_batch_size > self.total_pixels:
            raise ValueError(f"Few shots batch size must be less than the total number of pixels")

        ## Ssee CAVIA code: https://github.com/lmzintgraf/cavia)
        self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                            transforms.Resize((self.img_size[0], self.img_size[1]), Image.LANCZOS),
                                            transforms.ToTensor(),
                                            ])

        ## Batch bookeeping
        self.nb_batches = np.ceil(self.total_envs / self.envs_batch_size).astype(int)
        self.remainder = self.total_envs % self.envs_batch_size
        self.curr_batch_id = 0


    def get_image(self, filename) -> torch.Tensor:
        img_path = os.path.join(self.data_path+"img_align_celeba/", filename)
        img = self.transform(img_path).float()
        img = img.permute(1, 2, 0)
        return img

    def sample_pixels(self, key, img) -> Tuple[np.ndarray, jnp.ndarray]:
        total_pixels = self.img_size[0] * self.img_size[1]

        if self.order_pixels:
            flattened_indices = jnp.arange(self.shots_batch_size)
        else:
            flattened_indices = jax.random.choice(key=key, a=total_pixels, shape=(self.shots_batch_size,), replace=False)

        x, y = np.unravel_index(flattened_indices, (self.img_size[0], self.img_size[1]))
        coordinates = np.vstack((x, y)).T
        coords = torch.from_numpy(coordinates).float()
        normed_coords = (coords / torch.Tensor(self.img_size[:2])).numpy()

        pixel_values = img[coords[:, 0].long(), coords[:, 1].long(), :].numpy()

        return normed_coords, pixel_values


    def sample_environments(self, key, batch_id, nb_envs) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """ Sample a batch of environments """

        X = np.zeros((nb_envs, self.shots_batch_size, self.input_dim))
        Y = np.zeros((nb_envs, self.shots_batch_size, self.output_dim))

        if self.envs_shuffle:
            sample_idx = jax.random.choice(key=key, a=self.total_envs, shape=(nb_envs,))
            sampled_files = self.files[sample_idx]
        else:
            f_start = batch_id*self.envs_batch_size
            f_end = min([(batch_id+1)*self.envs_batch_size, self.total_envs])
            sampled_files = self.files[f_start:f_end]

        pixel_keys = jax.random.split(key, num=nb_envs)
        for e, img_name in enumerate(sampled_files):
            img = self.get_image(img_name)
            normed_coords, pixel_values = self.sample_pixels(pixel_keys[e], img)
            X[e, :, :] = normed_coords
            Y[e, :, :] = pixel_values

        return jnp.array(X), jnp.array(Y)


    def make_batch(self):

        ## Sample a batch of environments
        if self.curr_batch_id == self.nb_batches-1 and self.remainder != 0:
            X, Y = self.sample_environments(self.key, self.curr_batch_id, self.remainder)
        else:
            X, Y = self.sample_environments(self.key, self.curr_batch_id, self.envs_batch_size)

        ##  Usefull when pixels are ordered
        if self.shots_shuffle:
            X = jax.random.permutation(self.key, X, axis=1)
            Y = jax.random.permutation(self.key, Y, axis=1)

        ## Update the state of the dataloader
        self.key, _ = jax.random.split(self.key)
        self.curr_batch_id += 1

        return X, Y


    def __iter__(self):
        self.curr_batch_id = 0
        return self
        # while self.curr_batch_id < self.nb_batches:
        #     yield self.make_batch()

    def __next__(self):
        if self.curr_batch_id < self.nb_batches:
            return self.make_batch()
        else:
            raise StopIteration

    def __len__(self):
        return self.total_envs



































## Todo. It might be better to write a new DataLoader from the Torch class altogether (See NumpyLoader)
def collate_to_jax(batch):
    xs, ys = zip(*batch)
    return jnp.array(xs), jnp.array(ys)


class CelebADataset(DataLoader):
    """
    A celeb a dataloader for meta-learning.
    """
    def __init__(self, 
                 data_path="./data/", 
                 data_split="train",
                 num_shots=100,
                 resolution=(32, 32),
                 order_pixels=False,
                 seed=None):

        # ## Set seed
        # if seed is not None:
        #     np.random.seed(seed)

        if num_shots <= 0:
            raise ValueError("Number of shots must be greater than 0.")
        elif num_shots > resolution[0]*resolution[1]:
            raise ValueError("Number of shots must be less than the total number of pixels.")
        self.nb_shots = num_shots

        self.input_dim = 2
        self.output_dim = 3
        self.img_size = (*resolution, self.output_dim)
        self.order_pixels = order_pixels
        ## Read the partitioning file: train(0), val(1), test(2)

        self.data_path = data_path
        partitions = pd.read_csv(self.data_path+'list_eval_partition.txt', 
                                 header=None, 
                                 sep=r'\s+', 
                                 names=['filename', 'partition'])
        if data_split in ["train"]:
            self.files = partitions[partitions['partition'] == 0]['filename'].values
        elif data_split in ["val"]:
            self.files = partitions[partitions['partition'] == 1]['filename'].values
        elif data_split in ["test"]:
            # self.files = partitions[partitions['partition'] == 2]['filename'].values

            ## To get the translation-equivariance img in front of the test set (incl. Ellen selfie)
            self.files = partitions[(partitions['partition'] == 2) | (partitions['partition'] == 3)]['filename'].values
            self.files = np.concatenate((self.files[-1:], self.files[:-1]))

        else:
            raise ValueError(f"Invalid data split provided. Got {data_split}")

        if data_split in ["train", "val"]:
            self.adaptation = False
        elif data_split in ["test"]:
            self.adaptation = True
        else:
            raise ValueError(f"Invalid data split provided. Got {data_split}")

        ## A list of MVPs images (or the worst during self-modulation) - Useful for active learning
        # self.mvp_files = self.files

        self.total_envs = len(self.files)
        if self.total_envs == 0:
            raise ValueError("No files found for the split.")

        self.total_pixels = self.img_size[0] * self.img_size[1]

        ## Ssee CAVIA code: https://github.com/lmzintgraf/cavia)
        self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                            transforms.Resize((self.img_size[0], self.img_size[1]), Image.LANCZOS),
                                            transforms.ToTensor(),
                                            ])

    def get_image(self, filename) -> torch.Tensor:
        img_path = os.path.join(self.data_path+"img_align_celeba/", filename)
        img = self.transform(img_path).float()
        img = img.permute(1, 2, 0)
        return np.array(img)

    def sample_pixels(self, img) -> Tuple[np.ndarray, jnp.ndarray]:        ## TODO: Stay in torch throughout this function!
        total_pixels = self.img_size[0] * self.img_size[1]

        if self.order_pixels:
            flattened_indices = np.arange(self.nb_shots)
        else:
            flattened_indices = np.random.choice(total_pixels, size=self.nb_shots, replace=False)

        x, y = np.unravel_index(flattened_indices, (self.img_size[0], self.img_size[1]))
        coords = np.vstack((x, y)).T
        normed_coords = (coords / np.array(self.img_size[:2]))

        pixel_values = img[coords[:, 0], coords[:, 1], :]

        return normed_coords, pixel_values

    def set_seed_sample_pixels(self, seed, idx):
        np.random.seed(seed)
        # np.random.set_state(seed)
        img = self.get_image(self.files[idx])
        return self.sample_pixels(img)


    def __getitem__(self, idx):
        img = self.get_image(self.files[idx])
        normed_coords, pixel_values = self.sample_pixels(img)
        return normed_coords, pixel_values


    def __len__(self):
        return self.total_envs











def numpy_collate(batch):
  return jax.tree.map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)
    
    self.num_batches = np.ceil(len(dataset) / batch_size).astype(int)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))
  
















class SinusoidDataset:
    """
    Same regression task as in Finn et al. 2017 (MAML)
    """

    # def __init__(self, meta_tain=True, support_set=True):
    def __init__(self, num_envs, num_shots, adaptation=False):

        self.num_inputs = 1
        self.num_outputs = 1

        self.amplitude_range = [0.1, 5.0]
        self.phase_range = [0, np.pi]

        self.input_range = [-5, 5]

        # if meta_tain:
        #     self.total_envs = 16
        # else:   ## Meta-test
        #     self.total_envs = 32

        # if support_set:
        #     self.num_shots = 4
        # else:   ## Query set
        #     self.num_shots = 1

        self.num_shots = num_shots
        self.total_envs = num_envs

        self.adaptation = adaptation

    def sample_inputs(self, batch_size, *args, **kwargs):
        inputs = torch.rand((batch_size, self.num_inputs))
        inputs = inputs * (self.input_range[1] - self.input_range[0]) + self.input_range[0]
        return inputs

    def sample_task(self):
        amplitude = np.random.uniform(self.amplitude_range[0], self.amplitude_range[1])
        phase = np.random.uniform(self.phase_range[0], self.phase_range[1])
        return self.get_target_function(amplitude, phase)

    @staticmethod
    def get_target_function(amplitude, phase):
        def target_function(x):
            if isinstance(x, torch.Tensor):
                return torch.sin(x - phase) * amplitude
            else:
                return np.sin(x - phase) * amplitude

        return target_function

    def sample_tasks(self, num_tasks, return_specs=False):

        amplitude = np.random.uniform(self.amplitude_range[0], self.amplitude_range[1], num_tasks)
        phase = np.random.uniform(self.phase_range[0], self.phase_range[1], num_tasks)

        target_functions = []
        for i in range(num_tasks):
            target_functions.append(self.get_target_function(amplitude[i], phase[i]))

        if return_specs:
            return target_functions, amplitude, phase
        else:
            return target_functions

    def __getitem__(self, idx):     ## TODO Idx doesn't matter here. Check effect of seed as well.
        if not self.adaptation:
            np.random.seed(idx)
            # torch.manual_seed(idx)
        else:
            np.random.seed(np.iinfo(np.int32).max - idx)
            # torch.manual_seed(np.iinfo(np.int32).max - idx)

        target_func = self.sample_tasks(1, return_specs=False)[0]
        inputs = self.sample_inputs(self.num_shots)
        return inputs, target_func(inputs)

    def __len__(self):
        return self.total_envs








class DynamicsDataset:
    """
    For all dynamics tasks as in Kirchmeyer et al. 2022
    """

    # def __init__(self, meta_tain=True, support_set=True):
    def __init__(self, data_dir, num_shots=-1, skip_steps=1, adaptation=False):

        self.data_dir = data_dir
        self.skip_steps = skip_steps
        # self.adaptation = data_dir.find("adapt") != -1 or data_dir.find("ood") != -1
        self.adaptation = adaptation

        try:
            raw_data = np.load(data_dir)
        except:
            raise ValueError(f"Data not found at {data_dir}")

        self.dataset, self.t_eval = raw_data['X'][...,::self.skip_steps,:], raw_data['t'][::skip_steps]

        datashape = self.dataset.shape
        self.total_envs = datashape[0]

        if num_shots is None or num_shots == -1:
            num_shots = datashape[1]
        self.num_shots = num_shots
        if num_shots > datashape[1]:
            raise ValueError("Number of shots must be less than the total number of trajectories")

        self.num_steps = datashape[2]
        self.data_size = datashape[3]


    def __getitem__(self, idx):     ## Idx doesn't matter here
        inputs = self.dataset[idx, :, 0, :]
        outputs = self.dataset[idx, :, :, :]
        return inputs, outputs

    def __len__(self):
        return self.total_envs
