import random
import os
import wandb
import pyrallis
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from dataclasses import asdict, dataclass
import numpy as np
import gym
import d4rl
import uuid
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


TensorBatch = List[torch.Tensor]


@dataclass
class TrainConfig:
    # Experiment
    device: str = "cuda"
    env_sim_name: str = "hopper"
    level: str = "medium"
    env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_freq: int = int(5e3)  # How often (time steps) we evaluate
    save_freq: int = int(2e4)  # How often (time steps) we evaluate
    n_episodes: int = 10  # How many episodes run during evaluation
    max_timesteps: int = int(4e5)  # Max time steps to run environment
    checkpoints_path: Optional[str] = f'checkpoints/{env}'  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    # IQL
    buffer_size: int = 2_000_000  # Replay buffer size
    vae_hidden_dim: int = 400  # hidden dimension of vae network
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    tau: float = 0.005  # Target network update rate
    iql_tau: float = 0.7  # Coefficient for asymmetric loss
    beta: float = 0.5  # Coefficient for asymmetric loss
    iql_deterministic: bool = False  # Use deterministic actor
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    vf_lr: float = 3e-4  # V function learning rate
    lr: float = 1e-5  # V function learning rate
    qf_lr: float = 3e-4  # Critic learning rate
    actor_lr: float = 3e-4  # Actor learning rate
    actor_dropout: Optional[float] = None  # Adroit uses dropout for policy network
    # Wandb logging
    project: str = "offline-vae"
    # group: str = f"{env_sim_name}-offline"
    name: str = f"{env}"


