# based on https://github.com/chwoong/LiRE
from copy import deepcopy
import os, sys, random, datetime
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pyrallis
import tqdm

import gym
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from env import utils_env
from logger import Logger
from base_models import MLP, ActorProb, Critic, DiagGaussian
sys.path.append("./reward_learning")
from reward_learning import reward_model

TensorBatch = List[torch.Tensor]
TrajTensorBatch = List[List[torch.Tensor]]


@dataclass
class TrainConfig:
    # Experiment
    device: str = None
    dataset: str = "medium-replay"
    env: str = "metaworld_box-close-v2"  # environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_freq: int = int(5e3)  # How often (time steps) we evaluate
    n_episodes: int = 50  # How many episodes run during evaluation
    max_timesteps: int = 250000  # Max time steps to run environment
    checkpoints_path: Optional[str] = None  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    data_quality: float = None # Replay buffer size (data_quality * 100000)
    trivial_reward: int = 0  # 0: GT reward, 1: zero reward, 2: constant reward, 3: negative reward
    # Algorithm
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size
    traj_batch_size : int = 16 # Trajectory batch size
    discount: float = 0.99  # Discount factor
    lam: float = 1e-3   # adversarial loss coefficient
    tau: float = 0.005  # Target network update rate
    alpha: float = 0.2  # entropy regularization (initial value)
    auto_alpha: bool = True  # optimize alpha
    alpha_lr: float = 3e-4  # alpha learning rate
    normalize: bool = True  # Normalize states
    normalize_reward: bool = True  # Normalize reward
    critic_lr: float = 3e-4  # Critic learning rate
    actor_lr: float = 3e-4  # Actor learning rate
    hidden_size: int = 256  # Hidden size for networks
    # Reward model
    feedback_num: int = 1000
    use_reward_model: bool = True
    epochs: int = 0
    batch_size: int = 256
    activation: str = "tanh"
    lr: float = 1e-3
    threshold: float = 0.5
    segment_size: int = 25
    data_aug: str = "none"
    hidden_sizes: int = 128
    ensemble_num: int = 3
    ensemble_method: str = "mean"
    q_budget: int = 1
    feedback_type: str = "RLT"
    model_type: str = "BT"
    noise: float = 0.0
    human: bool = False

    def __post_init__(self):
        if self.dataset=="medium-replay":
            self.log_path = f"log/{self.env}/medium-replay/data_{self.data_quality}_fn_{self.feedback_num}_qb_{self.q_budget}_ft_{self.feedback_type}_m_{self.model_type}/s_{self.seed}"
        elif self.dataset=="medium-expert":
            self.log_path = f"log/{self.env}/medium-expert/fn_{self.feedback_num}/s_{self.seed}"



def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        # if state 2 dim
        if len(state) == 2:
            state = state[0]
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)
    
    def add(self, state, action, reward, done, next_state):
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)
        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)

    def load_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        self.num_traj = n_transitions // 500

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def sample_trajectory(self, batch_size: int, segment_size: int) -> TensorBatch:
        traj_idx = np.random.choice(self.num_traj, batch_size, replace=True)
        idx_start = [500 * i + np.random.randint(0, 499 - segment_size) for i in traj_idx]
        indices = []
        for i in idx_start:
            indices.extend([j for j in range(i, i + segment_size)])
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        return [states, actions, rewards, next_states]


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


