""" Mostly inspired from the repository https://github.com/PolymathicAI/multiple_physics_pretraining """
from typing import Tuple, List, TypeVar, Iterator, Dict
import os
from itertools import product
from pathlib import Path
import numpy as np
import torch

from src.sf_euler_data.datamodule import SfEulerDataModule
from torch.utils.data import (
    Sampler, Dataset, DataLoader, RandomSampler, DistributedSampler
)
from src.global_constants import *
from src.data.hdf5_datasets import (
    SWEDataset, IncompNSDataset, DiffRe2DDataset, CompNSDataset, BurgersDataset,
)


DSET_NAME_TO_OBJECT = {
    'swe': SWEDataset,
    'incompNS': IncompNSDataset,
    'diffre2d': DiffRe2DDataset,
    'compNS': CompNSDataset,
    'burgers': BurgersDataset,
}


class MixedDataset(Dataset):
    def __init__(self, path_list=[], n_steps=1, dt=1, train_val_test=(.8, .1, .1),
                  split='train', tie_fields=True, use_all_fields=True, extended_names=False, 
                  enforce_max_steps=False, train_offset=0):
        super().__init__()
        # Global dicts used by Mixed DSET. 
        self.train_offset = train_offset
        self.path_list, self.type_list, self.include_string, self.reso = zip(*path_list)
        self.reso = [[int(d) for d in tuple(str(res).split('x'))] for res in self.reso]
        self.tie_fields = tie_fields
        self.extended_names = extended_names
        self.split = split
        self.sub_dsets = []
        self.offsets = [0]
        self.train_val_test = train_val_test
        self.use_all_fields = use_all_fields

        for dset, path, include_string, reso in zip(self.type_list, self.path_list, self.include_string, self.reso):
            subdset = DSET_NAME_TO_OBJECT[dset](
                path, reso, include_string, n_steps=n_steps, dt=dt, train_val_test=train_val_test, split=split
            )
            # Check to make sure our dataset actually exists with these settings
            try:
                len(subdset)
            except ValueError:
                raise ValueError(f'Dataset {path} is empty. Check that n_steps < trajectory_length in file.')
            self.sub_dsets.append(subdset)
            self.offsets.append(self.offsets[-1]+len(self.sub_dsets[-1]))
        self.offsets[0] = -1

        self.subset_dict = self._build_subset_dict()

    def get_state_names(self):
        name_list = []
        if self.use_all_fields:
            for name, dset in DSET_NAME_TO_OBJECT.items():
                field_names = dset._specifics()[2]
                name_list += field_names
            return name_list
        else:
            visited = set()
            for dset in self.sub_dsets:
                    name = dset.get_name() # Could use extended names here
                    if not name in visited:
                        visited.add(name)
                        name_list.append(dset.field_names)
        return [f for fl in name_list for f in fl] # Flatten the names

    def _build_subset_dict(self):
        # Maps fields to subsets of variables
        if self.tie_fields: # Hardcoded, but seems less effective anyway
            subset_dict = {
                'swe': [3],
                'incompNS': [0, 1, 2],
                'compNS': [0, 1, 2, 3],
                'diffre2d': [4, 5]
            }
        elif self.use_all_fields:
            cur_max = 0
            subset_dict = {}
            for name, dset in DSET_NAME_TO_OBJECT.items():
                field_names = dset._specifics()[2]
                subset_dict[name] = list(range(cur_max, cur_max + len(field_names)))
                cur_max += len(field_names)
        else:
            subset_dict = {}
            cur_max = self.train_offset
            for dset in self.sub_dsets:
                name = dset.get_name(self.extended_names)
                if not name in subset_dict:
                    subset_dict[name] = list(range(cur_max, cur_max + len(dset.field_names)))
                    cur_max += len(dset.field_names)
        return subset_dict

    def __getitem__(self, index):
        file_idx = np.searchsorted(self.offsets, index, side='right')-1 #which dataset are we are on
        local_idx = index - max(self.offsets[file_idx], 0)
        try:
            dict = self.sub_dsets[file_idx][local_idx]
        except:
            print('FAILED AT ', file_idx, local_idx, index,int(os.environ.get("RANK", 0)))
            thisvariabledoesntexist
        # return x, file_idx, torch.tensor(self.subset_dict[self.sub_dsets[file_idx].get_name()]), bcs, y
        dataset = self.sub_dsets[file_idx]
        # subsample at the required resolution
        dict['field_labels'] = torch.tensor(self.subset_dict[dataset.get_name()])
        dict['file_index'] = file_idx
        return dict
    
    def __len__(self):
        return sum([len(dset) for dset in self.sub_dsets])


