from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import asdict, dataclass
import os
from pathlib import Path
import random
import uuid
import diffuser.utils as utils
from ml_logger import logger, RUN
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from copy import deepcopy
import numpy as np
import os
import pickle
import gym
from diffuser.utils.timer import Timer
from config.locomotion_config import Config
import d4rl
import gym
import numpy as np
# import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

TensorBatch = List[torch.Tensor]


@dataclass
class TrainConfig:
    # Experiment
    device: str = "cuda"
    env: str = "antmaze-umaze-v2"  # OpenAI gym environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_freq: int = int(5e3)  # How often (time steps) we evaluate
    n_episodes: int = 15  # How many episodes run during evaluation
    max_timesteps: int = int(1e6)  # Max time steps to run environment
    checkpoints_path: Optional[str] = None  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    # BC
    buffer_size: int = 3_300_000  # Replay buffer size
    frac: float = 0.1  # Best data fraction to use
    max_traj_len: int = 1000  # Max trajectory length
    normalize: bool = True  # Normalize states
    # Wandb logging
    project: str = "CORL"
    group: str = "BC-D4RL"
    name: str = "BC"

    def __post_init__(self):
        self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:           #randomly select history data (prevent time-dependent data influence)
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)    #randomly select one batch
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        return [states, actions, rewards, next_states]

    def add_transition(self):
        # Use this method to add new data into the replay buffer during fine-tuning.
        # I left it unimplemented since now we do not do fine-tuning.
        raise NotImplementedError


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def wandb_init(config: dict) -> None:
    wandb.init(
        config=config,
        project=config["project"],
        group=config["group"],
        name=config["name"],
        id=str(uuid.uuid4()),
    )
    wandb.run.save()


@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int, original_dataset
) -> np.ndarray:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        while not done:
            state = original_dataset.normalizer.normalize(state, "observations")
            action = actor.act(state, device)
            # action = original_dataset.normalizer.unnormalize(action, 'actions')
            state, reward, done, _ = env.step(action)
            episode_reward += reward
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards)


def keep_best_trajectories(
    dataset: Dict[str, np.ndarray],
    frac: float,
    discount: float,
    max_episode_steps: int = 1000,
):
    ids_by_trajectories = []
    returns = []
    cur_ids = []
    cur_return = 0
    reward_scale = 1.0
    for i, (reward, done) in enumerate(zip(dataset["rewards"], dataset["terminals"])):
        cur_return += reward_scale * reward
        cur_ids.append(i)
        reward_scale *= discount
        if done == 1.0 or len(cur_ids) == max_episode_steps:
            ids_by_trajectories.append(list(cur_ids))
            returns.append(cur_return)
            cur_ids = []
            cur_return = 0
            reward_scale = 1.0

    sort_ord = np.argsort(returns, axis=0)[::-1].reshape(-1)
    top_trajs = sort_ord[: max(1, int(frac * len(sort_ord)))]

    order = []
    for i in top_trajs:
        order += ids_by_trajectories[i]
    order = np.array(order)
    dataset["observations"] = dataset["observations"][order]
    dataset["actions"] = dataset["actions"][order]
    dataset["next_observations"] = dataset["next_observations"][order]
    dataset["rewards"] = dataset["rewards"][order]
    dataset["terminals"] = dataset["terminals"][order]

LOG_STD_MIN = -5.0
LOG_STD_MAX = 2.0
class Actor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, max_action: float):
        super(Actor, self).__init__()

        # self.net = nn.Sequential(
        #     nn.Linear(state_dim, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, action_dim),
        #     nn.Tanh(),
        # )

        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh(),
        )

        self.max_action = max_action   #max action maps the output of the NN to [-max_action, max_action] range

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.max_action * self.net(state)

    @torch.no_grad()  #evaluate step will use
    def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        return self(state).cpu().data.numpy().flatten()


class BC:  # noqa
    def __init__(
        self,
        max_action: np.ndarray,
        actor: nn.Module,
        actor_optimizer: torch.optim.Optimizer,
        discount: float = 0.99,   #gamma, close to 1 -> pay more attention on long estimates, but still finite (number)
        device: str = "cpu",
    ):
        self.actor = actor
        self.actor_optimizer = actor_optimizer
        self.max_action = max_action
        self.discount = discount

        self.total_it = 0
        self.device = device

    def train(self, batch: TensorBatch) -> Dict[str, float]:
        log_dict = {}
        self.total_it += 1

        state, action, _, _ = batch

        # Compute actor loss
        pi = self.actor(state)   #call the forward
        actor_loss = F.mse_loss(pi, action)
        log_dict["actor_loss"] = actor_loss.item()
        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        return log_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.total_it = state_dict["total_it"]


