from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import os
import random
import d4rl
import wandb
import gym
import numpy as np
import torch
from scipy import stats
from embedding import Embedding

TensorBatch = List[torch.Tensor]


class TrainConfig:
    # Experiment
    device: str = "cuda"
    env_sim_name: str = "halfcheetah"
    level: str = "medium-replay"
    env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_seed: int = 0  # Eval environment seed
    eval_freq: int = int(2e3)  # How often (time steps) we evaluate
    save_freq: int = int(1e4)
    n_episodes: int = 100  # How many episodes run during evaluation
    offline_iterations: int = int(300000)  # Number of offline updates
    max_timesteps: int = int(2e5)  # Max time steps to run environment
    checkpoints_path: Optional[str] = f'checkpoints/{env}'  # Save path
    load_model: str = f"../offline/checkpoints/{env}/checkpoint_999999.pt"  # Model load file name, "" doesn't load
    vae_file: str = f"../../vae_checkpoints/{env}/checkpoint_999999.pt"  # Model load file name, "" doesn't load
    # IQL
    actor_dropout: float = 0.0  # Dropout in actor network
    buffer_size: int = 2_000_000  # Replay buffer size
    vae_hidden_dim: int = 400  # hidden dimension of vae network
    batch_size: int = 64  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    tau: float = 0.005  # Target network update rate
    beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
    iql_tau: float = 0.7  # Coefficient for asymmetric loss
    expl_noise: float = 0.03  # Std of Gaussian exploration noise
    noise_clip: float = 0.5  # Range to clip noise
    iql_deterministic: bool = False  # Use deterministic actor
    use_off_policy: bool = True  # Use deterministic actor
    normalize: bool = False  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    vf_lr: float = 3e-4  # V function learning rate
    qf_lr: float = 3e-4  # Critic learning rate
    actor_lr: float = 3e-4  # Actor learning rate
    # Wandb logging
    project: str = "offline-embedding"
    group: str = f"{env_sim_name}"
    name: str = f"{env}"


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        self.length = n_transitions

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(self):
        # Use this method to add new data into the replay buffer during fine-tuning.
        # I left it unimplemented since now we do not do fine-tuning.
        raise NotImplementedError

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        states = self._states[item].unsqueeze(0)
        actions = self._actions[item].unsqueeze(0)
        rewards = self._rewards[item].unsqueeze(0)
        next_states = self._next_states[item].unsqueeze(0)
        dones = self._dones[item].unsqueeze(0)
        return [states, actions, rewards, next_states, dones]


def wandb_init(config) -> None:
    wandb.init(
        # config=config.dict(),
        project=config.project,
        group=config.group,
        name=config.name,
        # id=str(uuid.uuid4()),
        settings=wandb.Settings(start_method="thread")
    )
    wandb.run.save()


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


def modify_reward(dataset, env_name, max_episode_steps=1000):
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def train(config):
    env = gym.make(config.env)
    dataset = d4rl.qlearning_dataset(env)
    if config.normalize_reward:
        modify_reward(dataset, config.env)

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    if not os.path.exists(f'embedding_checkpoints/{config.env}'):
        os.makedirs(f'embedding_checkpoints/{config.env}')

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(dataset)

    seed = config.seed
    set_seed(seed, env)
    embedding_model = Embedding(obs_size=state_dim, num_outputs=action_dim, Mean_in=0, max_in_dis=0, min_in_dis=0)

    for t in range(int(config.max_timesteps)):
        batch = replay_buffer.sample(config.batch_size)
        log_dict = embedding_model.train_model(batch)
        wandb.log(log_dict)
        if (t + 1) % config.save_freq == 0:
            torch.save(embedding_model, f'embedding_checkpoints/{config.env}/checkpoint_{t}.pt')


if __name__ == "__main__":
    config = TrainConfig
    wandb_init(config)
    train(config)
    wandb.finish()
