import copy

import numpy as np
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.tensor_utils as TensorUtils
from PIL import Image
from robomimic.utils.dataset import SequenceDataset
from torch.utils.data import Dataset
import torch
import random
import h5py

"""
    Helper function from Robomimic to read hdf5 demonstrations into sequence dataset

    ISSUE: robomimic's SequenceDataset has two properties: seq_len and frame_stack,
    we should in principle use seq_len, but the paddings of the two are different.
    So that's why we currently use frame_stack instead of seq_len.
"""


def get_dataset(
    dataset_path,
    obs_modality,
    initialize_obs_utils=True,
    seq_len=1,
    frame_stack=1,
    filter_key=None,
    hdf5_cache_mode="low_dim",
    *args,
    **kwargs
):

    if initialize_obs_utils:
        ObsUtils.initialize_obs_utils_with_obs_specs({"obs": obs_modality})

    all_obs_keys = []
    for modality_name, modality_list in obs_modality.items():
        all_obs_keys += modality_list
    shape_meta = FileUtils.get_shape_metadata_from_dataset(
        dataset_path=dataset_path, all_obs_keys=all_obs_keys, verbose=False
    )

    seq_len = seq_len
    filter_key = filter_key
    dataset = SequenceDataset(
        hdf5_path=dataset_path,
        obs_keys=shape_meta["all_obs_keys"],
        dataset_keys=["actions"],
        load_next_obs=False,
        frame_stack=frame_stack,
        seq_length=seq_len,  # length-10 temporal sequences
        pad_frame_stack=True,
        pad_seq_length=True,  # pad last obs per trajectory to ensure all sequences are sampled
        get_pad_mask=False,
        goal_mode=None,
        hdf5_cache_mode=hdf5_cache_mode,  # cache dataset in memory to avoid repeated file i/o
        hdf5_use_swmr=False,
        hdf5_normalize_obs=None,
        filter_by_attribute=filter_key,  # can optionally provide a filter key here
    )
    return dataset, shape_meta


class SequenceVLDataset(Dataset):
    def __init__(self, sequence_dataset, task_emb):
        self.sequence_dataset = sequence_dataset
        self.task_emb = task_emb
        self.n_demos = self.sequence_dataset.n_demos
        self.total_num_sequences = self.sequence_dataset.total_num_sequences

    def __len__(self):
        return len(self.sequence_dataset)

    def __getitem__(self, idx):
        return_dict = self.sequence_dataset.__getitem__(idx)
        return_dict["task_emb"] = self.task_emb
        return return_dict


class GroupedTaskDataset(Dataset):
    def __init__(self, sequence_datasets, task_embs):
        self.sequence_datasets = sequence_datasets
        self.task_embs = task_embs
        self.group_size = len(sequence_datasets)
        self.n_demos = sum([x.n_demos for x in self.sequence_datasets])
        self.total_num_sequences = sum(
            [x.total_num_sequences for x in self.sequence_datasets]
        )
        self.lengths = [len(x) for x in self.sequence_datasets]
        self.task_group_size = len(self.sequence_datasets)

        # create a map that maps the current idx of dataloader to original task data idx
        # imagine we have task 1,2,3, with sizes 3,5,4, then the idx looks like
        # task-1  task-2  task-3
        #   0       1       2
        #   3       4       5
        #   6       7       8
        #           9       10
        #           11
        # by doing so, when we concat the dataset, every task will have equal number of demos
        self.map_dict = {}
        sizes = np.array(self.lengths)
        row = 0
        col = 0
        for i in range(sum(sizes)):
            while sizes[col] == 0:
                col = col + 1
                if col >= self.task_group_size:
                    col -= self.task_group_size
                    row += 1
            self.map_dict[i] = (row, col)
            sizes[col] -= 1
            col += 1
            if col >= self.task_group_size:
                col -= self.task_group_size
                row += 1
        self.n_total = sum(self.lengths)

    def __len__(self):
        return self.n_total

    def __get_original_task_idx(self, idx):
        return self.map_dict[idx]

    def __getitem__(self, idx):
        oi, oti = self.__get_original_task_idx(idx)
        return_dict = self.sequence_datasets[oti].__getitem__(oi)
        return_dict["task_emb"] = self.task_embs[oti]
        return return_dict


