import numpy as np
import os
from typing import Optional

from sklearn.model_selection import train_test_split
import d4rl
from torch.utils.data import Dataset, dataloader
import dataclasses
import torch

DATAPATH = ''

def get_imitation_dataset(env_name,
        rollout_path = DATAPATH,
        n_max_demos=None):
    path = os.path.join(rollout_path, f"{env_name.replace('-', '_')}_0/rollouts/final.pkl")

    if not os.path.exists(path):
        import gym
        if 'hirid' in env_name:
            import hirid_env
        elif 'sepsis' in env_name:
            import sepsis_env

        print(f"Saving rollouts for {env_name} via env.get_dataset()")
        env = gym.make(env_name)
        data = env.get_dataset()

        #convert D4RL format to imitation format
        print(data.keys())
        print([data[key].shape for key in data.keys()])

        condensed = {
                "obs": data['observations'],
                "acts": data['actions'],
                "rews": data['rewards'],
                "infos": data['infos'],
                "terminal": data['terminals'],
                "indices": np.where(data['terminals'] == 1)[0][:-1] +1,
            }

        os.makedirs(path.replace("/final.pkl", ''), exist_ok=True)
        tmp_path = f"{path}.tmp"
        with open(tmp_path, "wb") as f:
            np.savez_compressed(f, **condensed)
        os.replace(tmp_path, path)

    expert_trajs = load(path)
    if n_max_demos is not None:
        if len(expert_trajs) < n_max_demos:
            raise ValueError(
                f"Want to use n_expert_demos={n_max_demos} trajectories, but only "
                f"{len(expert_trajs)} are available via {rollout_path}.",
            )
        expert_trajs = expert_trajs[:n_max_demos]
    return expert_trajs


def compute_norm_statistics(loader):
    if loader is None:
        return None
    with torch.no_grad():
        statistics = {}
        for dim in ['obs', 'acts']:
            input_dim = next(iter(loader))[dim].shape[-1]
            running_mean = torch.zeros(input_dim).to(device='cuda').to(torch.float32)
            running_var = torch.ones(input_dim).to(device='cuda').to(torch.float32)
            count = 0
            for batch in loader:
                x = torch.as_tensor(batch[dim], device='cuda').detach().to(torch.float32)
                #acts = torch.as_tensor(batch["acts"], device='cuda').detach().to(torch.float32)
                #obs = algo_base.unsqueezed_array(obs, lstm)
                #acts = algo_base.unsqueezed_array(acts, lstm)
                #x = torch.cat([obs, acts], -1)
                x = x.reshape(-1, x.shape[-1])

                batch_mean = torch.mean(x, dim=0)
                batch_var = torch.var(x, dim=0, unbiased=False)
                batch_count = x.shape[0]

                delta = batch_mean - running_mean
                tot_count = count + batch_count
                running_mean += delta * batch_count / tot_count

                running_var *= count
                running_var += batch_var * batch_count
                running_var += torch.square(delta) * count * batch_count / tot_count
                running_var /= tot_count

                count += batch_count

            statistics[dim] = (running_mean, running_var)
        return statistics

EPSILON=1e-8
def normalise(loader, statistics):
    if loader is None or statistics is None:
        return loader
    with torch.no_grad():
        transitions = dataclass_quick_asdict(loader.dataset)
        for dim in ['obs', 'acts']:
            mean, var = statistics[dim]
            transitions[dim] = (transitions[dim] - mean.cpu().numpy()) /np.sqrt(var.cpu().numpy() + EPSILON)

        return Transitions(**transitions)

def get_fo_env_name(env_name):
    if 'hirid' not in env_name:
        raise NotImplementedError
    fo_env_name = {'hirid-circ-v1': 'hirid-circ-v1', 
                'hirid-circage-v1':'hirid-circ-v1',
                'hirid-circZ-v1': 'hirid-circ-v1',
                'hirid-fluidsZ-v1': 'hirid-fluids-v1',
                'hirid-vasoZ-v1': 'hirid-vaso-v1',
                }
    return fo_env_name[env_name]


