# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy
import os
import random
import time
from dataclasses import dataclass

import robosuite as suite
from robosuite.wrappers import GymWrapper
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

try:
    import gymnasium.spaces as gymnasium_spaces
except ImportError:  # pragma: no cover
    gymnasium_spaces = None


def _ensure_gymnasium_box(space, dtype=np.float32):
    if gymnasium_spaces is None:
        return space
    if isinstance(space, gymnasium_spaces.Box):
        return space
    if hasattr(space, "low") and hasattr(space, "high"):
        low = np.asarray(space.low, dtype=dtype)
        high = np.asarray(space.high, dtype=dtype)
        return gymnasium_spaces.Box(low=low, high=high, dtype=dtype)
    return space

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "EggHandOver-v0"
    """the environment id of the task"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256
    """the batch size of sample from the reply memory"""
    learning_starts: int = 5e3
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.
    """the frequency of updates for the target nerworks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    
    # SPR specific arguments
    spr_lr: float = 1e-3
    """learning rate for SPR head"""
    spr_hidden_dim: int = 256
    """hidden dimension for SPR head"""
    spr_projection_dim: int = 128
    """projection dimension for SPR head"""
    spr_prediction_steps: int = 10
    """number of steps to predict into the future"""
    spr_momentum_tau: float = 0.01
    """momentum coefficient for target network updates"""
    spr_weight: float = 1.0
    """weight for SPR loss"""
    
    lle_learning_rate_trunk: float = 1e-3
    ssl_method: str = 'spr'

def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        env_name, robot = env_id.split('-')
        env = suite.make(
            env_name=env_name, # try with other tasks like "Stack" and "Door"
            robots=robot,  # try with other robots like "Sawyer" and "Jaco"
            has_renderer=False,
            has_offscreen_renderer=False,
            use_object_obs=True,
            use_camera_obs=False,
            reward_shaping=True,
            )
        env = GymWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.seed(seed)
        env.action_space.seed(seed)
        return env

    return thunk

# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.feature_dim = np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape)
        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

        # SPR components (Online)
        self.spr_projection = nn.Sequential(
            nn.Linear(256, args.spr_hidden_dim),
            nn.ReLU(),
            nn.Linear(args.spr_hidden_dim, args.spr_projection_dim)
        )
        self.spr_prediction = nn.Sequential(
            nn.Linear(args.spr_projection_dim, args.spr_hidden_dim),
            nn.ReLU(),
            # Output K predictions
            nn.Linear(args.spr_hidden_dim, args.spr_prediction_steps * args.spr_projection_dim)
        )

        # SPR components (Target - only projection needed)
        self.spr_projection_target = nn.Sequential(
            nn.Linear(256, args.spr_hidden_dim),
            nn.ReLU(),
            nn.Linear(args.spr_hidden_dim, args.spr_projection_dim)
        )
        # Initialize target network same as online network
        self.spr_projection_target.load_state_dict(self.spr_projection.state_dict())
        # Target network is not trained directly
        for param in self.spr_projection_target.parameters():
            param.requires_grad = False


    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def get_representation(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    def get_spr_projection(self, x, a):
        rep = self.get_representation(x, a)
        return self.spr_projection(rep)

    def get_spr_target_projection(self, x, a):
        rep = self.get_representation(x, a)
        return self.spr_projection_target(rep)

    def get_spr_prediction(self, z):
        # Predict K steps
        preds = self.spr_prediction(z)
        # Reshape to (batch_size, K, projection_dim)
        return preds.view(preds.size(0), args.spr_prediction_steps, args.spr_projection_dim)

LOG_STD_MAX = 2
LOG_STD_MIN = -5

class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.feature_dim = np.array(env.single_observation_space.shape).prod()
        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
        # action rescaling
        self.register_buffer(
            "action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float32)
        )

        # SPR components (Online)
        self.spr_projection = nn.Sequential(
            nn.Linear(256, args.spr_hidden_dim),
            nn.ReLU(),
            nn.Linear(args.spr_hidden_dim, args.spr_projection_dim)
        )
        self.spr_prediction = nn.Sequential(
            nn.Linear(args.spr_projection_dim, args.spr_hidden_dim),
            nn.ReLU(),
            # Output K predictions
            nn.Linear(args.spr_hidden_dim, args.spr_prediction_steps * args.spr_projection_dim)
        )

        # SPR components (Target - only projection needed)
        self.spr_projection_target = nn.Sequential(
            nn.Linear(256, args.spr_hidden_dim),
            nn.ReLU(),
            nn.Linear(args.spr_hidden_dim, args.spr_projection_dim)
        )
        # Initialize target network same as online network
        self.spr_projection_target.load_state_dict(self.spr_projection.state_dict())
         # Target network is not trained directly
        for param in self.spr_projection_target.parameters():
            param.requires_grad = False


    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std

    def get_representation(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    def get_spr_projection(self, x):
        rep = self.get_representation(x)
        return self.spr_projection(rep)

    def get_spr_target_projection(self, x):
        rep = self.get_representation(x)
        return self.spr_projection_target(rep)

    def get_spr_prediction(self, z):
        # Predict K steps
        preds = self.spr_prediction(z)
         # Reshape to (batch_size, K, projection_dim)
        return preds.view(preds.size(0), args.spr_prediction_steps, args.spr_projection_dim)

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

# Cosine similarity loss function (from SPR paper appendix)
def spr_loss_fn(pred, target):
    pred_norm = F.normalize(pred, dim=-1, p=2)
    target_norm = F.normalize(target, dim=-1, p=2)
    loss = - (pred_norm * target_norm).sum(dim=-1)
    return loss.mean()

if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
        )

    args = tyro.cli(Args)
    run_name = f"{args.env_id}__{args.exp_name}__{args.ssl_method}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs-robosuite/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    
    # Write all arguments to a file
    with open(f"runs-robosuite/{run_name}/args.txt", "w") as f:
        for key, value in vars(args).items():
            f.write(f"{key}: {value}\n")

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
    observation_space = _ensure_gymnasium_box(envs.single_observation_space, dtype=np.float32)
    action_space = _ensure_gymnasium_box(envs.single_action_space, dtype=np.float32)
    box_types = (gym.spaces.Box,)
    if gymnasium_spaces is not None:
        box_types = box_types + (gymnasium_spaces.Box,)
    assert isinstance(action_space, box_types), "only continuous action space is supported"

    max_action = float(action_space.high[0])

    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())

    # Initialize optimizers
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)

    # Initialize SPR optimizers
    spr_actor_optimizer = optim.Adam(
        list(actor.spr_projection.parameters()) + 
        list(actor.spr_prediction.parameters()), 
        lr=args.spr_lr
    )
    spr_qf1_optimizer = optim.Adam(
        list(qf1.spr_projection.parameters()) + 
        list(qf1.spr_prediction.parameters()), 
        lr=args.spr_lr
    )
    spr_qf2_optimizer = optim.Adam(
        list(qf2.spr_projection.parameters()) + 
        list(qf2.spr_prediction.parameters()), 
        lr=args.spr_lr
    )

    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha

    rb = ReplayBuffer(
        args.buffer_size,
        observation_space,
        action_space,
        device,
        handle_timeout_termination=False,
        n_envs=envs.num_envs # Needed for sequence sampling logic later
    )
    start_time = time.time()

    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    # Need to store the last obs and action for sequence sampling
    last_obs = obs
    last_action = np.zeros((envs.num_envs, envs.single_action_space.shape[0]))

    for global_step in tqdm(range(args.total_timesteps)):
        # ALGO LOGIC: put action logic here
        if global_step < args.learning_starts:
            actions = np.array([action_space.sample() for _ in range(envs.num_envs)])
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, terminations, infos = envs.step(actions)

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        if 1 in terminations:
            for info in infos:
                writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                break

        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
        obs = next_obs

        # ALGO LOGIC: training.
        if global_step > args.learning_starts:
            # Sample data - This is where sequence sampling *should* happen.
            # The current SB3 buffer samples individual transitions.
            # We simulate sequence fetching by getting the next K transitions for the sampled indices.
            # This is an approximation and has limitations (e.g., crossing episode boundaries).
            data = rb.sample(args.batch_size)

            # --- SAC Loss Calculation (Q-function and Actor) --- 
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            qf2_a_values = qf2(data.observations, data.actions).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()
            
            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support
                for _ in range(args.policy_frequency):
                    pi, log_pi, _ = actor.get_action(data.observations)
                    qf1_pi = qf1(data.observations, pi)
                    qf2_pi = qf2(data.observations, pi)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    if args.autotune:
                        with torch.no_grad():
                            _, log_pi, _ = actor.get_action(data.observations)
                        alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()

                if (global_step+1) % 2048 == 0:
                    # --- SPR Loss Calculation --- 
                    K = args.spr_prediction_steps
                    
                    # Sample sequence end indices
                    if rb.full:
                        buffer_current_size = rb.buffer_size
                    else:
                        buffer_current_size = rb.pos
                    
                    # Ensure we have enough history for a full sequence
                    low_bound = K 
                    high_bound = buffer_current_size if rb.full else rb.pos 
                    
                    actor_spr_loss = torch.tensor(0.0, device=device) # Initialize loss to zero
                    qf1_spr_loss = torch.tensor(0.0, device=device)
                    qf2_spr_loss = torch.tensor(0.0, device=device)

                    if high_bound > low_bound: # Check if enough data exists
                        end_indices = np.random.randint(low=low_bound, high=high_bound, size=args.batch_size)

                        # Retrieve sequences ending at end_indices
                        obs_t_list = []
                        act_t_list = []
                        target_obs_k_list = [] # List of sequences [obs_{t+1}, ..., obs_{t+K}]
                        target_act_k_list = [] # List of sequences [act_{t+1}, ..., act_{t+K}]

                        for end_idx in end_indices:
                            start_idx = end_idx - K
                            # Use modular arithmetic for index retrieval (handles wrap-around)
                            indices = np.arange(start_idx, end_idx + 1) % rb.buffer_size

                            obs_sequence = rb.observations[indices]
                            act_sequence = rb.actions[indices]

                            obs_t_list.append(obs_sequence[0])
                            act_t_list.append(act_sequence[0])

                            target_obs_k_list.append(obs_sequence[1:])
                            target_act_k_list.append(act_sequence[1:])

                        # Convert lists to tensors
                        obs_t = torch.as_tensor(np.array(obs_t_list), dtype=torch.float32, device=device)
                        act_t = torch.as_tensor(np.array(act_t_list), dtype=torch.float32, device=device)
                        target_obs_k = torch.as_tensor(np.array(target_obs_k_list), dtype=torch.float32, device=device)
                        target_act_k = torch.as_tensor(np.array(target_act_k_list), dtype=torch.float32, device=device)


                        with torch.no_grad():
                            # Calculate target projections z_{t+k} using target networks
                            target_zs_actor = []
                            target_zs_qf1 = []
                            target_zs_qf2 = []
                            for k in range(K):
                                obs_k = target_obs_k[:, k] # obs at t+k+1
                                act_k = target_act_k[:, k] # action at t+k+1
                                target_zs_actor.append(actor.get_spr_target_projection(obs_k))
                                target_zs_qf1.append(qf1.get_spr_target_projection(obs_k, act_k))
                                target_zs_qf2.append(qf2.get_spr_target_projection(obs_k, act_k))

                            target_zs_actor = torch.stack(target_zs_actor, dim=1) 
                            target_zs_qf1 = torch.stack(target_zs_qf1, dim=1)
                            target_zs_qf2 = torch.stack(target_zs_qf2, dim=1)

                        # Calculate online projections z_t and predictions p_k(z_t)
                        current_z_actor = actor.get_spr_projection(obs_t)
                        predicted_zs_actor = actor.get_spr_prediction(current_z_actor) 
                        
                        current_z_qf1 = qf1.get_spr_projection(obs_t, act_t)
                        predicted_zs_qf1 = qf1.get_spr_prediction(current_z_qf1)
                        
                        current_z_qf2 = qf2.get_spr_projection(obs_t, act_t)
                        predicted_zs_qf2 = qf2.get_spr_prediction(current_z_qf2)

                        # Calculate SPR cosine similarity losses
                        actor_spr_loss = spr_loss_fn(predicted_zs_actor, target_zs_actor)
                        qf1_spr_loss = spr_loss_fn(predicted_zs_qf1, target_zs_qf1)
                        qf2_spr_loss = spr_loss_fn(predicted_zs_qf2, target_zs_qf2)
                    # else: # Optional: Log if skipping SPR update
                        # print(f"Warning: Skipping SPR update at step {global_step} due to insufficient buffer samples ({high_bound} <= {low_bound}).")
                        
                    # Optimize SPR losses 
                    spr_total_loss = args.spr_weight * (actor_spr_loss + qf1_spr_loss + qf2_spr_loss)

                    spr_actor_optimizer.zero_grad()
                    spr_qf1_optimizer.zero_grad()
                    spr_qf2_optimizer.zero_grad()
                    # Only backprop if loss requires grad (i.e., not the initial zero tensor)
                    if spr_total_loss.requires_grad: 
                        spr_total_loss.backward()
                        spr_actor_optimizer.step()
                        spr_qf1_optimizer.step()
                        spr_qf2_optimizer.step()

                    # Log SPR losses
                    # Restore logging if needed
                    # writer.add_scalar("losses/actor_spr_loss", actor_spr_loss.item(), global_step)
                    # writer.add_scalar("losses/qf1_spr_loss", qf1_spr_loss.item(), global_step)
                    # writer.add_scalar("losses/qf2_spr_loss", qf2_spr_loss.item(), global_step)

            # update the target networks (SAC Q-networks)
            if global_step % args.target_network_frequency == 0:
                 for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                     target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                 for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                     target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                
                 # Update SPR target projection networks using momentum
                 with torch.no_grad():
                     for param, target_param in zip(actor.spr_projection.parameters(), actor.spr_projection_target.parameters()):
                         target_param.data.mul_(1 - args.spr_momentum_tau)
                         target_param.data.add_((args.spr_momentum_tau) * param.data)
                     for param, target_param in zip(qf1.spr_projection.parameters(), qf1.spr_projection_target.parameters()):
                         target_param.data.mul_(1 - args.spr_momentum_tau)
                         target_param.data.add_((args.spr_momentum_tau) * param.data)
                     for param, target_param in zip(qf2.spr_projection.parameters(), qf2.spr_projection_target.parameters()):
                         target_param.data.mul_(1 - args.spr_momentum_tau)
                         target_param.data.add_((args.spr_momentum_tau) * param.data)

            if global_step % 100 == 0:
                writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                writer.add_scalar("losses/alpha", alpha, global_step)
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                if args.autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

    envs.close()
    writer.close()
