import os
import random
import time
from dataclasses import dataclass
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter
from types import SimpleNamespace
from cleanrl_utils.buffers import ReplayBuffer
import mo_gymnasium

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = None
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    env_id: str = "mo-walker2d-v5"
    """the environment id of the task"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    num_envs: int = 1
    """the number of parallel game environments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256
    """the batch size of sample from the reply memory"""
    learning_starts: int = 5e3
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1  
    """the frequency of updates for the target nerworks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    pref: str = "1,0"
    """comma-separated preference for vector reward, e.g. '1,0' or '0.8,0.2'"""
    m: int = 2
    """number of objectives (mo-walker2d-v5 uses 2)"""
    mu: float = 0.05
    """used in softmax weighting (e.g., 0.01, 1)"""
    jub: str = "2000,1000"
    """comma-separated upper bounds J^ub_i per objective (e.g., forward, cost)"""
    ema_beta: float = 0.05
    """EMA factor for estimating \hat{J}_i"""
    net_depth: int = 2
    """number of layers for both Actor and Critic networks"""
    net_hidden: int = 256
    """hidden width for both Actor and Critic networks"""
    
class ScalarizeMOReward(gym.RewardWrapper):
    def __init__(self, env, pref):
        super().__init__(env)
        self.pref = np.asarray(pref, dtype=np.float32)
        s = self.pref.sum()
        if s > 0:
            self.pref = self.pref / s
        self._ep_sum = np.zeros_like(self.pref, dtype=np.float32)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._ep_sum[...] = 0.0
        return obs, info

    def step(self, action):
        obs, reward_vec, terminated, truncated, info = self.env.step(action)
        reward_vec = np.asarray(reward_vec, dtype=np.float32)
        self._ep_sum += reward_vec  
        scalar_r = float(np.dot(self.pref, reward_vec))

        if info is None:
            info = {}
        info = dict(info)
        info["reward_vec"] = reward_vec  

        if terminated or truncated:
            info["vec_return"] = self._ep_sum.copy()

        return obs, scalar_r, terminated, truncated, info

def vector_infos_to_list(infos, n_envs):
    out = [dict() for _ in range(n_envs)]
    if not isinstance(infos, dict):
        return infos
    for k, v in infos.items():
        if isinstance(v, (list, tuple, np.ndarray)):
            if len(v) == n_envs:
                for i in range(n_envs):
                    if v[i] is not None:
                        out[i][k] = v[i]
        else:
            pass
    return out
    
class CastObs(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        assert isinstance(self.observation_space, gym.spaces.Box)
        low  = np.full(self.observation_space.shape, -np.inf, dtype=np.float32)
        high = np.full(self.observation_space.shape,  np.inf, dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)

    def observation(self, observation):
        return np.asarray(observation, dtype=np.float32)

class MORLReplayBuffer(ReplayBuffer):
    def __init__(self, *args, m: int, **kwargs):
        super().__init__(*args, **kwargs)
        self.m = m
        self.reward_vecs = np.zeros((self.buffer_size, self.n_envs, m), dtype=np.float32)

    def add(self, obs, next_obs, actions, rewards, dones, infos_list):
        super().add(obs, next_obs, actions, rewards, dones, infos_list)
        idx = (self.pos - 1) % self.buffer_size
        for i in range(self.n_envs):
            rv = None
            if infos_list is not None and i < len(infos_list):
                rv = infos_list[i].get("reward_vec", None)
            if rv is None:
                self.reward_vecs[idx, i, :] = 0.0
            else:
                self.reward_vecs[idx, i, :] = np.asarray(rv, dtype=np.float32)

    def sample(self, batch_size):
        max_size = self.buffer_size if self.full else self.pos
        assert max_size > 0, "ReplayBuffer is empty!"
        b_inds = np.random.randint(0, max_size, size=batch_size)
        env_inds = np.random.randint(0, self.n_envs, size=batch_size)

        obs      = torch.tensor(self.observations[b_inds, env_inds], device=self.device)
        actions  = torch.tensor(self.actions[b_inds, env_inds], device=self.device)
        next_obs = torch.tensor(self.next_observations[b_inds, env_inds], device=self.device)
        dones    = torch.tensor(self.dones[b_inds, env_inds], device=self.device).view(-1, 1)
        rewards  = torch.tensor(self.rewards[b_inds, env_inds], device=self.device).view(-1, 1)
        rvecs    = torch.tensor(self.reward_vecs[b_inds, env_inds], device=self.device, dtype=torch.float32)

        return SimpleNamespace(
            observations=obs,
            actions=actions,
            next_observations=next_obs,
            dones=dones,
            rewards=rewards,
            reward_vecs=rvecs,
        )

def make_env(env_id, seed, idx, capture_video, run_name, pref_vec):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array", disable_env_checker=True)
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id, disable_env_checker=True)

        env = ScalarizeMOReward(env, pref=pref_vec)
        obs_space = env.observation_space
        env = gym.wrappers.TransformObservation(
            env,
            lambda obs: np.asarray(obs, dtype=np.float32),
            observation_space=gym.spaces.Box(
                low=np.full(obs_space.shape, -np.inf, dtype=np.float32),
                high=np.full(obs_space.shape,  np.inf, dtype=np.float32),
                dtype=np.float32,
            ),
        )
        env = gym.wrappers.TimeLimit(env, max_episode_steps=500)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env
    return thunk

class SoftQNetwork(nn.Module):
    def __init__(self, env, depth=2, hidden_size=256):
        super().__init__()
        input_dim = np.prod(env.single_observation_space.shape) + np.prod(env.single_action_space.shape)
        layers = [nn.Linear(input_dim, hidden_size), nn.ReLU()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden_size, hidden_size), nn.ReLU()]
        layers += [nn.Linear(hidden_size, 1)]

        self.net = nn.Sequential(*layers)

    def forward(self, x, a):
        x = torch.cat([x, a], dim=1)
        return self.net(x)


LOG_STD_MAX = 2
LOG_STD_MIN = -5


class Actor(nn.Module):
    def __init__(self, env, hidden=256, depth=2):
        super().__init__()
        obs_dim = int(np.prod(env.single_observation_space.shape))
        act_dim = int(np.prod(env.single_action_space.shape))

        layers = [nn.Linear(obs_dim, hidden), nn.ReLU()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.ReLU()]
        self.backbone = nn.Sequential(*layers)

        self.fc_mean   = nn.Linear(hidden, act_dim)
        self.fc_logstd = nn.Linear(hidden, act_dim)
        self.register_buffer(
            "action_scale",
            torch.tensor(
                (env.single_action_space.high - env.single_action_space.low) / 2.0,
                dtype=torch.float32,
            ),
        )
        self.register_buffer(
            "action_bias",
            torch.tensor(
                (env.single_action_space.high + env.single_action_space.low) / 2.0,
                dtype=torch.float32,
            ),
        )

    def forward(self, x):
        x = self.backbone(x)
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        return mean, log_std

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample() 
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


if __name__ == "__main__":
    args = tyro.cli(Args)
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    pref_vec = [float(x) for x in args.pref.split(",")]
    assert len(pref_vec) == args.m
    p = torch.tensor(pref_vec, dtype=torch.float32).to(device)  
    p = p / p.sum()
    jub_vec = torch.tensor([float(x) for x in args.jub.split(",")], dtype=torch.float32)
    assert len(jub_vec) == args.m
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    wandb_defined = False
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic
    envs = gym.vector.SyncVectorEnv(
    [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, pref_vec) for i in range(args.num_envs)]
)
    assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

    max_action = float(envs.single_action_space.high[0])
    
    Jhat = torch.zeros(args.m, dtype=torch.float32, device=device)  
    ema_beta = args.ema_beta

    D = args.net_depth
    H = args.net_hidden

    actor = Actor(envs, hidden=H, depth=D).to(device)

    qf1_list = nn.ModuleList([SoftQNetwork(envs, depth=D, hidden_size=H).to(device) for _ in range(args.m)])
    qf2_list = nn.ModuleList([SoftQNetwork(envs, depth=D, hidden_size=H).to(device) for _ in range(args.m)])
    qf1_targ_list = nn.ModuleList([SoftQNetwork(envs, depth=D, hidden_size=H).to(device) for _ in range(args.m)])
    qf2_targ_list = nn.ModuleList([SoftQNetwork(envs, depth=D, hidden_size=H).to(device) for _ in range(args.m)])

    for i in range(args.m):
        qf1_targ_list[i].load_state_dict(qf1_list[i].state_dict())
        qf2_targ_list[i].load_state_dict(qf2_list[i].state_dict())

    q_optim = optim.Adam(
        list(qf1_list.parameters()) + list(qf2_list.parameters()),
        lr=args.q_lr
    )
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)

    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha

    envs.single_observation_space.dtype = np.float32
    rb = MORLReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        n_envs=args.num_envs,
        handle_timeout_termination=True,
        m=args.m,
    )

    start_time = time.time()

    obs, _ = envs.reset(seed=args.seed)
    ep_ret  = np.zeros(args.num_envs, dtype=np.float64)
    ep_len  = np.zeros(args.num_envs, dtype=np.int32)
    vec_ret = np.zeros((args.num_envs, args.m), dtype=np.float64)
    episode_idx = 0  
    last_actor_loss = None
    last_alpha_loss = None
    last_Qpi_mean = None                
    last_q_means = None                  
    last_per_obj_td = None               

    for global_step in range(args.total_timesteps):
        if global_step < args.learning_starts:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()

        next_obs, rewards, terminations, truncations, infos = envs.step(actions)
        if "reward_vec" in infos:
            vals = [rv for rv in infos["reward_vec"] if rv is not None]
            if len(vals):
                arr = np.stack(vals)  
                for j in range(min(args.m, arr.shape[1])):
                    writer.add_scalar(f"charts/vec_reward_step_mean/obj{j}",
                                    float(arr[:, j].mean()), global_step)

        writer.add_scalar("charts/reward_per_step_mean", float(np.mean(rewards)), global_step)
        dones = np.logical_or(terminations, truncations)

        ep_ret += rewards
        ep_len += 1

        if isinstance(infos, dict) and "reward_vec" in infos and infos["reward_vec"] is not None:
            for i in range(args.num_envs):
                rv = infos["reward_vec"][i] if i < len(infos["reward_vec"]) else None
                if rv is not None:
                    rv = np.asarray(rv, dtype=np.float64)
                    if rv.shape: 
                        vec_ret[i] += rv

        for i in range(args.num_envs):
            if dones[i]:
                episode_idx += 1

                writer.add_scalar("charts/episodic_return", float(ep_ret[i]), global_step)
                writer.add_scalar("charts/episodic_length", int(ep_len[i]),  global_step)
                writer.add_scalar("charts/vec_return/forward_reward", float(vec_ret[i, 0]), global_step)
                writer.add_scalar("charts/vec_return/control_cost",  float(vec_ret[i, 1]), global_step)
                for j in range(args.m):
                    writer.add_scalar(f"charts/vec_return/obj{j}", float(vec_ret[i, j]), global_step)

                if args.track:
                    if not wandb_defined:
                        wandb.define_metric("by_ep/episode")                       
                        wandb.define_metric("by_ep/*", step_metric="by_ep/episode")
                        wandb_defined = True

                    wandb.log({"by_ep/episode": episode_idx})
                    wb_payload = {
                        "by_ep/episodic_return": float(ep_ret[i]),
                        "by_ep/episodic_length": int(ep_len[i]),
                        "by_ep/vec_return/obj0": float(vec_ret[i, 0]),
                        "by_ep/vec_return/obj1": float(vec_ret[i, 1]),
                    }
                    for j in range(args.m):
                        wb_payload[f"by_ep/Jhat_{j}"] = float(Jhat[j].item())
                    wandb.log(wb_payload)
                vret = vec_ret[i]  
                vret_t = torch.tensor(vret, dtype=torch.float32, device=device)
                Jhat = (1 - ema_beta) * Jhat + ema_beta * vret_t

                ep_ret[i]  = 0.0
                ep_len[i]  = 0
                vec_ret[i] = 0.0

        dones = np.logical_or(terminations, truncations)

        real_next_obs = next_obs.copy()
        final_obs = infos.get("final_observation", None) if isinstance(infos, dict) else None
        if final_obs is not None:
            for i, done in enumerate(dones):
                if done and final_obs[i] is not None:
                    real_next_obs[i] = final_obs[i]

        infos_list = vector_infos_to_list(infos, args.num_envs)
        rb.add(obs, real_next_obs, actions, rewards, dones, infos_list)
        obs = next_obs

        if global_step > args.learning_starts:
            # === Critic update ===
            data = rb.sample(args.batch_size)

            with torch.no_grad():
                next_a, next_logp, _ = actor.get_action(data.next_observations)
                min_q_next_per_obj = []
                for i in range(args.m):
                    q1n = qf1_targ_list[i](data.next_observations, next_a)
                    q2n = qf2_targ_list[i](data.next_observations, next_a)
                    min_qn = torch.min(q1n, q2n)
                    min_q_next_per_obj.append(min_qn)
                min_q_next = torch.cat(min_q_next_per_obj, dim=1)         
                r_vec = data.reward_vecs                                
                next_q_value_vec = r_vec + (1 - data.dones) * args.gamma * (min_q_next - alpha * next_logp)

            q_loss = torch.zeros((), device=device)
            mean_q1, mean_q2 = [], []
            per_obj_td = []  

            for i in range(args.m):
                q1 = qf1_list[i](data.observations, data.actions).view(-1)
                q2 = qf2_list[i](data.observations, data.actions).view(-1)
                target_i = next_q_value_vec[:, i].view(-1)

                q1_loss = F.mse_loss(q1, target_i)
                q2_loss = F.mse_loss(q2, target_i)
                q_loss = q_loss + q1_loss + q2_loss

                mean_q1.append(q1.mean().item())
                mean_q2.append(q2.mean().item())
                per_obj_td.append(0.5 * (q1_loss.item() + q2_loss.item()))

            q_optim.zero_grad()
            q_loss.backward()
            q_optim.step()

            last_q_means = {'q1': mean_q1, 'q2': mean_q2}
            last_per_obj_td = per_obj_td

            # === Actor update ===
            if global_step % args.policy_frequency == 0:
                for net in list(qf1_list) + list(qf2_list):
                    for p_ in net.parameters():
                        p_.requires_grad_(False)

                for _ in range(args.policy_frequency):
                    pi, log_pi, _ = actor.get_action(data.observations)   
                    q_pi_per_obj = []
                    for i in range(args.m):
                        q1_pi = qf1_list[i](data.observations, pi) 
                        q2_pi = qf2_list[i](data.observations, pi)  
                        q_pi_per_obj.append(torch.min(q1_pi, q2_pi))
                    Qpi = torch.cat(q_pi_per_obj, dim=1)          

                    logits = (p * (jub_vec.to(device) - Jhat)) / args.mu
                    logits = torch.clamp(logits, -20.0, 20.0)     
                    lam = torch.softmax(logits, dim=0).detach()  
                    w   = (lam * p).detach()                      
                    weighted_q = (Qpi * w.view(1, -1)).sum(dim=1, keepdim=True) 
                    actor_loss = (alpha * log_pi - weighted_q).mean()

                    actor_optimizer.zero_grad(set_to_none=True)
                    actor_loss.backward()
                    actor_optimizer.step()

                    last_actor_loss = actor_loss.item()
                    last_Qpi_mean = Qpi.mean(dim=0).detach().cpu()

                    if args.autotune:
                        with torch.no_grad():
                            _, log_pi_new, _ = actor.get_action(data.observations)
                        alpha_loss = (-log_alpha.exp() * (log_pi_new + target_entropy)).mean()
                        a_optimizer.zero_grad(set_to_none=True)
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()
                        last_alpha_loss = alpha_loss.item()

                    for i in range(args.m):
                        writer.add_scalar(f"morl/lam_{i}", float(lam[i]), global_step)
                        writer.add_scalar(f"morl/w_{i}",   float(w[i]),   global_step)
                        writer.add_scalar(f"morl/Jhat_{i}", float(Jhat[i]), global_step)

                for net in list(qf1_list) + list(qf2_list):
                    for p_ in net.parameters():
                        p_.requires_grad_(True)

            if global_step % args.target_network_frequency == 0:
                for i in range(args.m):
                    for p_src, p_tgt in zip(qf1_list[i].parameters(), qf1_targ_list[i].parameters()):
                        p_tgt.data.copy_(args.tau * p_src.data + (1 - args.tau) * p_tgt.data)
                    for p_src, p_tgt in zip(qf2_list[i].parameters(), qf2_targ_list[i].parameters()):
                        p_tgt.data.copy_(args.tau * p_src.data + (1 - args.tau) * p_tgt.data)

            if global_step % 100 == 0:
                if last_q_means is not None:
                    for i in range(args.m):
                        writer.add_scalar(f"losses/q{i}/q1_mean", last_q_means['q1'][i], global_step)
                        writer.add_scalar(f"losses/q{i}/q2_mean", last_q_means['q2'][i], global_step)
                if last_per_obj_td is not None:
                    for i in range(args.m):
                        writer.add_scalar(f"losses/q{i}_td_mse", last_per_obj_td[i], global_step)
                writer.add_scalar("losses/q_total_loss", q_loss.item(), global_step)

                if last_actor_loss is not None:
                    writer.add_scalar("losses/actor_loss", last_actor_loss, global_step)
                if last_Qpi_mean is not None:
                    for i in range(args.m):
                        writer.add_scalar(f"actor/Qpi_mean_{i}", float(last_Qpi_mean[i]), global_step)

                writer.add_scalar("losses/alpha", alpha, global_step)
                if args.autotune and (last_alpha_loss is not None):
                    writer.add_scalar("losses/alpha_loss", last_alpha_loss, global_step)

                sps = int(global_step / (time.time() - start_time))
                print("SPS:", sps)
                writer.add_scalar("charts/SPS", sps, global_step)

                if isinstance(infos, dict) and "reward_vec" in infos:
                    vals = [rv for rv in infos["reward_vec"] if rv is not None]
                    if len(vals):
                        arr = np.stack(vals)  
                        for i in range(min(args.m, arr.shape[1])):
                            writer.add_scalar(f"charts/vec_reward_step_mean/obj{i}",
                                            float(arr[:, i].mean()), global_step)

    envs.close()
    writer.close()
