import itertools
import os
import pickle
import random
from abc import ABC, abstractmethod

from PIL import Image
from os.path import join

import numpy as np
from typing import Callable, Optional, List, Tuple, Any
import pytorch_lightning as pl

import torch
from torch.utils import data
from torch.utils.data import Subset, random_split

from conf.dataset import SupervisionDatasetConfig, DatasetParams
from utils.utils import display_tensor, display_mask, read_list_from_file, write_list_to_file

sep = 'v'


class Modes:
    modes = None

    @staticmethod
    def set(nb_dom: int):
        Modes.nb_dom = nb_dom

        modes = dict()
        for n in range(nb_dom + 1):
            for i in itertools.combinations(range(nb_dom), n):
                int_tuple = i
                str_tuple = sep.join((str(j) for j in i))
                tensor = torch.zeros(nb_dom, requires_grad=False)
                tensor[list(i)] = 1

                values = (int_tuple, str_tuple, tensor)
                for v in values[:-1]:
                    modes[v] = values

        Modes.modes = modes

    @staticmethod
    def get(key):
        if isinstance(key, List):
            key = tuple(key)
        return Modes.modes[key]

    @staticmethod
    def get_str(key) -> str:
        return Modes.get(key)[1]

    @staticmethod
    def get_tensor(key) -> torch.Tensor:
        return Modes.get(key)[2]

    @staticmethod
    def get_list(key) -> List:
        return list(Modes.get(key)[0])


class SupervisionDataset(data.Dataset):
    def __init__(
        self,
        dataset: data.Dataset,
        params: SupervisionDatasetConfig,
    ):
        super().__init__()
        self.dataset = dataset
        self.p = params
        Modes.set(params.number_of_domains)

        if params.random_file is not None and (params.random_supervision or params.random_from_dataset):
            random_file_exist = os.path.exists(params.random_file)
            if random_file_exist:
                raise ValueError(f'{params.random_file=} is incompatible with {params.random_supervision=} or {params.random_from_dataset=}')

        print(
            f"""
            SupervisionDataset: {self.p.proportions=}
        """)

        # return the list of token, which length <= len(dataset)
        token_list = self.get_token_list()

        if params.random_supervision:
            random.shuffle(token_list)

        # get random indices from the samples we will query from the dataset
        nb_tokens = len(token_list)
        # assert nb_tokens == len(dataset), f'Expect to have the same number of tokens({nb_tokens}) and dataset({len(dataset)})'
        if params.random_from_dataset:
            # takes nb_tokens random indices from the dataset
            dataset_indices = range(len(dataset))
            indices_dataset = random.sample(dataset_indices, nb_tokens)
        elif params.random_file is not None:
            indices_dataset = read_list_from_file(params.random_file)
            # the random file length is the dataset length, need to take less than that:
            assert len(indices_dataset) <= len(dataset), f'Expect to have the same number of tokens({nb_tokens}) and dataset({len(dataset)})'
            if len(indices_dataset) < len(dataset):
                indices_dataset = indices_dataset[:len(dataset)]
        else:
            indices_dataset = range(nb_tokens)

        if params.random_file is not None and (params.random_supervision or params.random_from_dataset):
            random_file_exist = os.path.exists(params.random_file)
            if not random_file_exist:
                print(f'[SupervisionDataset] writing random file {params.random_file}')
                write_list_to_file(file_path=params.random_file, integer_list=indices_dataset)

        # only keep the targeted part from supervision_mode
        final_dataset_indices, final_tokens = self.select_on_supervision_mode(indices_dataset, token_list)

        # todo: implem using abso in order to remove this issue, and assert on abso, warning on other
        # assert len(final_dataset_indices) == len(final_tokens)
        if len(final_dataset_indices) < len(final_tokens):
            print(f'[SupervisionDataset] {len(final_dataset_indices)=} < {len(final_tokens)=}')
            final_tokens = final_tokens[:len(final_dataset_indices)]

        self.dataset = data.Subset(dataset, final_dataset_indices)
        self.tokens = final_tokens

    def get_token_list(self) -> List[str]:
        """
        return a list of list[int] for each element of the dataset, where the list[int] represent the data available
        """
        dataset_length = len(self.dataset)
        params = self.p

        if params.proportions is not None and params.proportions_file is not None:
            raise ValueError(f'Only one of {params.proportions=} and {params.proportions_file=} can be set')

        if params.proportions is None and params.proportions_file is None:
            number_of_domains = params.number_of_domains
            return [sep.join(str(i) for i in range(number_of_domains))] * dataset_length

        if params.proportions_file is not None:
            assert params.proportions_mode == 'abso', f'Only {params.proportions_mode=} is supported with {params.proportions_file=}'
            with open(params.proportions_file, 'r') as file:
                content = file.read()
                proportions = [x for x in content.split(',') if x.strip()]
                print(proportions)
        else:
            proportions = params.proportions

        tokens = []
        for i in range(len(proportions) // 2):
            domain_rpz = proportions[i*2]

            if params.proportions_mode == 'frac':
                proportion = float(proportions[i*2+1])
                nb_tokens = int(dataset_length * proportion)
            elif params.proportions_mode == 'abso':
                proportion = int(proportions[i*2+1])
                nb_tokens = proportion
            else:
                raise ValueError(f'Unsupported {params.proportions_mode=}')

            tokens += [domain_rpz] * nb_tokens

        return tokens

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i: int):
        sample = self.dataset[i]
        token  = self.tokens[i]
        tensor_token = Modes.get_tensor(token)

        if self.p.return_supervision:
            return sample, tensor_token
        else:
            return sample

    def select_on_supervision_mode(self, indices_dataset: List[int], token_list: List[str]):
        """
        Filter the indices in order to only keep that supervision mode selected
        """
        filtered_dataset_indices = []
        filtered_tokens = []

        supervision_mode = self.p.supervision_mode
        if supervision_mode is None:
            return indices_dataset, token_list

        for dataset_indice_i, token_i in zip(indices_dataset, token_list):
            if token_i in supervision_mode:
                filtered_dataset_indices.append(dataset_indice_i)
                filtered_tokens.append(token_i)

        return filtered_dataset_indices, filtered_tokens


