# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy
import os
import random
import time
from dataclasses import dataclass
import csv

import robosuite as suite
from robosuite.wrappers import GymWrapper
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from tqdm import tqdm

try:
    import gymnasium.spaces as gymnasium_spaces
except ImportError:  # pragma: no cover
    gymnasium_spaces = None


def _ensure_gymnasium_box(space, dtype=np.float32):
    if gymnasium_spaces is None:
        return space
    if isinstance(space, gymnasium_spaces.Box):
        return space
    if hasattr(space, "low") and hasattr(space, "high"):
        low = np.asarray(space.low, dtype=dtype)
        high = np.asarray(space.high, dtype=dtype)
        return gymnasium_spaces.Box(low=low, high=high, dtype=dtype)
    return space


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "EggHandOver-v0"
    """the environment id of the task"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256
    """the batch size of sample from the reply memory"""
    learning_starts: int = 5e3
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.
    """the frequency of updates for the target nerworks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    
    # LLE specific arguments
    use_lcr: bool = False
    local_window_size: int = 40
    lle_batch_size: int = 128
    lle_learning_rate_W: float = 1e-2
    lle_learning_rate_Phi: float = 1e-5
    lle_gradient_steps_W: int = 1000
    lle_gradient_steps_Phi: int = 100
    lle_loss_reduction_threshold_W: float = 1e-6
    lle_loss_reduction_threshold_Phi: float = 1e-6
    lle_epochs: int = 1
    use_lle_projection: bool = False
    train_trunk: bool = False
    next_state: bool = False
    lle_learning_rate_trunk: float = 1e-3


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        # Split the env_id to get the environment name and robot
        env_name, robot = env_id.split('-')
        env = suite.make(
            env_name=env_name, # try with other tasks like "Stack" and "Door"
            robots=robot,  # try with other robots like "Sawyer" and "Jaco"
            has_renderer=False,
            has_offscreen_renderer=False,
            use_object_obs=True,
            use_camera_obs=False,
            reward_shaping=True,
            )
        env = GymWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.seed(seed)
        env.action_space.seed(seed)
        return env

    return thunk

class NonNegLinear(nn.Linear):
    def forward(self, input):
        return F.linear(input, self.weight.clamp(min=0.))
    
class FeatureAttention(nn.Module):
    def __init__(self, num_heads=1):
        super(FeatureAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=1, num_heads=num_heads, batch_first=True)

    def forward(self, reward_features, lle_features):
        combined_features = torch.cat((reward_features, lle_features), dim=1).unsqueeze(-1)
        attn_output, attn_weights = self.attention(combined_features, combined_features, combined_features)
        return attn_output.squeeze(-1), attn_weights

