import collections
import pathlib
import random
import json
import os

import math
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Sampler, Dataset
from torch import Tensor
import torch.distributed as dist
from collections.abc import Sequence,Iterator

from collections import deque
from tqdm import tqdm


WEIGHTS = {
    "fractal20220817_data": 1.0,                # Google RT-1 Robot Data (Large-Scale)
    "kuka": 0.0,
    "bridge_orig": 1.0,                                   # Original Version of Bridge V2 from Project Website
    "taco_play": 0.0,
    "jaco_play": 1.0,
    "berkeley_cable_routing": 0.0,
    "roboturk": 2.0,
    "viola": 2.0,
    "berkeley_autolab_ur5": 0.0,
    "toto": 1.0,
    "language_table": 0.1,
    "stanford_hydra_dataset_converted_externally_to_rlds": 2.0,
    "austin_buds_dataset_converted_externally_to_rlds": 0.0,
    "nyu_franka_play_dataset_converted_externally_to_rlds": 0.0,
    "furniture_bench_dataset_converted_externally_to_rlds": 0.0,
    "ucsd_kitchen_dataset_converted_externally_to_rlds": 0.0,
    "austin_sailor_dataset_converted_externally_to_rlds": 0.0,
    "austin_sirius_dataset_converted_externally_to_rlds": 0.0,
    "dlr_edan_shared_control_converted_externally_to_rlds": 0.0,
    "iamlab_cmu_pickup_insert_converted_externally_to_rlds": 0.0,
    "utaustin_mutex": 1.0,
    "berkeley_fanuc_manipulation": 0.0,
    "cmu_stretch": 0.0,
    ## New Datasets in MagicSoup++
    "bc_z": 0.2,                                          # Note: use v0.1.0 --> later versions broken
    "fmb": 0.0,
    "dobbe": 0.0,
    "droid": 0.0,
}


def convert(value):
    value = np.array(value)
    if np.issubdtype(value.dtype, np.floating):
        return value.astype(np.float32)
    elif np.issubdtype(value.dtype, np.signedinteger):
        return value.astype(np.int32)
    elif np.issubdtype(value.dtype, np.uint8):
        return value.astype(np.uint8)
    return value


def cycle(dl):
    while True:
        for data in dl:
            yield data


class DistributedWeightedRandomSampler(Sampler[int]):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).

    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
        generator (Generator): Generator used in sampling.

    Example:
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [4, 4, 1, 4, 5]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    weights: Tensor
    num_samples: int
    replacement: bool

    def __init__(self, weights: Sequence[float], num_samples: int,
                 replacement: bool = True, generator:torch.Generator=None) -> None:
        if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
        if not isinstance(replacement, bool):
            raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}")

        # We generate a random permutation of indices.
        self.indices = torch.randperm(num_samples, generator=generator)
        # We generate weight tensor
        weights_tensor = torch.as_tensor(weights, dtype=torch.double)[self.indices]
        if len(weights_tensor.shape) != 1:
            raise ValueError("weights should be a 1d sequence but given "
                             f"weights have shape {tuple(weights_tensor.shape)}")
        self.mask = torch.ones_like(weights_tensor).bool()

        if dist.is_initialized():
            num_processes = dist.get_world_size()
            if num_processes>1:
                assert generator is not None,"A generator should be set when num_processes > 1"
                # We reset the mask to zero for all processes
                self.mask = torch.zeros_like(weights_tensor)
                # We want the mask to select only indices for the current process
                # => We cut our indices in num_processes parts and we set the mask to 1 where the rank is matching
                rank_indices = [i for i in range(len(self.mask)) if i%num_processes==dist.get_rank()]
                self.mask[rank_indices]=1
                self.mask=self.mask.bool()
        else:
            num_processes=1

        # Set parameters...
        self.weights = weights_tensor
        self.num_samples = num_samples
        self.replacement = replacement
        self.generator = generator

    def __iter__(self) -> Iterator[int]:
        # We sample "num_samples" indices from the weights tensor "masked" on current process weights
        rand_tensor = torch.multinomial(self.weights[self.mask], self.num_samples, self.replacement, generator=self.generator)
        # We get corresponding indices
        rank_indices = self.indices[self.mask]
        rand_indices = rank_indices[rand_tensor]
        rand_indices:torch.Tensor
        # We sample only from theses indices.
        yield from iter(rand_indices.tolist())

    def __len__(self) -> int:
        return self.num_samples

