# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy
import os
import random
import time
from dataclasses import dataclass
import csv

import robosuite as suite
from robosuite.wrappers import GymWrapper
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from tqdm import tqdm

try:
    import gymnasium.spaces as gymnasium_spaces
except ImportError:  # pragma: no cover
    gymnasium_spaces = None


def _ensure_gymnasium_box(space, dtype=np.float32):
    if gymnasium_spaces is None:
        return space
    if isinstance(space, gymnasium_spaces.Box):
        return space
    if hasattr(space, "low") and hasattr(space, "high"):
        low = np.asarray(space.low, dtype=dtype)
        high = np.asarray(space.high, dtype=dtype)
        return gymnasium_spaces.Box(low=low, high=high, dtype=dtype)
    return space


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "EggHandOver-v0"
    """the environment id of the task"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256
    """the batch size of sample from the reply memory"""
    learning_starts: int = 5e3
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.
    """the frequency of updates for the target nerworks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    
    # LLE specific arguments
    use_lcr: bool = False
    local_window_size: int = 40
    lle_batch_size: int = 128
    lle_learning_rate_W: float = 1e-2
    lle_learning_rate_Phi: float = 1e-5
    lle_gradient_steps_W: int = 1000
    lle_gradient_steps_Phi: int = 100
    lle_loss_reduction_threshold_W: float = 1e-6
    lle_loss_reduction_threshold_Phi: float = 1e-6
    lle_epochs: int = 1
    use_lle_projection: bool = False
    train_trunk: bool = False
    next_state: bool = False
    lle_learning_rate_trunk: float = 1e-3


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        # Split the env_id to get the environment name and robot
        env_name, robot = env_id.split('-')
        env = suite.make(
            env_name=env_name, # try with other tasks like "Stack" and "Door"
            robots=robot,  # try with other robots like "Sawyer" and "Jaco"
            has_renderer=False,
            has_offscreen_renderer=False,
            use_object_obs=True,
            use_camera_obs=False,
            reward_shaping=True,
            )
        env = GymWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.seed(seed)
        env.action_space.seed(seed)
        return env

    return thunk

class NonNegLinear(nn.Linear):
    def forward(self, input):
        return F.linear(input, self.weight.clamp(min=0.))
    
class FeatureAttention(nn.Module):
    def __init__(self, num_heads=1):
        super(FeatureAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=1, num_heads=num_heads, batch_first=True)

    def forward(self, reward_features, lle_features):
        combined_features = torch.cat((reward_features, lle_features), dim=1).unsqueeze(-1)
        attn_output, attn_weights = self.attention(combined_features, combined_features, combined_features)
        return attn_output.squeeze(-1), attn_weights

