""" 
Remember to parameterize the file paths eventually
"""
import torch
import torch.nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
import os
try:
    from mixed_dset_sampler import MultisetSampler
    from hdf5_datasets import *
except ImportError:
    from .mixed_dset_sampler import MultisetSampler
    from .hdf5_datasets import *
import os
import glob

broken_paths = []
# IF YOU ADD A NEW DSET MAKE SURE TO UPDATE THIS MAPPING SO MIXED DSET KNOWS HOW TO USE IT
DSET_NAME_TO_OBJECT = {
            'burgers': Burgers1DDataset,
            'diffsorb': DiffSorb1DDataset,
            'readiff': ReaDiff1DDataset,
            'compNS1d': CompNS1dDataset,
            'adv': Adv1DDataset,
            'swe': SWEDataset,
            'incompNS': IncompNSDataset,
            'diffre2d': DiffRe2DDataset,
            'compNS': CompNSDataset,
            'compNS3d': CompNS3dDataset,
            }

def get_data_loader(params, paths, distributed, split='train', rank=0, train_offset=0):
    # paths, types, include_string = zip(*paths)
    dataset = MixedDataset(paths, n_steps=params.n_steps, train_val_test=params.train_val_test, split=split,
                            tie_fields=params.tie_fields, use_all_fields=params.use_all_fields, enforce_max_steps=params.enforce_max_steps, 
                            train_offset=train_offset)
    # dataset = IncompNSDataset(paths[0], n_steps=params.n_steps, train_val_test=params.train_val_test, split=split)
    seed = torch.random.seed() if 'train'==split else 0
    if distributed:
        base_sampler = DistributedSampler
    else:
        base_sampler = RandomSampler
    sampler = MultisetSampler(dataset, base_sampler, params.batch_size,
                               distributed=distributed, max_samples=params.epoch_size, 
                               rank=rank)
    # sampler = DistributedSampler(dataset) if distributed else None
    dataloader = DataLoader(dataset,
                            batch_size=int(params.batch_size),
                            num_workers=params.num_data_workers,
                            shuffle=False, #(sampler is None),
                            sampler=sampler, # Since validation is on a subset, use a fixed random subset,
                            drop_last=True,
                            pin_memory=torch.cuda.is_available(),
                            prefetch_factor=32, persistent_workers=split=='train')
    return dataloader, dataset, sampler
    

class MixedDataset(Dataset):
    def __init__(self, path_list=[], n_steps=1, dt=1, train_val_test=(.8, .1, .1),
                  split='train', tie_fields=True, use_all_fields=True, extended_names=False, 
                  enforce_max_steps=False, train_offset=0):
        super().__init__()
        # Global dicts used by Mixed DSET. 
        self.train_offset = train_offset
        self.path_list, self.type_list, self.include_string = zip(*path_list)
        self.tie_fields = tie_fields
        self.extended_names = extended_names
        self.split = split
        self.sub_dsets = []
        self.offsets = [0]
        self.train_val_test = train_val_test
        self.use_all_fields = use_all_fields

        for dset, path, include_string in zip(self.type_list, self.path_list, self.include_string):
            subdset = DSET_NAME_TO_OBJECT[dset](path, include_string, n_steps=n_steps,
                                                 dt=dt, train_val_test=train_val_test, split=split)
            # Check to make sure our dataset actually exists with these settings
            try:
                len(subdset)
            except ValueError:
                raise ValueError(f'Dataset {path} is empty. Check that n_steps < trajectory_length in file.')
            self.sub_dsets.append(subdset)
            self.offsets.append(self.offsets[-1]+len(self.sub_dsets[-1]))
        self.offsets[0] = -1

        self.subset_dict = self._build_subset_dict()

    def get_state_names(self):
        name_list = []
        if self.use_all_fields:
            for name, dset in DSET_NAME_TO_OBJECT.items():
                field_names = dset._specifics()[2]
                name_list += field_names
            return name_list
        else:
            visited = set()
            for dset in self.sub_dsets:
                    name = dset.get_name() # Could use extended names here
                    if not name in visited:
                        visited.add(name)
                        name_list.append(dset.field_names)
        
        # return [f for fl in name_list for f in fl] # Flatten the names
        return ["first", "second", "third", "forth", "fifth"]

    def _build_subset_dict(self):
        # Maps fields to subsets of variables
        # if self.tie_fields: # Hardcoded, but seems less effective anyway
        #     subset_dict = {
        #                 'burgers': [0],
        #                 'diffsorb': [9, 10],
        #                 'readiff': [7],
        #                 'compNS1d': [0, 3, 4],
        #                 'adv': [8],
        #                 'swe': [4],
        #                 'incompNS': [0, 1, 3],
        #                 'compNS': [0, 1, 3, 4],
        #                 'diffre2d': [5, 6],
        #                 'compNS3d': [0, 1, 2, 3, 4],
        #                 }
        # elif self.use_all_fields:
        #     cur_max = 0
        #     subset_dict = {}
        #     for name, dset in DSET_NAME_TO_OBJECT.items():
        #         field_names = dset._specifics()[2]
        #         subset_dict[name] = list(range(cur_max, cur_max + len(field_names)))
        #         cur_max += len(field_names)
        # else:
        #     subset_dict = {}
        #     cur_max = self.train_offset
        #     for dset in self.sub_dsets:
        #         name = dset.get_name(self.extended_names)
        #         if not name in subset_dict:
        #             subset_dict[name] = list(range(cur_max, cur_max + len(dset.field_names)))
        #             cur_max += len(dset.field_names)
        subset_dict = {
            "burgers": [0, 1, 2, 3, 4],
            'diffsorb': [0, 1, 2, 3, 4],
            'readiff': [0, 1, 2, 3, 4],
            'compNS1d': [0, 1, 2, 3, 4],
            'adv': [0, 1, 2, 3, 4],
            'swe': [0, 1, 2, 3, 4],
            'incompNS': [0, 1, 2, 3, 4],
            'compNS': [0, 1, 2, 3, 4],
            'diffre2d': [0, 1, 2, 3, 4],
            'compNS3d': [0, 1, 2, 3, 4],
        }
        return subset_dict

    def __getitem__(self, index):
        file_idx = np.searchsorted(self.offsets, index, side='right')-1 #which dataset are we are on
        local_idx = index - max(self.offsets[file_idx], 0)
        try:
            x, bcs, y = self.sub_dsets[file_idx][local_idx]
        except:
            print('FAILED AT ', file_idx, local_idx, index,int(os.environ.get("RANK", 0)))
            thisvariabledoesntexist
        return x, file_idx, torch.tensor(self.subset_dict[self.sub_dsets[file_idx].get_name()]), bcs, y
    
    def __len__(self):
        return sum([len(dset) for dset in self.sub_dsets])

        # field_names = ["0", 'h', "0", "0", "0"]
        # field_names = ["0", "0", "0", 'activator', 'inhibitor']
        # # field_names = ["0", 'Vx', 'Vy', 'particles']
        # field_names = ["0", 'Vx', 'Vy', 'density', 'pressure']
        # field_names = ['Vx', "0", "0", "0", "0"]
        # field_names = ["0", "0", 'u', "0", "0"]
        # field_names = ["0", "0", "0", 'u', "0"]
        # field_names = ["0", "0", "0", "0", 'u']
        # field_names = ['Vx', "0", "0", 'density', 'pressure']
        # field_names = ['Vx', 'Vy', 'Vz', 'density', 'pressure']
        # # field_names = ['Vx', 'Vy', 'density', 'pressure']