@torch.no_grad()
def eval_actor(
    env: gym.Env,
    env_name: str,
    agent: nn.Module,
    device: str,
    n_episodes: int,
    seed: int,
) -> np.ndarray:
    # env.seed(seed)
    agent.eval()
    episode_rewards = []
    episode_success_list = []
    for _ in range(n_episodes):
        # perturn initial arm position
        state, done = env.reset(), False

        episode_reward = 0.0
        episode_succes = 0
        while not done:
            action = agent.sample_action(state, deterministic=False)
            state, reward, done, info = env.step(action)
            episode_reward += reward
            if "metaworld" in env_name:
                episode_succes = max(episode_succes, info["success"])

        episode_rewards.append(episode_reward)
        episode_success_list.append(episode_succes)

    agent.train()
    return np.array(episode_rewards), np.array(episode_success_list)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(
    dataset,
    max_episode_steps=1000,
    trivial_reward=0,
):
    # min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
    # GT reward
    if trivial_reward == 0:
        dataset["rewards"] = (dataset["rewards"] - min(dataset["rewards"])) / (
            max(dataset["rewards"]) - min(dataset["rewards"])
        )
    # zero reward
    elif trivial_reward == 1:
        dataset["rewards"] *= 0.0
    # random reward
    elif trivial_reward == 2:
        dataset["rewards"] = (dataset["rewards"] - min(dataset["rewards"])) / (
            max(dataset["rewards"]) - min(dataset["rewards"])
        )
        min_reward, max_reward = min(dataset["rewards"]), max(dataset["rewards"])
        dataset["rewards"] = np.random.uniform(
            min_reward, max_reward, size=dataset["rewards"].shape
        )
        # min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
    # negative reward
    elif trivial_reward == 3:
        dataset["rewards"] = 1 - (dataset["rewards"] - min(dataset["rewards"])) / (
            max(dataset["rewards"]) - min(dataset["rewards"])
        )


