import torch
from typing import Tuple, List
from robert.agent.config import CONFIGS
from robert.agent.params import PARAMS
from robert.utils.common import get_device
from robert.agent.args import task, DEVICE3
from torch.distributions.multinomial import Multinomial
import math
import numpy as np
from functools import reduce
from config import CONFIGS

DEVICE = DEVICE3


def preprocess_pointmaze(
    _states: torch.Tensor, _actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    states = (
        _states.clone()
        # .to(DEVICE)
        .reshape((-1, CONFIGS["dataset"]["episode_steps"], CONFIGS["state_dim"]))
    )
    actions = (
        _actions.clone()
        # .to(DEVICE)
        .reshape((-1, CONFIGS["dataset"]["episode_steps"] - 1, CONFIGS["action_dim"]))
    )

    return (states, actions)


def coordinate_transform_pointmaze(states: torch.Tensor) -> torch.Tensor:
    new_states = states.clone()

    pre_steps = PARAMS["pre_steps"]
    post_steps = PARAMS["post_steps"]

    assert new_states.size(1) == pre_steps + post_steps + 1

    # new_states[:, :pre_steps] = (
    #     new_states[:, :pre_steps] - new_states[:, 1 : pre_steps + 1]
    # )
    new_states[:, :pre_steps] = new_states[:, :pre_steps] - new_states[
        :, [pre_steps]
    ].repeat_interleave(pre_steps, dim=1)

    # new_states[:, pre_steps + 1 :] = (
    #     new_states[:, pre_steps + 1 :] - new_states[:, pre_steps:-1]
    # )
    new_states[:, pre_steps + 1 :] = new_states[:, pre_steps + 1 :] - new_states[
        :, [pre_steps]
    ].repeat_interleave(post_steps, dim=1)

    new_states[:, pre_steps] = 0.0
    return new_states


def preprocess_shadow(
    _states: torch.Tensor, _actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    states = _states.reshape(
        (-1, CONFIGS["dataset"]["episode_steps"], CONFIGS["state_dim"])
    )
    actions = _actions.reshape(
        (-1, CONFIGS["dataset"]["episode_steps"] - 1, CONFIGS["action_dim"])
    )

    return (states, actions)


def coordinate_transform_shadow(states: torch.Tensor) -> torch.Tensor:
    new_states = states.clone()

    pre_steps = PARAMS["pre_steps"]
    post_steps = PARAMS["post_steps"]

    assert new_states.size(1) == pre_steps + post_steps + 1

    new_states[:, :pre_steps] = new_states[:, :pre_steps] - new_states[
        :, [pre_steps]
    ].repeat_interleave(pre_steps, dim=1)

    new_states[:, pre_steps + 1 :] = new_states[:, pre_steps + 1 :] - new_states[
        :, [pre_steps]
    ].repeat_interleave(post_steps, dim=1)

    return new_states


def preprocess_ur5e(
    _states: torch.Tensor, _actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    states = _states.reshape(
        (-1, CONFIGS["dataset"]["episode_steps"], CONFIGS["state_dim"])
    )
    actions = _actions.reshape(
        (-1, CONFIGS["dataset"]["episode_steps"] - 1, CONFIGS["action_dim"])
    )

    return (states, actions)


def coordinate_transform_ur5e(states: torch.Tensor) -> torch.Tensor:
    new_states = states.clone()

    pre_steps = PARAMS["pre_steps"]
    post_steps = PARAMS["post_steps"]

    assert new_states.size(1) == pre_steps + post_steps + 1
    assert new_states.size(2) == CONFIGS["dataset"]["state_dim"]

    new_states[:, :pre_steps] = new_states[:, :pre_steps] - new_states[
        :, [pre_steps]
    ].repeat_interleave(pre_steps, dim=1)

    new_states[:, pre_steps + 1 :] = new_states[:, pre_steps + 1 :] - new_states[
        :, [pre_steps]
    ].repeat_interleave(post_steps, dim=1)

    # new_states[:, pre_steps, 0] = 0.0

    return new_states


def sample(
    states: torch.Tensor,
    actions: torch.Tensor,
    batch_size: int,
):
    assert len(states) == len(actions)
    assert states.shape[0] == actions.shape[0]
    assert states.size(1) == actions.size(1)
    assert states.shape[2] == actions.shape[2] + 1

    (episode_idx, seq_idx, seq_len) = states.shape[:3]

    sampled_episodes = torch.randint(0, episode_idx, (batch_size,))
    sampled_seqs = torch.randint(0, seq_idx, (batch_size,))

    return (
        states[sampled_episodes, sampled_seqs],
        actions[sampled_episodes, sampled_seqs],
    )


def enumer2(tensor: torch.Tensor, seq_len: int):
    assert len(tensor.shape) == 3
    bs, seq_length, ts = tensor.shape
    return tensor.unfold(1, seq_len, 1).transpose(-1, -2)


def normalize(state: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
    return (state - mean) / std


def dataset_split2(
    states: torch.Tensor,
    actions: torch.Tensor,
    splits: Tuple[float, float],
) -> Tuple[
    Tuple[torch.Tensor, torch.Tensor],
    Tuple[torch.Tensor, torch.Tensor],
    Tuple[List[int], List[int]],
]:
    assert sum(splits) == 1

    size = states.size(0)
    assert size == states.size(0) == actions.size(0)

    train_ratio, test_ratio = splits

    train_size, test_size = math.floor(size * train_ratio), math.ceil(size * test_ratio)
    assert train_size + test_size == size

    test_idx = torch.randint(0, size, (test_size,)).tolist()

    all_idxs = set(range(size))

    get_train_idx = lambda test_idx: list(all_idxs.difference(set(test_idx)))

    return (
        (states[get_train_idx(test_idx)], states[test_idx]),
        (actions[get_train_idx(test_idx)], actions[test_idx]),
        (get_train_idx(test_idx), test_idx),
    )


def calc_state_seq_mean_std(state_seq):
    ts, si, sl, dim = state_seq.shape
    _device = DEVICE3
    _state_seq_means = torch.zeros((si, sl, dim), device=_device, dtype=torch.float32)
    for i in range(si):
        _state_seq_means[i] = transform(state_seq[:, i].to(DEVICE3)).mean(dim=0)
    _state_seq_mean = _state_seq_means.mean(dim=0)
    assert _state_seq_mean.shape == (sl, dim)
    del _state_seq_means

    _state_seq_stds = torch.zeros((si, sl, dim), device=_device, dtype=torch.float32)
    for i in range(si):
        _state_seq_stds[i] = (
            ((transform(state_seq[:, i].to(DEVICE3)) - _state_seq_mean) ** 2)
            .mean(dim=0)
            .sqrt()
        )

    _state_seq_std = (_state_seq_stds**2).mean(dim=0).sqrt()
    assert _state_seq_std.shape == (sl, dim)
    del _state_seq_stds

    print(
        f"state transformed seqs mean is: {_state_seq_mean}, std is: {_state_seq_std}"
    )

    return (_state_seq_mean, _state_seq_std + 1e-8)


def extract_env_observation_pointmaze(t):
    return t.observation[:2]


def extract_env_observation_shadow(t):
    assert t.observation.shape == (44, )
    return t.observation[:22]


def extract_env_observation_ur5e(t):
    assert t.observation.shape == (10,)
    return t.observation[:5]


preprocess = {
    "pointmaze": preprocess_pointmaze,
    "shadow": preprocess_shadow,
    "ur5e": preprocess_ur5e,
}[task]

transform = {
    "pointmaze": coordinate_transform_pointmaze,
    "shadow": coordinate_transform_shadow,
    "ur5e": coordinate_transform_ur5e,
}[task]

extract_observation = {
    "pointmaze": extract_env_observation_pointmaze,
    "shadow": extract_env_observation_shadow,
    "ur5e": extract_env_observation_ur5e,
}[task]