class TruncatedSequenceDataset(Dataset):
    def __init__(self, sequence_dataset, n_demos=8):
        # self.sequence_vl_dataset = sequence_vl_dataset
        # self.sequence_dataset = sequence_vl_dataset.sequence_dataset
        # self.n_demos = n_demos
        # self.demo_ids = self.sequence_dataset.demos[:self.n_demos]
        # self.demo_start_indices = {demo: self.sequence_dataset._demo_id_to_start_indices[demo] for demo in self.demo_ids}
        # self.demo_lengths = {demo: self.sequence_dataset._demo_id_to_demo_length[demo] for demo in self.demo_ids}
        # self.total_num_sequences = sum([self.demo_lengths[demo] for demo in self.demo_ids])
        # self.task_emb = sequence_vl_dataset.task_emb
        self.sequence_dataset = sequence_dataset
        self.sequence_dataset_robomimic = sequence_dataset.sequence_dataset
        self.n_demos = n_demos
        self.demo_ids = self.sequence_dataset_robomimic.demos[:self.n_demos]
        self.demo_start_indices = {demo: self.sequence_dataset_robomimic._demo_id_to_start_indices[demo] for demo in self.demo_ids}
        self.demo_lengths = {demo: self.sequence_dataset_robomimic._demo_id_to_demo_length[demo] for demo in self.demo_ids}
        self.total_num_sequences = sum([self.demo_lengths[demo] for demo in self.demo_ids])
        self.task_emb = sequence_dataset.task_emb

        # Initialize weights with ones
        self.weights = np.ones((self.total_num_sequences, 10))

    def __len__(self):
        return self.total_num_sequences

    def __getitem__(self, idx):
        # for demo in self.demo_ids:
        #     if idx < self.demo_lengths[demo]:
        #         data = self.sequence_dataset.get_item(self.demo_start_indices[demo] + idx)
        #         data["task_emb"] = self.task_emb
        #         return data
        #     idx -= self.demo_lengths[demo]
        data = self.sequence_dataset.__getitem__(idx)
        data['weights'] = self.weights[idx]  # Add weights to the data dictionary
        return data


class TruncatedSequenceDatasetSample(Dataset):
    def __init__(self, sequence_dataset, buffer_size):
        self.sequence_dataset = sequence_dataset
        self.buffer_size = buffer_size

    def __len__(self):
        return self.buffer_size

    def __getitem__(self, idx):
        return self.sequence_dataset.__getitem__(idx)