def list_files(root):
    paths = []
    dirs = os.listdir(root)
    for dir in dirs:
        paths.append(pathlib.Path(dir))
    return paths


class TrajDataset(Dataset):
    def __init__(self, file_dir, seq_len=10, indices_path=None, device="cpu"):
        super().__init__()
        self._path = pathlib.Path(file_dir)
        self._paths = list_files(self._path)
        self._seq_len = seq_len
        self._weights = None
        self._action_dim = None

        # create a list of episode indices
        if indices_path is None:
            self.indices, self._num_episodes = self.load_indices()
        else:
            with open(indices_path, 'r') as f:
                info = json.load(f)
                self.indices, self._num_episodes = info['indices'], info['num_episodes']

        print(f'num_episodes: {self._num_episodes}, num_steps: {len(self.indices)}')
        self.load_weights()

        self.device = torch.device(device)
        
    @property
    def weights(self):
        return self._weights

    @property
    def action_dim(self):
        if self._action_dim is None:
            imgs, acts, rewards = self[0]
            self._action_dim = acts.shape[-1]
        return self._action_dim

    def load_weights(self):
        self._weights = []
        for dataset, _, _, _ in self.indices:
            dataset = dataset.split('/')[-1]
            self._weights.append(WEIGHTS[dataset])
            
    def load_indices(self):
        indices = []
        num_episodes = 0
        for file_path in self._paths:
            filenames = sorted((self._path/ file_path).glob('*.npz'))
            random.Random(0).shuffle(filenames)
            for filename in tqdm(filenames, desc='Loading indices'):
                try:
                    with filename.open('rb') as f:
                        episode = np.load(f)
                        episode = {k: episode[k] for k in episode.keys()}
                except Exception as e:
                    print(f'Could not load episode {str(filename)}: {e}')
                    continue
                ep_len = len(episode['imgs'])
                if ep_len < self._seq_len:
                    continue
                num_episodes += 1
                for j in range(ep_len - self._seq_len + 1):
                    indices.append((str(file_path), int(filename.stem), j, j+self._seq_len))
            
        with open(self._path / 'indices.json', 'w') as file:
            json.dump({'indices': indices, 'num_episodes': num_episodes}, file)

        return indices, num_episodes
     
    def load_traj(self, file_path, ep_idx, start, end):
        episode_dir = os.path.join(self._path, file_path, f'{ep_idx}')
        episode_len = len(os.listdir(episode_dir))
        # end = start + self._seq_len
        imgs, actions = [], []
        for i in range(start, end):
            if i < episode_len:
                step = np.load(os.path.join(episode_dir, f'{i}.npz'))
            else:
                step = np.load(os.path.join(episode_dir, f'{episode_len-1}.npz'))
            imgs.append(step['imgs'])
            actions.append(step['actions'])
        imgs, actions = np.array(imgs), np.array(actions)
        if imgs.shape[-1] == 3:
            imgs = imgs.transpose(0, 3, 1, 2)
        imgs = (imgs/127.5 - 1.0).astype(np.float32)
        actions = actions.astype(np.float32)
        rewards = np.zeros((self._seq_len, 1)).astype(np.float32)
        return imgs, actions, rewards

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        file_path, ep_idx, start, end = self.indices[idx]
        imgs, acts, rewards = self.load_traj(file_path, ep_idx, start, end)
        return imgs, acts, rewards
    
    def sample_batch(self, batch_size):
        indices = np.random.choice(len(self), size=batch_size)
        results = {"imgs": [], "acts": []}
        for idx in indices:
            file_path, ep_idx, start, end = self.indices[idx]
            imgs, acts = self.load_traj(file_path, ep_idx, start, end)
            results["imgs"].append(imgs)
            results["acts"].append(acts)
        for k, v in results.items():
            results[k] = np.stack(v)
        return results.values()
    
    def sample_batch_dataset(self):
        results = {"imgs": [], "acts": [], "rewards": []}
        task_names = []
        for file_path in self._paths:
            try:
                imgs, acts, rewards = self.load_traj(self._path / file_path, 4, 0, 32)
                results["imgs"].append(imgs)
                results["acts"].append(acts)
                results["rewards"].append(rewards)
                task_names.append(str(file_path))
            except:
                continue
        imgs = torch.from_numpy(np.stack(results['imgs'])).to(self.device)
        actions = torch.from_numpy(np.stack(results['acts'])).to(self.device)
        rewards = torch.from_numpy(np.stack(results['rewards'])).to(self.device)
        return imgs, actions, rewards, task_names
    