class ReplayBuffer:
    def __init__(
            self,
            state_dim: int,
            action_dim: int,
            buffer_size: int,
            device: str = "cuda",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        self.length = n_transitions

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        states = self._states[item].unsqueeze(0)
        actions = self._actions[item].unsqueeze(0)
        rewards = self._rewards[item].unsqueeze(0)
        next_states = self._next_states[item].unsqueeze(0)
        dones = self._dones[item].unsqueeze(0)
        return [states, actions, rewards, next_states, dones]


class VAE(nn.Module):
    """
        Variational Auto-Encoder

        Args:
            obs_dim (int): The dimension of the observation space.
            act_dim (int): The dimension of the action space.
            hidden_size (int): The number of hidden units in the encoder and decoder networks, default=64.
            latent_dim (int): The dimensionality of the latent space.
            act_lim (float): The upper limit of the action space.
            device (str): The device to use for computation (cpu or cuda).
        """

    def __init__(self, obs_dim, act_dim, hidden_size, latent_dim, act_lim, obs_lim, device='cuda'):
        super(VAE, self).__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.e1 = nn.Linear(obs_dim + act_dim, hidden_size).to(self.device)
        self.e2 = nn.Linear(hidden_size, hidden_size).to(self.device)

        self.mean = nn.Linear(hidden_size, latent_dim).to(self.device)
        self.log_std = nn.Linear(hidden_size, latent_dim).to(self.device)

        self.d1 = nn.Linear(obs_dim + act_dim + latent_dim, hidden_size).to(self.device)
        self.d2 = nn.Linear(hidden_size, hidden_size).to(self.device)
        self.d3 = nn.Linear(hidden_size, obs_dim).to(self.device)

        self.act_lim = act_lim
        self.obs_lim = obs_lim

    def forward(self, obs, act, next_obs):
        z, mean, std = self.encoder(obs, act, next_obs)
        u = self.decode(obs, act, next_obs, z)

        return u, mean, std

    def encoder(self, obs, act, next_obs):
        z = F.relu(self.e1(torch.cat([obs, act], 1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # clamp for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        return z, mean, std

    def decode(self, obs, act, next_obs, z=None):
        if z is None:
            z = torch.randn((obs.shape[0], self.latent_dim)).clamp(-0.5, 0.5).to(self.device)
        s = F.relu(self.d1(torch.cat([obs, act, z], 1)))
        s = F.relu(self.d2(s))

        return self.d3(s)

    # for BEARL only
    def decode_multiple(self, obs, z=None, num_decode=10):
        if z is None:
            z = torch.randn(
                (obs.shape[0], num_decode, self.latent_dim)).clamp(-0.5,
                                                                   0.5).to(self.device)

        a = F.relu(
            self.d1(
                torch.cat(
                    [obs.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 2)))
        a = F.relu(self.d2(a))
        return torch.tanh(self.d3(a)), self.d3(a)


class VAETrainer:
    def __init__(self, dataset, state_dim, action_dim, max_action, obs_lim, vae_hidden_dim=64,
                 lr=1e-4, reward_scale=1.0, device='cuda', beta=1.5):
        self.dataset = dataset
        self.device = device
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.vae_hidden_dim = vae_hidden_dim
        self.latent_dim = action_dim + state_dim
        self.max_action = max_action
        self.beta = beta
        self.obs_lim = obs_lim
        self.this_beta = 0
        self.epsilon = self.beta / 200000
        self.vae = VAE(self.state_dim, self.action_dim, self.vae_hidden_dim,
                       self.latent_dim, self.max_action, self.obs_lim, self.device).to(self.device)
        self.vae_optim = torch.optim.Adam(self.vae.parameters(), lr=lr)
        self.vae_list = []

    def vae_loss(self, obs, act, next_observations, t):
        recon, mean, std = self.vae(obs, act, next_observations)
        recon_loss = nn.functional.mse_loss(recon, next_observations)
        KL_loss = - 0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        # KL_loss = torch.clamp(KL_loss, -100, 100)
        if t > 50000:
            self.this_beta += self.epsilon
        self.this_beta = min([self.this_beta, self.beta])
        loss_vae = recon_loss + self.this_beta * KL_loss

        self.vae_optim.zero_grad()
        loss_vae.backward()
        self.vae_optim.step()
        stats_vae = {"loss/loss_vae": loss_vae.item(),
                     "loss/KL_loss": KL_loss.item() * self.beta,
                     "loss/recon_loss": recon_loss.item(),
                     "beta": self.this_beta}
        return loss_vae, stats_vae

    def train_one_step(self, observations, next_observations, actions, rewards, done, t):
        # update VAE
        loss_vae, log_dict = self.vae_loss(observations, actions, next_observations, t)

        return log_dict

    def test_one_step(self, observations, next_observations, actions, rewards, done):
        # update VAE
        with torch.no_grad():
            _, mean, std = self.vae.encoder(observations, actions, next_observations)
            return mean, std


def set_seed(
        seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wandb_init(config: dict) -> None:
    wandb.init(
        config=config,
        project=config["project"],
        # group=config["group"],
        name=config["name"],
        # id=str(uuid.uuid4()),
    )
    wandb.run.save()


@pyrallis.wrap()
def train_vae(config: TrainConfig):
    env = gym.make(config.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    print(state_dim, action_dim, '\n\n\n')

    dataset = d4rl.qlearning_dataset(env)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(dataset)

    obs_lim = [env.observation_space.low[0], env.observation_space.high[0]]
    max_action = float(env.action_space.high[0])
    seed = config.seed
    set_seed(seed, env)
    print(env.observation_space)
    print("---------------------------------------")
    print(f"Training VAE network, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    wandb_init(asdict(config))

    kwargs = {
        "max_action": max_action,
        "dataset": dataset,
        "state_dim": state_dim,
        "action_dim": action_dim,
        "vae_hidden_dim": config.vae_hidden_dim,
        "beta": config.beta,
        "device": config.device,
        "lr": config.lr,
        "obs_lim": obs_lim,

    }
    trainer = VAETrainer(**kwargs)
    print(obs_lim, max_action)
    if not os.path.exists(f'sas_vae_checkpoints/{config.env}-{seed}'):
        os.makedirs(f'sas_vae_checkpoints/{config.env}-{seed}')

    for t in range(int(config.max_timesteps)):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        log_dict = trainer.train_one_step(observations, next_observations, actions, rewards, dones, t)
        wandb.log(log_dict)
        if (t + 1) % config.save_freq == 0 and t > 100000: # and t > 99000 or t == 102999 or t == 104999 or t == 106999 or t == 108999:
            torch.save({"vae": trainer.vae, "vae_optim": trainer.vae_optim}, f'sas_vae_checkpoints/{config.env}-{seed}/checkpoint_{t}.pt')
    wandb.finish()
    config.seed += 1


if __name__ == "__main__":
    os.environ["WANDB_MODE"] = 'offline'
    train_vae()