def train(config: TrainConfig, data, original_dataset, logger):
    env = gym.make(config.env)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    idx = 0
    dataset = {}


    for path in data:
        returns = np.sum(path[2])
        print(returns)

    # for path in data:
    #     if idx == 0:
    #         dataset["observations"] = path[0]
    #         dataset["actions"] = original_dataset.normalizer.normalize(path[1], "actions")
    #         dataset["rewards"] = path[2] * 4
    #         dataset["next_observations"] = path[3]
    #     else:
    #         dataset["observations"] = np.concatenate((dataset["observations"], path[0]), axis=0)
    #         dataset["actions"] = np.concatenate((dataset["actions"], original_dataset.normalizer.normalize(path[1], "actions")), axis=0)
    #         dataset["rewards"] = np.concatenate((dataset["rewards"], path[2] * 4), axis=0)
    #         dataset["next_observations"] = np.concatenate((dataset["next_observations"], path[3]), axis=0)
    #     idx += 1

    for path in data:
        temp_rewards = []
        if "hopper" in Config.dataset:
            temp_rewards = path[2] * 4
        elif "halfcheetah" in Config.dataset:
            temp_rewards = path[2] * 15
        elif "walker2d" in Config.dataset:
            temp_rewards = path[2] * 9
        elif "kitchen" in Config.dataset:
            temp_rewards = path[2] * 3
        elif "antmaze" in Config.dataset:
            temp_rewards = path[2]
        else:
            return "ERROR"
        if np.sum(temp_rewards) > 150:
            if idx == 0:
                dataset["observations"] = path[0]
                dataset["actions"] = path[1]
                if "hopper" in Config.dataset:
                    dataset["rewards"] = path[2] * 4
                elif "halfcheetah" in Config.dataset:
                    dataset["rewards"]= path[2] * 5
                elif "walker2d" in Config.dataset:
                    dataset["rewards"] = path[2] * 4
                elif "antmaze" in Config.dataset:
                    dataset["rewards"] = path[2]
                else:
                    return "ERROR"
                # dataset["rewards"] = path[2] * 4
                dataset["next_observations"] = path[3]
            else:
                dataset["observations"] = np.concatenate((dataset["observations"], path[0]), axis=0)
                dataset["actions"] = np.concatenate((dataset["actions"], path[1]), axis=0)
                # dataset["rewards"] = np.concatenate((dataset["rewards"], path[2] * 4), axis=0)
                if "hopper" in Config.dataset:
                    dataset["rewards"] = np.concatenate((dataset["rewards"], path[2] * 4), axis=0)
                elif "halfcheetah" in Config.dataset:
                    dataset["rewards"] = np.concatenate((dataset["rewards"], path[2] * 5), axis=0)
                elif "walker2d" in Config.dataset:
                    dataset["rewards"] = np.concatenate((dataset["rewards"], path[2] * 4), axis=0)
                elif "antmaze" in Config.dataset:
                    dataset["rewards"] = np.concatenate((dataset["rewards"], path[2]), axis=0)
                else:
                    return "ERROR"
                dataset["next_observations"] = np.concatenate((dataset["next_observations"], path[3]), axis=0)
            idx += 1
    # temp_one = np.ones(dataset["actions"].shape)
    # dataset["actions"] = np.where(dataset["actions"] > 1.0, temp_one, dataset["actions"])
    # keep_best_trajectories(dataset, config.frac, config.discount) # TODO: turn this back on

    # if config.normalize:
    #     state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    # else:
    #     state_mean, state_std = 0, 1

    # dataset["observations"] = normalize_states(
    #     dataset["observations"], state_mean, state_std
    # )
    # dataset["next_observations"] = normalize_states(
    #     dataset["next_observations"], state_mean, state_std
    # )
    # env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(dataset)

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    max_action = float(env.action_space.high[0])  #limit the action space, let the output be in the action space

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    actor = Actor(state_dim, action_dim, max_action).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

    params = sum([param.nelement() for param in actor.parameters()])
    print("total parameter number: %.2fM", params / 1e6)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "discount": config.discount,
        "device": config.device,
    }

    print("---------------------------------------")
    print(f"Training BC, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize policy
    trainer = BC(**kwargs)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(policy_file))
        actor = trainer.actor


    evaluations = []
    best_score = 0
    for t in range(int(config.max_timesteps)):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        logger.log(log_dict, step=trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            print("current actor loss:", log_dict["actor_loss"])
            eval_scores = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
                original_dataset=original_dataset
            )
            eval_score = eval_scores.mean()
            normalized_eval_score = env.get_normalized_score(eval_score) * 100.0
            evaluations.append(normalized_eval_score)
            if eval_score > best_score:
                best_score = eval_score
                savepath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
                filename = "bc_best.pt"
                savepath = os.path.join(savepath, filename)
                torch.save(
                    trainer.state_dict(),
                    savepath,
                )

                print("---------------------------------------")
                print(
                    f"Evaluation over {config.n_episodes} episodes: "
                    f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
                )
                print("---------------------------------------")



            logger.log(
                {"d4rl_normalized_score": normalized_eval_score},
                step=trainer.total_it,
            )


def train_bc(**deps):
    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'
    config = TrainConfig()
    config.device = Config.device
    config.seed = Config.seed
    config.env = Config.dataset

    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    os.makedirs(loadpath, exist_ok=True)
    filename = "new_dataset.dat"
    loadpath = os.path.join(loadpath, filename)
    with open(loadpath, "rb") as f:
        data = pickle.load(f)

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )


    dataset = dataset_config()

    train(config, data, dataset, logger)