@dataclasses.dataclass()
class Trajectory:
    obs: np.ndarray
    """Observations, shape (trajectory_len + 1, ) + observation_shape."""

    acts: np.ndarray
    """Actions, shape (trajectory_len, ) + action_shape."""

    infos: Optional[np.ndarray]
    """An array of info dicts, length trajectory_len."""

    terminal: bool

    rews: Optional[np.ndarray]

    def __len__(self) -> int:
        """Returns number of transitions, equal to the number of actions."""
        return len(self.acts)

    def __eq__(self, other) -> bool:
        if not isinstance(other, Trajectory):
            return False

        dict_self, dict_other = dataclasses.asdict(self), dataclasses.asdict(other)
        # Trajectory objects may still have different keys if different subclasses
        if dict_self.keys() != dict_other.keys():
            return False

        if len(self) != len(other):
            # Short-circuit: if trajectories are of different length, then unequal.
            # Redundant as later checks would catch this, but speeds up common case.
            return False

        for k, self_v in dict_self.items():
            other_v = dict_other[k]
            if k == "infos":
                # Treat None equivalent to sequence of empty dicts
                self_v = [{}] * len(self) if self_v is None else self_v
                other_v = [{}] * len(other) if other_v is None else other_v
            if not np.array_equal(self_v, other_v):
                return False

        return True

    def __post_init__(self):
        """Performs input validation: check shapes are as specified in docstring."""
        if len(self.obs) != len(self.acts) + 1:
            raise ValueError(
                "expected one more observations than actions: "
                f"{len(self.obs)} != {len(self.acts)} + 1",
            )
        if self.infos is not None and len(self.infos) != len(self.acts):
            raise ValueError(
                "infos when present must be present for each action: "
                f"{len(self.infos)} != {len(self.acts)}",
            )
        if len(self.acts) == 0:
            raise ValueError("Degenerate trajectory: must have at least one action.")

    def __setstate__(self, state):
        self.__dict__.update(state)



def load(path):    
    data = np.load(path, allow_pickle=True)
    num_trajs = len(data["indices"])

    # Account for the extra obs in each trajectory
    if len(data["obs"]) > len(data["acts"]):
        obs = np.split(data["obs"], data["indices"] + np.arange(num_trajs) + 1)
    else:
        obs = np.split(data["obs"], data["indices"])
        obs = [np.concatenate([ob, np.zeros_like(ob[:1])]) for ob in obs]
    fields = (
        obs,
        np.split(data["acts"], data["indices"]),
        np.split(data["infos"], data["indices"]),
        data["terminal"],
    )
    
    if "rews" in data:
        fields += (np.split(data["rews"], data["indices"]),)
    return [Trajectory(*args) for args in zip(*fields)]

def split_datasets(trajs, test_size=0.2, val_size=0, seed=None):
    if seed is not None:
        np.random.seed(seed)
    train_trajs, test_trajs = train_test_split(trajs, test_size=test_size)
    if val_size == 0:
        return train_trajs, test_trajs
    else:
        train_trajs, val_trajs = train_test_split(train_trajs, test_size=val_size)
        return train_trajs, val_trajs, test_trajs

###########################################################
def dataclass_quick_asdict(obj):
    d = {f: getattr(obj, f) for f in dir(obj) if '_' not in f}
    return d

class Transitions:
    """A batch of obs-act-obs-done transitions."""
        

    
    def __init__(self, obs,
            acts,
            infos,
            next_obs: np.ndarray = None,
            dones: np.ndarray = None,
            history: np.ndarray = None,
            future: np.ndarray = None,
            masks: np.ndarray = None,
            **kwargs
    ):
        self.obs = obs
        self.acts = acts
        self.infos = infos
        self.next_obs=next_obs
        self.dones = dones
        self.history=history
        self.future=future
        self.masks = masks
        for attr, val in kwargs.items():
            setattr(self, attr, val)

    def __len__(self):
        """Returns number of transitions. Always positive."""
        return len(self.obs)

    def __getitem__(self, key):
        """ Indexing an instance `trans` of TransitionsMinimal with an integer `i`
            returns the `i`th `Dict[str, np.ndarray]` sample, whose keys are the field
            names of each dataclass field and whose values are the ith elements of each field
            value.

            Slicing returns a possibly empty instance of `TransitionsMinimal` where each
            field has been sliced.
        """
        d = dataclass_quick_asdict(self)
        d_item = {k: v[key] for k, v in d.items() if v is not None}


        if isinstance(key, slice):
            # Return type is the same as this dataclass. Replace field value with
            # slices.
            return dataclasses.replace(self, **d_item)
        else:
            assert isinstance(key, int)
            # Return type is a dictionary. Array values have no batch dimension.
            #
            # Dictionary of np.ndarray values is a convenient
            # torch.util.data.Dataset return type, as a torch.util.data.DataLoader
            # taking in this `Dataset` as its first argument knows how to
            # automatically concatenate several dictionaries together to make
            # a single dictionary batch with `torch.Tensor` values.
            return d_item