class CustomDataModule(pl.LightningDataModule, ABC):
    def __init__(self, params: DatasetParams):
        super().__init__()
        self.train_dataset = None
        self.valid_dataset = None
        self.test_dataset = None
        self.p = params
        self.batch_size = params.batch_size

    @abstractmethod
    def _fetch_base_dataset(self) -> Tuple[data.Dataset, data.Dataset, data.Dataset]:
        """
        Return train, valid and test dataset
        """
        pass

    def setup(self, stage: Optional[str] = None) -> None:
        base_train_dataset, base_valid_dataset, base_test_dataset = self._fetch_base_dataset()

        train_dataset = base_train_dataset
        valid_dataset = base_valid_dataset
        test_dataset = base_test_dataset

        # region Add supervision wrapper
        train_dataset = SupervisionDataset(train_dataset, self.p.supervision_params_train)
        valid_dataset = SupervisionDataset(valid_dataset, self.p.supervision_params_valid)
        test_dataset = SupervisionDataset(test_dataset, self.p.supervision_params_test)

        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.test_dataset  = test_dataset

        print(f'{len(self.train_dataset)=}')
        print(f'{len(self.valid_dataset)=}')
        print(f'{len(self.test_dataset)=}')

    def train_dataloader(self):
        dataset = self.train_dataset
        if self.p.use_min_for_batch_size and self.p.drop_last_train and self.batch_size > len(dataset):
            print(f'[DropLast + Train dataset size = {len(dataset)} < {self.batch_size=}] => set batch size to dataset size'
                  f'this ensure that we do not have an empty dataset with drop last = True')
            self.batch_size = len(dataset)

        return data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.p.workers,
            pin_memory=self.p.pin_memory,
            drop_last=self.p.drop_last_train,
        )

    def val_dataloader(self):
        dataset = self.valid_dataset
        return data.DataLoader(
            dataset,
            batch_size=self.p.batch_size_val,
            shuffle=False,
            num_workers=self.p.workers,
            pin_memory=self.p.pin_memory,
            drop_last=self.p.drop_last_valid,
        )

    def test_dataloader(self):
        dataset = self.test_dataset
        return data.DataLoader(
            dataset,
            batch_size=self.p.batch_size_test,
            shuffle=False,
            num_workers=self.p.workers,
            pin_memory=self.p.pin_memory,
            drop_last=self.p.drop_last_test,
        )

    def split_dataset(self, dataset: data.Dataset):
        """
        Instantiate the datasets and split them into train, val, test
        """
        len_d = len(dataset)
        proportion_mode = self.p.proportion_mode

        if proportion_mode == 'frac':
            train_size = int(len_d * self.p.train_prop)
            valid_size = int(len_d * self.p.valid_prop)
            test_size = len_d - train_size - valid_size
        elif proportion_mode == 'perc':
            train_size = int(len_d * self.p.train_prop / 100)
            valid_size = int(len_d * self.p.valid_prop / 100)
            test_size = len_d - train_size - valid_size
        elif proportion_mode == 'abso':
            train_size = self.p.train_prop
            valid_size = self.p.valid_prop
            test_size = self.p.test_prop
        else:
            raise ValueError(f'{self.p.proportion_mode=}')

        self.train_size = train_size
        self.valid_size = valid_size
        self.test_size = test_size
        print(
            f"""
            Splitting size: {train_size=} {valid_size=} {test_size=}
        """)

        # split the dataset
        if self.p.file_path is None or not os.path.exists(self.p.file_path):
            train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

            if self.p.file_path is not None and not os.path.exists(self.p.file_path):
                ids_list = list(train_dataset.indices) + list(valid_dataset.indices) + list(test_dataset.indices)
                print(f'[CustomDataModule:split_dataset] Write indices to {self.p.file_path}')
                write_list_to_file(file_path=self.p.file_path, integer_list=ids_list)
        else:
            indices = read_list_from_file(self.p.file_path)
            assert len(indices) == len(dataset), f'{len(indices)=} != {len(dataset)=}'
            indices_train = indices[:train_size]
            indices_valid = indices[train_size:train_size + valid_size]
            indices_test = indices[train_size + valid_size:train_size + valid_size + test_size]

            train_dataset = Subset(dataset, indices_train)
            valid_dataset = Subset(dataset, indices_valid)
            test_dataset  = Subset(dataset, indices_test)

        print("split_dataset >")
        print(f'{len(train_dataset)=}')
        print(f'{len(valid_dataset)=}')
        print(f'{len(test_dataset)=}')

        if self.p.limit_train is not None:
            train_dataset = Subset(train_dataset, range(self.p.limit_train))
        if self.p.limit_valid is not None:
            valid_dataset = Subset(valid_dataset, range(self.p.limit_valid))
        if self.p.limit_test is not None:
            test_dataset = Subset(test_dataset, range(self.p.limit_test))

        print("After limit: split_dataset >")
        print(f'{len(train_dataset)=}')
        print(f'{len(valid_dataset)=}')
        print(f'{len(test_dataset)=}')

        return train_dataset, valid_dataset, test_dataset