class LLE(nn.Module):
    def __init__(self, env):
        super().__init__()
        # LLE network
        hidden_dim = 128
        
        self.lle_trunk = nn.Sequential(
            nn.Linear(np.array(envs.single_observation_space.shape).prod(), hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
        )
        
        # LLE Decoder
        if args.train_trunk:
            if args.next_state:
                # LLE Decoder
                self.lle_decoder = nn.Sequential(
                    nn.Linear(hidden_dim+np.prod(envs.single_action_space.shape), hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, np.array(envs.single_observation_space.shape).prod())
                )
            else:
                # LLE Decoder
                self.lle_decoder = nn.Sequential(
                    nn.Linear(hidden_dim, np.array(envs.single_observation_space.shape).prod())
                )

    def forward(self, x):
        projection = x
        predicted_representation = self.lle_trunk(projection)
        return predicted_representation

    
    def lle_decode(self, x):
        lle_reconstructed = self.lle_decoder(x)
        return lle_reconstructed

# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(256, 1)
        
        self.qnet_attention = FeatureAttention()
        self.qnet_attention_weights = None
        
    def apply_attention(self, reward_features, lle_features, attention_model):
        combined_features, attn_weights = attention_model(reward_features, lle_features)
        return combined_features, attn_weights

    def forward(self, x, a, lle_features):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        reward_features = torch.tanh(self.fc2(x))
        combined_qnet_features, self.qnet_attention_weights = self.apply_attention(reward_features, lle_features, self.qnet_attention)
        x = self.fc3(combined_qnet_features)
        return x

LOG_STD_MAX = 2
LOG_STD_MIN = -5

class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
         # action rescaling TODO: CHECK IF THIS WORKS FOR MULTIPLE ENVIRONMENTS
        self.register_buffer(
            "action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float32)
        )
        self.actor_attention = FeatureAttention()
        self.actor_attention_weights = None
        
        
    def apply_attention(self, reward_features, lle_features, attention_model):
        combined_features, attn_weights = attention_model(reward_features, lle_features)
        return combined_features, attn_weights

    def forward(self, x, lle_features): 
        x = F.relu(self.fc1(x))
        reward_features = torch.tanh(self.fc2(x))
        # add combined features
        combined_actor_features, self.actor_attention_weights = self.apply_attention(reward_features, lle_features, self.actor_attention)
        mean = self.fc_mean(combined_actor_features)
        log_std = self.fc_logstd(combined_actor_features)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std

    def get_action(self, x, lle_features):
        mean, log_std = self(x, lle_features)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
        )

    args = tyro.cli(Args)
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs-robosuite/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    # Write all arguments to a file
    with open(f"runs-robosuite/{run_name}/args.txt", "w") as f:
        for key, value in vars(args).items():
            f.write(f"{key}: {value}\n")
            
    # Open CSV files once at the start
    recon_csv_file = open(f"runs-robosuite/{run_name}/log_loss_recon.csv", mode='a', newline='\n')
    lle_loss_w_csv_file = open(f"runs-robosuite/{run_name}/lle_loss_log_W.csv", mode='a', newline='\n')
    lle_loss_phi_csv_file = open(f"runs-robosuite/{run_name}/lle_loss_log_Phi.csv", mode='a', newline='\n')

    recon_csv_writer = csv.writer(recon_csv_file)
    lle_loss_w_csv_writer = csv.writer(lle_loss_w_csv_file)
    lle_loss_phi_csv_writer = csv.writer(lle_loss_phi_csv_file)
    
    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    # env setup
    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
    observation_space = _ensure_gymnasium_box(envs.single_observation_space, dtype=np.float32)
    action_space = _ensure_gymnasium_box(envs.single_action_space, dtype=np.float32)
    box_types = (gym.spaces.Box,)
    if gymnasium_spaces is not None:
        box_types = box_types + (gymnasium_spaces.Box,)
    assert isinstance(action_space, box_types), "only continuous action space is supported"

    max_action = float(action_space.high[0])

    actor = Actor(envs).to(device)
    lle_model = LLE(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)
    

    if args.train_trunk:
        lle_trunk_optimizer = optim.Adam(list(lle_model.lle_trunk.parameters()) + list(lle_model.lle_decoder.parameters()), lr=args.lle_learning_rate_trunk)
    if args.use_lcr:
        lcr_optimizer = optim.Adam(lle_model.parameters(), lr=args.lle_learning_rate_Phi)
    
    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha

    rb = ReplayBuffer(
        args.buffer_size,
        observation_space,
        action_space,
        device,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    for global_step in tqdm(range(args.total_timesteps)):
        # ALGO LOGIC: put action logic here
        if global_step < args.learning_starts:
            actions = np.array([action_space.sample() for _ in range(envs.num_envs)])
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device), lle_model(torch.Tensor(obs).to(device)))
            actions = actions.detach().cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, terminations, infos = envs.step(actions)

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        if 1 in terminations:
            for info in infos:
                # print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                break

        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
        obs = next_obs
        
        # Log attention maps at 50%, 75% and 100% of training steps
        if global_step / args.total_timesteps in [0.5, 0.75, 0.98]:
            # Actor
            attention_map = actor.actor_attention_weights.detach().cpu().numpy().squeeze()
            # Save attention map as an image
            plt.imshow(attention_map, cmap='viridis', origin='lower')
            plt.colorbar()
            # plt.xticks(ticks=np.arange(attention_map.shape[1]), labels=np.arange(attention_map.shape[1]))  # Set x-ticks
            # plt.yticks(ticks=np.arange(attention_map.shape[0]), labels=np.arange(attention_map.shape[0]))  # Set y-ticks
            plt.title(f"Attention Map at Step {global_step}")
            plt.savefig(f"runs-robosuite/{run_name}/actor_attention_map_{global_step}.png")
            plt.close()
            attention_map = np.mean(qf1.qnet_attention_weights.detach().cpu().numpy().squeeze(), axis=0) #take average across batch
            # Save attention map as an image
            plt.imshow(attention_map, cmap='viridis', origin='lower')
            plt.colorbar()
            # plt.xticks(ticks=np.arange(attention_map.shape[1]), labels=np.arange(attention_map.shape[1]))  # Set x-ticks
            # plt.yticks(ticks=np.arange(attention_map.shape[0]), labels=np.arange(attention_map.shape[0]))  # Set y-ticks
            plt.title(f"Attention Map at Step {global_step}")
            plt.savefig(f"runs-robosuite/{run_name}/qf2_attention_map_{global_step}.png")
            plt.close()
            attention_map = np.mean(qf1.qnet_attention_weights.detach().cpu().numpy().squeeze(), axis=0) #take average across batch
            # Save attention map as an image
            plt.imshow(attention_map, cmap='viridis', origin='lower')
            plt.colorbar()
            # plt.xticks(ticks=np.arange(attention_map.shape[1]), labels=np.arange(attention_map.shape[1]))  # Set x-ticks
            # plt.yticks(ticks=np.arange(attention_map.shape[0]), labels=np.arange(attention_map.shape[0]))  # Set y-ticks
            plt.title(f"Attention Map at Step {global_step}")
            plt.savefig(f"runs-robosuite/{run_name}/qf1_attention_map_{global_step}.png")
            plt.close()

        # ALGO LOGIC: training.
        if global_step > args.learning_starts:
            data = rb.sample(args.batch_size)
            with torch.no_grad():
                lle_next_features = lle_model(data.next_observations)
                next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations, lle_next_features)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions, lle_next_features)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions, lle_next_features)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

            lle_features = lle_model(data.observations)
            qf1_a_values = qf1(data.observations, data.actions, lle_features).view(-1)
            qf2_a_values = qf2(data.observations, data.actions, lle_features).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            # optimize the model
            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()
            
            if global_step % 100 == 0:
                writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                writer.add_scalar("losses/alpha", alpha, global_step)
                # print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                if args.autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support
                for _ in range(
                    args.policy_frequency
                ):  # compensate for the delay by doing 'actor_update_interval' instead of 1
                    lle_features = lle_model(data.observations)
                    pi, log_pi, _ = actor.get_action(data.observations, lle_features)
                    qf1_pi = qf1(data.observations, pi, lle_features)
                    qf2_pi = qf2(data.observations, pi, lle_features)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    if args.autotune:
                        with torch.no_grad():
                            _, log_pi, _ = actor.get_action(data.observations, lle_features)
                        alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()
                        
                actual_state = data.observations
                actual_phi = lle_model.lle_trunk(actual_state)
                recon_state = lle_model.lle_decoder(actual_phi)
                # Calculate MSE loss between actual_state and recon_state
                recon_loss = F.smooth_l1_loss(actual_state, recon_state)
                # Optimize lle_trunk using the recon_loss
                lle_trunk_optimizer.zero_grad()
                recon_loss.backward()
                lle_trunk_optimizer.step()
                recon_csv_writer.writerow([global_step, recon_loss.item()])
                recon_csv_file.flush()  # Ensure data is written to disk

            # update the target networks
            if global_step % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                    
        

     # Save models
    torch.save(actor.state_dict(), f"runs-robosuite/{run_name}/actor.pth")
    torch.save(qf1.state_dict(), f"runs-robosuite/{run_name}/qf1.pth")
    torch.save(qf2.state_dict(), f"runs-robosuite/{run_name}/qf2.pth")
    torch.save(lle_model.state_dict(), f"runs-robosuite/{run_name}/lle_model.pth")
    
    envs.close()
    writer.close()