def pad_and_stack(array_dict):
    lengths = [len(obs) for obs in array_dict['obs']]
    output = {}
    if len(set(lengths)) > 1: # more than one length
        max_len = max(lengths)
        for key in array_dict.keys():
            arr = array_dict[key]
            if arr is None:
                output[key] = arr
            elif arr[0] is None:
                output[key] = arr
            elif len(arr[0].shape) == 1:
                output[key] = np.stack(
                        [np.pad(val, (0, max_len-val.shape[0])) for val in arr]
                        ,0)
            elif len(arr[0].shape) == 2:
                output[key] = np.stack(
                            [np.pad(val, ((0, max_len-val.shape[0]), (0,0))) for val in arr]
                        ,0)
            else:
                raise NotImplementedError
            
        output['masks'] = np.stack( #masks
                [np.concatenate([np.ones(val.shape[0], dtype=bool), 
                        np.zeros(max_len-val.shape[0], dtype=bool)]) for val in array_dict['obs'] ]
                ,0)
    else:
        for key in array_dict.keys():
            output[key] = np.stack(array_dict[key])
        output['masks'] = np.stack(
                    [np.ones((lengths[0]), dtype=bool) for _ in range(len(array_dict['obs']))])

    return output

def flatten_trajectories(trajectories, lstm= False):
    """Flatten a series of trajectory dictionaries into arrays.

    Args:
        trajectories: list of trajectories.

    Returns:
        The trajectories flattened into a single batch of Transitions.
    """
    keys = ["obs", "next_obs", "acts", "dones", "infos"]
    extra = []
    for key in dir(trajectories[0]):
        if '__' not in key and key not in keys and key != 'terminal':
            keys.append(key)
            extra.append(key)

    parts = {key: [] for key in keys}

    for traj in trajectories:
        parts["acts"].append(traj.acts)

        obs = traj.obs
        parts["obs"].append(obs[:-1])
        parts["next_obs"].append(obs[1:])

        dones = np.zeros(len(traj.acts), dtype=bool)
        dones[-1] = traj.terminal
        parts["dones"].append(dones)

        if traj.infos is None:
            infos = np.array([{}] * len(traj))
        else:
            infos = traj.infos
        parts["infos"].append(infos)

        for key in extra:
            parts[key].append(getattr(traj, key))

    if lstm:
        cat_parts = pad_and_stack(parts)
    else:
        cat_parts = {
            key: np.concatenate(part_list, axis=0) for key, part_list in parts.items()
        }
    lengths = set(map(len, cat_parts.values()))
    assert len(lengths) == 1, f"expected one length, got {lengths}"
    return Transitions(**cat_parts)

    
def transitions_collate_fn(batch):
    """Custom `torch.utils.data.DataLoader` collate_fn for `Transitions`.

    Use this as the `collate_fn` argument to `DataLoader` if using an instance of
    `Transitions` as the `dataset` argument.

    Args:
        batch: The batch to collate.

    Returns:
        A collated batch. Uses Torch's default collate function for everything
        except the "infos" key. For "infos", we join all the info dicts into a
        list of dicts. (The default behavior would recursively collate every
        info dict into a single dict, which is incorrect.)
    """
    batch_no_infos = [
        {k: np.array(v) for k, v in sample.items() if (k != "infos") and (v is not None)} for sample in batch
    ]
    
    result = dataloader.default_collate(batch_no_infos)
    assert isinstance(result, dict)
    result["infos"] = [sample["infos"] for sample in batch]
    return result