class MySequenceDataset(Dataset):
    def __init__(self, all_obs_keys=None, obs_keys_in_memory=None, all_dataset_keys=None, data=None, seq_len=None, pad_frame_stack=True, pad_seq_length=True, frame_stack=1):
        self.all_obs_keys = all_obs_keys
        self.all_dataset_keys = all_dataset_keys
        self.obs_keys_in_memory = obs_keys_in_memory
        
        if isinstance(data, list):
            self.data = data
        elif isinstance(data, dict):
            self.data = [data]
        elif data is None:
            self.data = None
        else:
            raise ValueError("Data type not supported")
        
        if self.data is not None:
            self.n_demos = len(self.data)
            self.total_num_sequences = sum([demo['length'] for demo in self.data])
        else:
            self.n_demos = 0
            self.total_num_sequences = 0
        
        self.seq_len = seq_len
        
        self.pad_frame_stack = pad_frame_stack
        self.pad_seq_length = pad_seq_length
        self.n_frame_stack = frame_stack
        
        self.__generate_idx_to_demo_list()
        self.__generate_demo_start_idx_list()
        self.__generate_demo_ids()
        self.__generate_demo_start_indices()
        self.__generate_demo_lengths()

        self.__hdf5_path = None
        self.__hdf5_file = None

        self.weights = np.ones((self.total_num_sequences, 10))
    
    def __len__(self):
        if self.all_dataset_keys is None:
            return 0
        
        return self.total_num_sequences
    
    def __getitem__(self, index):
        assert self.all_dataset_keys is not None, f"Dataset is empty, please generate dataset from DataModule"
        if isinstance(index, int):
            meta = self.__get_dataset_sequence_from_demo(index)
            meta['weights'] = self.weights[index] 
            return meta
        elif isinstance(index, slice):
            # Handle slice of indices
            indices = range(*index.indices(len(self)))
            sliced_data = self.data[index]
            return MySequenceDataset(
                all_obs_keys=self.all_obs_keys,
                obs_keys_in_memory=self.obs_keys_in_memory,
                all_dataset_keys=self.all_dataset_keys,
                data=sliced_data,
                seq_len=self.seq_len,
                pad_frame_stack=self.pad_frame_stack,
                pad_seq_length=self.pad_seq_length,
                frame_stack=self.n_frame_stack
            )
        else:
            raise TypeError(f"Invalid argument type: {type(index)}. Expected int or slice.")
        
    def slice_by_list(self, indices):
        sliced_data = [self.data[i] for i in indices]
        return MySequenceDataset(
            all_obs_keys=self.all_obs_keys,
            obs_keys_in_memory=self.obs_keys_in_memory,
            all_dataset_keys=self.all_dataset_keys,
            data=sliced_data,
            seq_len=self.seq_len,
            pad_frame_stack=self.pad_frame_stack,
            pad_seq_length=self.pad_seq_length,
            frame_stack=self.n_frame_stack
        )

    def __add__(self, other):
        if self.all_obs_keys is None:
            return other

        if other.all_obs_keys is None:
            return self
            
        # assert that the two datasets have the same keys
        assert self.all_obs_keys == other.all_obs_keys
        assert self.obs_keys_in_memory == other.obs_keys_in_memory
        assert self.all_dataset_keys == other.all_dataset_keys
        assert self.seq_len == other.seq_len
        
        new_data = self.data + other.data
            
        return MySequenceDataset(self.all_obs_keys, self.obs_keys_in_memory, self.all_dataset_keys, new_data, self.seq_len)

    def __generate_idx_to_demo_list(self):
        if self.data is None:
            return

        self.idx_to_demo = []
        for demo_idx, demo in  enumerate(self.data):
            length = (demo['length']).item()
            self.idx_to_demo += [demo_idx] * length

    def __generate_demo_start_idx_list(self):
        if self.data is None:
            return
        
        self.demo_start_idx = []
        count = 0
        for demo in self.data:
            self.demo_start_idx.append(count)
            count += (demo['length'] ).item()

    def __generate_demo_ids(self):
        if self.data is None:
            return
        
        self.demo_ids = [ 'demo_' + str(idx) for idx, demo in enumerate(self.data) ]

    def __generate_demo_start_indices(self):
        if self.data is None:
            return
        
        self.demo_start_indices = {demo_id: self.demo_start_idx[demo_idx] for demo_idx, demo_id in enumerate(self.demo_ids)}
    
    def __generate_demo_lengths(self):
        if self.data is None:
            return
        
        self.demo_lengths = {demo_id: demo['length'] for demo_id, demo in zip(self.demo_ids, self.data)}


    def __get_dataset_sequence_from_demo(self, index):
        demo_id = self.idx_to_demo[index]
        demo_start_index = self.demo_start_idx[demo_id]
        demo_length = self.data[demo_id]['length']
        
        demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1)
        index_in_demo = index - demo_start_index + demo_index_offset
        
        num_frames_to_stack = self.n_frame_stack - 1
        seq_begin_index = max(0, index_in_demo - num_frames_to_stack)
        seq_end_index = min(demo_length, index_in_demo + self.seq_len)
        
        seq_begin_pad = max(0, num_frames_to_stack - index_in_demo)  # pad for frame stacking
        seq_end_pad = max(0, index_in_demo + self.seq_len - demo_length)  # pad for sequence length

        if not self.pad_frame_stack:
            assert seq_begin_pad == 0
        if not self.pad_seq_length:
            assert seq_end_pad == 0
            
        meta = dict()
        for key in self.all_dataset_keys:
            meta[key] = self.data[demo_id][key][seq_begin_index:seq_end_index]
        
        meta = TensorUtils.pad_sequence(meta, padding=(seq_begin_pad, seq_end_pad), pad_same=True)
        
        obs = dict()
        for key in self.all_obs_keys:
            if key in self.obs_keys_in_memory:
                obs[key] = self.data[demo_id]['obs'][key][seq_begin_index:seq_end_index]
            else:
                hdf5_file_path = self.data[demo_id]['hdf5_file_path']
                demo_name = self.data[demo_id]['demo_name']
                file_handler = self.__get_hdf5_file(hdf5_file_path)
                trajectory = file_handler['data/{}/obs/{}'.format(demo_name,key)]
                obs[key] = trajectory[seq_begin_index:seq_end_index].astype('float32')
        
        obs = TensorUtils.pad_sequence(obs, padding=(seq_begin_pad, seq_end_pad), pad_same=True)
        obs = ObsUtils.process_obs_dict(obs)
        
        meta['obs'] = obs
        
        meta['task_emb'] = self.data[demo_id]['demo_emb_list'][-1]
        meta['task_emb_list'] = self.data[demo_id]['demo_emb_list']
         
        return meta
        
    @property
    def get_num_demos(self):
        assert self.all_dataset_keys is not None, f"Dataset is empty, please generate dataset from DataModule"

        return len(self.data)

   
    def __get_hdf5_file(self, hdf5_path):
        """
        This property allows for a lazy hdf5 file open.
        """
        if self.__hdf5_path == hdf5_path:
            return self.__hdf5_file
        else:
            self.__hdf5_path = hdf5_path
            if self.__hdf5_file is not None:
                self.__hdf5_file.close()
            self.__hdf5_file = h5py.File(self.__hdf5_path, 'r', swmr=False, libver='latest')
            return self.__hdf5_file

class EvalDataset(Dataset):
    def __init__(self, init_states_path, n_eval):

        self.init_states = torch.load(init_states_path)

        # randomly take the n_eval init states
        self.data = random.choices(self.init_states, k=n_eval)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