class LLE(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.lle_head = nn.Sequential(
            nn.Linear(args.local_window_size, 1, bias=False)
        )

    def lle_head_forward(self, x):
        projection = torch.transpose(x,1,2)
        predicted_state = self.lle_head(torch.transpose(projection,1,2))
        return predicted_state     

# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
        
        self.qnet_attention = FeatureAttention()
        self.qnet_attention_weights = None
        self.decoder = nn.Linear(256, np.array(env.single_observation_space.shape).prod())

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        reward_features = torch.tanh(self.fc2(x))
        x = self.fc3(reward_features)
        return x
    
    def get_representation(self, x, a):
        x = torch.cat([x, a], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x
    
    def reconstruction(self, x):
        return self.decoder(x)

LOG_STD_MAX = 2
LOG_STD_MIN = -5

class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
         # action rescaling TODO: CHECK IF THIS WORKS FOR MULTIPLE ENVIRONMENTS
        self.register_buffer(
            "action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float32)
        )
        self.actor_attention = FeatureAttention()
        self.actor_attention_weights = None
        self.decoder = nn.Linear(256, np.array(env.single_observation_space.shape).prod())
        
        

    def forward(self, x): 
        x = F.relu(self.fc1(x))
        reward_features = torch.tanh(self.fc2(x))
        mean = self.fc_mean(reward_features)
        log_std = self.fc_logstd(reward_features)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std
    
    def get_representation(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x
    
    def reconstruction(self, x):
        return self.decoder(x)

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
        )

    args = tyro.cli(Args)
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs-robosuite/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    # Write all arguments to a file
    with open(f"runs-robosuite/{run_name}/args.txt", "w") as f:
        for key, value in vars(args).items():
            f.write(f"{key}: {value}\n")
            
    # Open CSV files once at the start
    recon_csv_file = open(f"runs-robosuite/{run_name}/log_loss_recon.csv", mode='a', newline='\n')
    lle_loss_w_csv_file = open(f"runs-robosuite/{run_name}/lle_loss_log_W.csv", mode='a', newline='\n')
    lle_loss_phi_csv_file = open(f"runs-robosuite/{run_name}/lle_loss_log_Phi.csv", mode='a', newline='\n')

    recon_csv_writer = csv.writer(recon_csv_file)
    lle_loss_w_csv_writer = csv.writer(lle_loss_w_csv_file)
    lle_loss_phi_csv_writer = csv.writer(lle_loss_phi_csv_file)
    
    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    # env setup
    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
    observation_space = _ensure_gymnasium_box(envs.single_observation_space, dtype=np.float32)
    action_space = _ensure_gymnasium_box(envs.single_action_space, dtype=np.float32)
    box_types = (gym.spaces.Box,)
    if gymnasium_spaces is not None:
        box_types = box_types + (gymnasium_spaces.Box,)
    assert isinstance(action_space, box_types), "only continuous action space is supported"

    max_action = float(action_space.high[0])

    actor = Actor(envs).to(device)
    lle_model = LLE(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)
    
    if args.use_lle_projection:
        lle_optimizer_W = optim.Adam(lle_model.lle_head.parameters(), lr=args.lle_learning_rate_W)
    else:
        lle_optimizer_W = optim.Adam(lle_model.lle_head.parameters(), lr=args.lle_learning_rate_W)
    actor_trunk_optimizer = optim.Adam(
        list(actor.fc1.parameters()) + 
        list(actor.fc2.parameters()) + 
        list(actor.decoder.parameters()), 
        lr=args.lle_learning_rate_trunk
    )
    qf1_trunk_optimizer = optim.Adam(
        list(qf1.fc1.parameters()) + 
        list(qf1.fc2.parameters()) + 
        list(qf1.decoder.parameters()), 
        lr=args.lle_learning_rate_trunk
    )
    qf2_trunk_optimizer = optim.Adam(
        list(qf2.fc1.parameters()) + 
        list(qf2.fc2.parameters()) + 
        list(qf2.decoder.parameters()), 
        lr=args.lle_learning_rate_trunk
    )
    
    if args.use_lcr:
        lcr_optimizer = optim.Adam(lle_model.parameters(), lr=args.lle_learning_rate_Phi)
    
    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha

    rb = ReplayBuffer(
        args.buffer_size,
        observation_space,
        action_space,
        device,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    for global_step in tqdm(range(args.total_timesteps)):
        # ALGO LOGIC: put action logic here
        if global_step < args.learning_starts:
            actions = np.array([action_space.sample() for _ in range(envs.num_envs)])
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, terminations, infos = envs.step(actions)

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        if 1 in terminations:
            for info in infos:
                # print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                break

        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
        obs = next_obs

        # ALGO LOGIC: training.
        if global_step > args.learning_starts:
            data = rb.sample(args.batch_size)
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            qf2_a_values = qf2(data.observations, data.actions).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            # optimize the model
            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()
            
            if global_step % 100 == 0:
                writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                writer.add_scalar("losses/alpha", alpha, global_step)
                # print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                if args.autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support
                for _ in range(
                    args.policy_frequency
                ):  # compensate for the delay by doing 'actor_update_interval' instead of 1
                    pi, log_pi, _ = actor.get_action(data.observations)
                    qf1_pi = qf1(data.observations, pi)
                    qf2_pi = qf2(data.observations, pi)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    if args.autotune:
                        with torch.no_grad():
                            _, log_pi, _ = actor.get_action(data.observations)
                        alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()
                        
                # Actor reconstruction
                actual_state = data.observations
                actual_phi = actor.get_representation(actual_state)
                recon_state = actor.reconstruction(actual_phi)
                recon_loss = F.mse_loss(actual_state, recon_state)
                actor_trunk_optimizer.zero_grad()
                recon_loss.backward()
                actor_trunk_optimizer.step()

                # QF1 reconstruction
                actual_phi_qf1 = qf1.get_representation(actual_state, data.actions)
                recon_state_qf1 = qf1.reconstruction(actual_phi_qf1)
                recon_loss_qf1 = F.mse_loss(actual_state, recon_state_qf1)
                qf1_trunk_optimizer.zero_grad()
                recon_loss_qf1.backward()
                qf1_trunk_optimizer.step()

                # QF2 reconstruction
                actual_phi_qf2 = qf2.get_representation(actual_state, data.actions)
                recon_state_qf2 = qf2.reconstruction(actual_phi_qf2)
                recon_loss_qf2 = F.mse_loss(actual_state, recon_state_qf2)
                qf2_trunk_optimizer.zero_grad()
                recon_loss_qf2.backward()
                qf2_trunk_optimizer.step()

            # update the target networks
            if global_step % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                    
                
        # LLE Algorithm Training
        for network, optimizer in zip([actor, qf1, qf2], [actor_trunk_optimizer, qf1_trunk_optimizer, qf2_trunk_optimizer]):
            if global_step % args.lle_batch_size == 0:
                # Reset weights for agent.lle_head
                for module in lle_model.lle_head:
                    if hasattr(module, 'reset_parameters'):
                        module.reset_parameters()
                for epoch in range(args.lle_epochs):
                    ### NEIGHBORS ARE IN THE PAST
                    # idxs = np.random.randint(low=args.local_window_size, high=args.batch_size, size=args.lle_batch_size)
                    # all_neighbouring_idxs = []
                    # for id in idxs:
                    #     neighbors = list(range(id - args.local_window_size, id))
                    #     all_neighbouring_idxs.append(neighbors)
                    ### NEIGHBORS ARE IN THE PAST AND FUTURE
                    idxs = np.random.randint(low=args.local_window_size // 2, high=args.batch_size - args.local_window_size // 2, size=args.lle_batch_size)
                    all_neighbouring_idxs = []
                    half_window = args.local_window_size // 2  # half window size
                    
                    for id in idxs:
                        if args.local_window_size % 2 == 0:
                            # Even window size
                            neighbors = list(range(id - half_window, id)) + list(range(id + 1, id + half_window + 1))
                        else:
                            # Odd window size, add one more data point to the "before" side
                            neighbors = list(range(id - half_window - 1, id)) + list(range(id + 1, id + half_window + 1))
                        all_neighbouring_idxs.append(neighbors)
                    
                    neighbour_data = rb._get_samples(np.reshape(all_neighbouring_idxs, (1,-1)))
                    neighbour_data_actions = neighbour_data.actions
                    neighbour_data = neighbour_data.observations
                    state_neighbours = torch.transpose(torch.reshape(neighbour_data, shape=(args.lle_batch_size, args.local_window_size, -1)), 1, 2)
                    action_neighbours = torch.transpose(torch.reshape(neighbour_data_actions, shape=(args.lle_batch_size, args.local_window_size, -1)), 1, 2)
                    actual_state = rb._get_samples(idxs).observations
                    actual_action = rb._get_samples(idxs).actions
                        
                    # LCR Version
                    if args.use_lcr:
                        actual_phi = network.get_representation(actual_state)
                        for epoch in range(args.lle_gradient_steps_W):
                            neighbouring_phi = network.get_representation(torch.transpose(state_neighbours, 1, 2))
                            predicted_phi = lle_model.lle_head(torch.transpose(neighbouring_phi, 1, 2))
                            lcr_loss = F.mse_loss(actual_phi.detach(), torch.squeeze(predicted_phi))
                            # writer.add_scalar("losses/lcr_loss", lcr_loss, epoch)
                            lcr_optimizer.zero_grad()
                            lcr_loss.backward()
                            lcr_optimizer.step()
                    # LLE Version
                    else:
                        # Initialize previous loss to a large value
                        prev_lle_loss_W = float('inf')
                        consecutive_below_threshold = 0  # Initialize counter for consecutive below-threshold occurrences

                        for epoch_W in range(args.lle_gradient_steps_W):
                            predicted_state = lle_model.lle_head_forward(state_neighbours)

                            lle_loss_W = F.mse_loss(actual_state, torch.squeeze(predicted_state))
                            # lle_loss_W = 1 - F.cosine_similarity(actual_state, torch.squeeze(predicted_state), dim=1).mean()
                            
                            # Check if the reduction in loss is below the threshold
                            if abs(prev_lle_loss_W - lle_loss_W.item()) < args.lle_loss_reduction_threshold_W:
                                consecutive_below_threshold += 1
                            else:
                                consecutive_below_threshold = 0  # Reset counter if not below threshold

                            if consecutive_below_threshold >= 5:
                                break

                            prev_lle_loss_W = lle_loss_W.item()
                            if epoch_W == 0:
                                first_lle_loss_W = prev_lle_loss_W

                            # writer.add_scalar("losses/lle_loss_W", lle_loss_W, epoch)
                            lle_optimizer_W.zero_grad()
                            lle_loss_W.backward()
                            lle_optimizer_W.step()
                            
                        # Write to CSV instead of printing
                        lle_loss_w_csv_writer.writerow([epoch_W, first_lle_loss_W, lle_loss_W.item()])
                        lle_loss_w_csv_file.flush()  # Ensure data is written to disk      
                            # print(f"Epoch_W: {epoch_W}, Initial LLE Loss W: {first_lle_loss_W}, LLE Loss W: {lle_loss_W.item()}")
                            
                        if network == actor:
                            actual_phi = network.get_representation(actual_state)
                        else:
                            actual_phi = network.get_representation(actual_state, actual_action)

                        prev_lle_loss_Phi = float('inf')
                        consecutive_below_threshold = 0  # Initialize counter for consecutive below-threshold occurrences
                        for epoch_Phi in range(args.lle_gradient_steps_Phi):
                            if network == actor:
                                neighbouring_phi = network.get_representation(torch.transpose(state_neighbours, 1, 2))
                            else:
                                neighbouring_phi = network.get_representation(torch.transpose(state_neighbours, 1, 2), torch.transpose(action_neighbours, 1, 2))
                            predicted_phi = lle_model.lle_head(torch.transpose(neighbouring_phi, 1, 2))

                            # lle_loss_Phi = F.mse_loss(actual_phi.detach(), torch.squeeze(predicted_phi)) 
                            lle_loss_Phi = 1 - F.cosine_similarity(actual_phi.detach(), torch.squeeze(predicted_phi), dim=1).mean()
                            # lle_loss_Phi = F.smooth_l1_loss(actual_phi.detach(), torch.squeeze(predicted_phi))
                            
                            # Check if the reduction in loss is below the threshold
                            if abs(prev_lle_loss_Phi - lle_loss_Phi.item()) < args.lle_loss_reduction_threshold_Phi:
                                consecutive_below_threshold += 1
                            else:
                                consecutive_below_threshold = 0  # Reset counter if not below threshold

                            if consecutive_below_threshold >= 5 and first_lle_loss_Phi == lle_loss_Phi.item():
                                break
                            prev_lle_loss_Phi = lle_loss_Phi.item()
                            
                            if epoch_Phi == 0:
                                first_lle_loss_Phi = prev_lle_loss_Phi

                            # writer.add_scalar("losses/lle_loss_Phi", lle_loss_Phi, epoch)
                            optimizer.zero_grad()
                            lle_loss_Phi.backward()
                            optimizer.step()
                            
                        lle_loss_phi_csv_writer.writerow([epoch_Phi, first_lle_loss_Phi, lle_loss_Phi.item()])
                        lle_loss_phi_csv_file.flush()  # Ensure data is written to disk
                            # print(f"Epoch_Phi: {epoch_Ph i}, Initial LLE Loss Phi: {first_lle_loss_Phi}, LLE Loss Phi: {lle_loss_Phi.item()}")


     # Save models
    torch.save(actor.state_dict(), f"runs-robosuite/{run_name}/actor.pth")
    torch.save(qf1.state_dict(), f"runs-robosuite/{run_name}/qf1.pth")
    torch.save(qf2.state_dict(), f"runs-robosuite/{run_name}/qf2.pth")
    torch.save(lle_model.state_dict(), f"runs-robosuite/{run_name}/lle_model.pth")
    
    envs.close()
    writer.close()