class APPO:
    def __init__(
        self, 
        actor, 
        qf1, 
        qf2,
        vf,
        actor_optim, 
        qf1_optim,
        qf2_optim,
        vf_optim,
        action_space,
        dist,
        device,
        lam,
        tau=0.001,
        gamma=0.99,
        alpha=0.2,
        batch_size=256,
        traj_batch_size=16,
        max_steps=250000
    ):
        super().__init__()

        self.actor = actor
        self.qf1, self.qf1_old = qf1, deepcopy(qf1)
        self.qf2, self.qf2_old = qf2, deepcopy(qf2)
        self.vf = vf

        self.actor_optim = actor_optim
        self.qf1_optim = qf1_optim
        self.qf2_optim = qf2_optim
        self.vf_optim = vf_optim

        self.action_space = action_space
        self.dist = dist

        self._tau = tau
        self._gamma = gamma
        self.lam = lam
        self.batch_size = batch_size
        self.traj_batch_size = traj_batch_size

        self._is_auto_alpha = False
        if isinstance(alpha, tuple):
            self._is_auto_alpha = True
            self._target_entropy, self._log_alpha, self._alpha_optim = alpha
            self._alpha = self._log_alpha.detach().exp()
        else:
            self._alpha = alpha
        
        self.__eps = np.finfo(np.float32).eps.item()

        self._device = device
        self.total_it = 0
    
    def train(self):
        self.actor.train()
        self.qf1.train()
        self.qf2.train()
        self.vf.train()

    def eval(self):
        self.actor.eval()
        self.qf1.eval()
        self.qf2.eval()
        self.vf.eval()

    def _sync_weight(self):
        for o, n in zip(self.qf1_old.parameters(), self.qf1.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.qf2_old.parameters(), self.qf2.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def __call__(self, obs, deterministic=False):
        dist = self.actor.get_dist(obs)
        if deterministic:
            action = dist.mode()
        else:
            action = dist.rsample()
        log_prob = dist.log_prob(action)

        action_scale = torch.tensor((self.action_space.high - self.action_space.low) / 2, device=action.device)
        squashed_action = torch.tanh(action)
        log_prob = log_prob - torch.log(action_scale * (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True)

        return squashed_action, log_prob
    
    def log_prob(self, obs, action):
        dist = self.actor.get_dist(obs)
        log_prob = dist.log_prob(action)

        action_scale = torch.tensor((self.action_space.high - self.action_space.low) / 2, device=action.device)
        log_prob = log_prob - torch.log(action_scale * (1 - action.pow(2)) + self.__eps).sum(-1, keepdim=True)

        return log_prob
    
    def sample_action(self, obs, deterministic=False):
        action, _ = self(obs, deterministic)
        return action.cpu().detach().numpy()
    
    def learn(self, batch: TensorBatch, traj_batch: TensorBatch) -> Dict[str, float]:
        self.total_it += 1
        obs, actions, rewards, next_obs, terminals = batch
        obs_t, action_t, reward_t, next_obs_t = traj_batch
        log_dict = {}        

        segment_size = obs_t.size(0)//(2*self.traj_batch_size)
        half_size = self.traj_batch_size
        
        # trajectory L1 loss
        with torch.no_grad():
            next_v_t = self.vf(next_obs_t).flatten()
            target_t = reward_t.flatten() + self._gamma * next_v_t.flatten()
            target_t = torch.sum(target_t.view(-1,segment_size), dim=-1)

        q1_t, q2_t = self.qf1(obs_t, action_t).flatten(), self.qf2(obs_t, action_t).flatten()
        q1_t, q2_t = torch.sum(q1_t.view(-1,segment_size), dim=-1), torch.sum(q2_t.view(-1,segment_size), dim=-1)
        traj_reg_loss_1 = (q1_t[:half_size]-target_t[:half_size]-q1_t[half_size:]+target_t[half_size:]).abs().mean() / segment_size
        traj_reg_loss_2 = (q2_t[:half_size]-target_t[:half_size]-q2_t[half_size:]+target_t[half_size:]).abs().mean() / segment_size

        # adversarial loss
        a, log_probs = self(obs)
        log_dict["train/log_probs"] = log_probs.mean().item()

        q1_cur, q2_cur = self.qf1(obs, a.detach()).flatten(), self.qf2(obs, a.detach()).flatten()
        q1, q2 = self.qf1(obs, actions).flatten(), self.qf2(obs, actions).flatten()
        adv_loss_1 = torch.mean(q1_cur-q1)
        adv_loss_2 = torch.mean(q2_cur-q2)
        critic1_loss = self.lam * adv_loss_1 + traj_reg_loss_1
        critic2_loss = self.lam * adv_loss_2 + traj_reg_loss_2

        # update critic
        log_dict["train/traj_reg_loss"] = traj_reg_loss_1.item() + traj_reg_loss_2.item()
        log_dict["train/adv_loss"] = adv_loss_1.item() + adv_loss_1.item()
        self.qf1_optim.zero_grad()
        critic1_loss.backward()
        self.qf1_optim.step()
        self.qf2_optim.zero_grad()
        critic2_loss.backward()
        self.qf2_optim.step()

        # update state-value funciton
        with torch.no_grad():
            target = torch.minimum(self.qf1_old(obs,a.detach()), self.qf2_old(obs,a.detach())).flatten()
        v = self.vf(obs).flatten()
        log_dict["train/v_scale"] = v.mean().item()
        advantage = target - v
        v_loss = torch.mean(advantage**2)
        self.vf_optim.zero_grad()
        v_loss.backward()
        self.vf_optim.step()
        log_dict["train/v_loss"] = v_loss.item()

        self._sync_weight()

        # update actor
        idx = random.choice([0,1])
        if idx == 0:
            q1a = self.qf1(obs, a).flatten()
            actor_loss = (self._alpha * log_probs.flatten() - q1a).mean()
        else:
            q2a = self.qf2(obs, a).flatten()
            actor_loss = (self._alpha * log_probs.flatten() - q2a).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()
        log_dict["train/actor_loss"] = actor_loss.item()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

            log_dict["train/alpha_loss"] = alpha_loss.item()
            log_dict["train/alpha"] = self._alpha.item()
        
        return log_dict
    
    def average_weight_norm(self):
        total_norm = 0.0
        num_params = 0
        
        for param in self.qf1.parameters():
            if param.requires_grad and param.nelement()>1:
                param_norm = torch.norm(param)
                total_norm += param_norm.item()
                num_params += 1

        for param in self.qf2.parameters():
            if param.requires_grad and param.nelement()>1:
                param_norm = torch.norm(param)
                total_norm += param_norm.item()
                num_params += 1

        if num_params == 0:
            return 0.0  # Handle cases where there are no parameters
        return total_norm / num_params


@pyrallis.wrap()
def train(config: TrainConfig):
    if config.device==None:
        config.device = "cuda" if torch.cuda.is_available() else "cpu"
    elif config.device!=None and config.device.isdigit():
        assert torch.cuda.device_count()>int(config.device), "invalid device"
        os.environ['CUDA_VISIBLE_DEVICES'] = f"{config.device}"
        config.device = "cuda"

    log_path = os.path.join(config.log_path, f"APPO_lambda_{config.lam}")
    writer = SummaryWriter(log_path)
    logger = Logger(writer=writer,log_path=log_path)
    
    env = utils_env.make_metaworld_env(config.env, config.seed)

    if config.dataset == "medium-replay":
        dataset = utils_env.MetaWorld_mr_dataset(config)
    elif config.dataset == "medium-expert":
        dataset = utils_env.MetaWorld_me_dataset(config)

    state_dim = env.observation_space.shape[0]  # 39 for metaworld
    action_dim = env.action_space.shape[0]  # 4 for metaworld
    dimension = state_dim + action_dim

    seed = config.seed
    set_seed(seed, env)
    
    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    
    if config.use_reward_model:
        model = reward_model.RewardModel(config, None, None, None, dimension, None)
        model.load_model(config.log_path)
        dataset["rewards"] = model.get_reward(dataset)
        print("labeled by reward model")

    if config.normalize_reward:
        modify_reward(
            dataset,
            max_episode_steps=500,
            trivial_reward=config.trivial_reward,
        )

    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    config.buffer_size = dataset["observations"].shape[0]

    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_dataset(dataset)

    # create policy model
    actor_backbone = MLP(input_dim=state_dim, hidden_dims=[256, 256])
    qf1_backbone = MLP(input_dim=state_dim + action_dim, hidden_dims=[256, 256])
    qf2_backbone = MLP(input_dim=state_dim + action_dim, hidden_dims=[256, 256])
    vf_backbone = MLP(input_dim=state_dim, hidden_dims=[256, 256])
    dist = DiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=action_dim,
        unbounded=True,
        conditioned_sigma=True
    )

    actor = ActorProb(actor_backbone, dist, config.device)
    qf1 = Critic(qf1_backbone, config.device)
    qf2 = Critic(qf2_backbone, config.device)
    vf = Critic(vf_backbone, config.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)
    qf1_optim = torch.optim.Adam(qf1.parameters(), lr=config.critic_lr)
    qf2_optim = torch.optim.Adam(qf2.parameters(), lr=config.critic_lr)
    vf_optim = torch.optim.Adam(vf.parameters(), lr=config.critic_lr)

    if config.auto_alpha:
        target_entropy = - action_dim

        log_alpha = torch.ones(1, requires_grad=True, device=config.device)
        log_alpha.data.copy_(np.log(config.alpha))
        alpha_optim = torch.optim.Adam([log_alpha], lr=config.alpha_lr)
        alpha = (target_entropy, log_alpha, alpha_optim)
    else:
        alpha = config.alpha

    kwargs = {
        "actor": actor,
        "qf1": qf1,
        "qf2": qf2,
        "vf": vf,
        "actor_optim": actor_optim,
        "qf1_optim": qf1_optim,
        "qf2_optim": qf2_optim,
        "vf_optim": vf_optim,
        "action_space": env.action_space,
        "dist": dist,
        "device": config.device,
        # algorithm parameters
        "tau": config.tau,
        "gamma": config.discount,
        "lam": config.lam,
        "alpha": alpha,
        "batch_size": config.batch_size,
        "traj_batch_size": config.traj_batch_size,
        "max_steps": config.max_timesteps
    }


    # Initialize actor
    trainer = APPO(**kwargs)
    trainer.train()

    for t in tqdm.tqdm(range(int(config.max_timesteps))):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        traj_batch = replay_buffer.sample_trajectory(2 * config.traj_batch_size, config.segment_size)
        traj_batch = [b.to(config.device) for b in traj_batch]
        log_dict = trainer.learn(batch, traj_batch)
        if (t + 1) % 5000 == 0:
            for k,v in log_dict.items():
                logger.record(k, v, trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, eval_success = eval_actor(
                env,
                config.env,
                trainer,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_success = eval_success.mean() * 100
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , success: {eval_success:.3f}"
            )
            print("---------------------------------------")
            logger.record("eval/eval_success", eval_success, trainer.total_it)
            logger.record("eval/eval_score", eval_score, trainer.total_it)


if __name__ == "__main__":
    train()