T_co = TypeVar('T_co', covariant=True)


class MultisetSampler(Sampler[T_co]):
    """ Sampler that restricts data loading to a subset of the dataset. """

    def __init__(self, 
        dataset: MixedDataset, batch_size: int, 
        shuffle: bool = True, seed: int = 0,
        max_samples: int = 10, rank: int = 0, 
        distributed: bool = True,
    ):
        self.batch_size = batch_size
        self.sub_dsets = dataset.sub_dsets
        if distributed:
            sampler = DistributedSampler
        else:
            sampler = RandomSampler
        self.sub_samplers = [sampler(dataset) for dataset in self.sub_dsets]
        self.dataset = dataset
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed
        self.max_samples = max_samples
        self.rank = rank

    def __iter__(self) -> Iterator[T_co]:
        samplers = [iter(sampler) for sampler in self.sub_samplers]
        sampler_choices = list(range(len(samplers)))
        generator = torch.Generator()
        generator.manual_seed(100*self.epoch+10*self.seed+self.rank)
        count = 0
        while len(sampler_choices) > 0:
            count += 1
            index_sampled = torch.randint(0, len(sampler_choices), size=(1,), generator=generator).item()
            dset_sampled = sampler_choices[index_sampled]
            offset = max(0, self.dataset.offsets[dset_sampled])
            # Do drop last batch type logic - if you can get a full batch, yield it, otherwise move to next dataset
            try:
                queue = []
                for _ in range(self.batch_size):
                    queue.append(next(samplers[dset_sampled]) + offset)
                if len(queue) == self.batch_size:
                    for d in queue:
                        yield d
            except Exception as err:
                print('ERRRR', err)
                sampler_choices.pop(index_sampled)
                print(f'Note: dset {dset_sampled} fully used. Dsets remaining: {len(sampler_choices)}')
                continue
            if count >= self.max_samples:
                break
    
    def __len__(self) -> int:
        return len(self.dataset)

    def set_epoch(self, epoch: int) -> None:
        """ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        for sampler in self.sub_samplers:
            sampler.set_epoch(epoch)
        self.epoch = epoch


def get_data_objects(
    paths: List[Tuple[Path|str, str]] | None,
    batch_size: int,
    epoch_size: int,
    train_val_test: Tuple[float],
    n_past: int, 
    n_future: int,
    distributed: bool, 
    num_data_workers: int,
    rank: int, 
    world_size: int,
    split: str,
    data_params: Tuple | None,
    template_name: str | None,
    mode: str | None,
) -> Tuple:
    if mode == "PDEBench":
        dataset = MixedDataset(
            paths, n_steps=n_past, train_val_test=train_val_test, split=split,
            tie_fields=False, use_all_fields=True, enforce_max_steps=False, 
            train_offset=0
        )
    elif mode == "sf_euler":
        # H = W = 512 // ss
        data_module = SfEulerDataModule(
            base_path=paths[0][0],
            dataset_name=paths[0][1],
            resolution=paths[0][2],
            batch_size=batch_size,
            n_steps_input=n_past, #n_timesteps,
            n_steps_output=n_future, #n_timesteps output
            data_workers=0,
            world_size=world_size,
            rank=rank,
            include_filters=paths[0][3],
        )
        dataloader = {
            'train': data_module.train_dataloader(),
            'val': data_module.val_dataloader(),
            'test': data_module.test_dataloader(),
        }[split]
        dataset = {
            'train': data_module.train_dataset,
            'val': data_module.val_dataset,
            'test': data_module.test_dataset,
        }[split]
        return dataset, dataloader.sampler, dataloader
    else:
        raise ValueError(f"Mode {mode} not recognized.")

    sampler = MultisetSampler(
        dataset, batch_size, distributed=distributed, 
        max_samples=epoch_size, rank=rank
    )

    dataloader = DataLoader(
        dataset, batch_size=int(batch_size), 
        num_workers=num_data_workers,
        shuffle=False, #(sampler is None),
        sampler=sampler, # Since validation is on a subset, use a fixed random subset,
        drop_last=True,
        pin_memory=torch.cuda.is_available()
    )

    return dataset, sampler, dataloader
