import os
import random
import time
from dataclasses import dataclass, field
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import metaworld
from sentence_transformers import SentenceTransformer
from torch.distributions import Normal, TransformedDistribution
from sklearn.decomposition import DictionaryLearning, sparse_encode
from copy import deepcopy
from sklearn.utils import check_array, check_random_state

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""

    # Algorithm specific arguments
    env_id: str = "mt4"
    """the environment id of the task"""
    env_names: list = field(
        default_factory=lambda: ['window-open-v2', 'window-close-v2', 'drawer-open-v2', 'drawer-close-v2'])
    """the environment id of the task"""
    env_hints: list = field(
        default_factory=lambda: ['Push and open a window.', 'Push and close a window.', 'Open a drawer.', 'Push and close a drawer.'])
    """the environment description of the task"""
    total_timesteps: int = 500000
    """total timesteps of the experiments"""
    test_timesteps: int = 100000
    """test timesteps of the experiments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256 # 1280
    """the batch size of sample from the reply memory"""
    learning_starts: int = 1e4
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.
    """the frequency of updates for the target nerworks"""
    noise_clip: float = 0.5
    """noise clip parameter of the Target Policy Smoothing Regularization"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    average_num: int = 100
    """the number of episode to compute average success rate"""
    threshold: float = 0.9
    """the threshold representing finishing training"""
    cycle_num: int = 2
    """the number of cycling"""


def make_env(env_name, gamma):
    def thunk():
        mt1 = metaworld.MT1(env_name) # Construct the benchmark, sampling tasks
        env = mt1.train_classes[env_name]()  # Create an environment with task `pick_place`
        task = mt1.train_tasks[0]
        env.set_task(task)  # Set task
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env
    return thunk

def clip_by_norm(x, c):
    clip_coef = c / (np.linalg.norm(x) + 1e-6)
    clip_coef_clipped = min(1.0, clip_coef)
    return x * clip_coef_clipped

def _update_dict(
    dictionary: torch.Tensor,
    Y: torch.Tensor,
    code: torch.Tensor,
    A: torch.Tensor = None,
    B: torch.Tensor = None,
    c: float = 1e-1,
    verbose: bool = False,
    random_state: int = None,
    positive: bool = False,
):
    n_samples, n_components = code.shape
    random_state = check_random_state(random_state)

    if A is None:
        A = torch.matmul(torch.Tensor(code.T), code)
    if B is None:
        B = torch.matmul(Y.T, code)

    n_unused = 0
    for k in range(n_components):
        if A[k, k] > 1e-6:
            dictionary[k] += (B[:, k] - A[k] @ dictionary) / A[k, k]
        else:
            newd = Y[random_state.choice(n_samples)]

            # add small noise to avoid making the sparse coding ill conditioned
            noise_level = 1.0 * (newd.std().item() or 1)  # avoid 0 std
            noise = random_state.normal(0, noise_level, size=len(newd))
            dictionary[k] = newd + noise
            code[:, k] = 0
            n_unused += 1

        if positive:
            dictionary[k] = torch.clamp(dictionary[k], min=0)

        # Projection on the constraint set ||V_k||_2 <= c
        dictionary[k] = clip_by_norm(dictionary[k], c)

    if verbose and n_unused > 0:
        print(f"{n_unused} unused atoms resampled.")

    return dictionary

class OnlineDictLearnerV2:
    def __init__(self, 
                 n_features: int, 
                 n_components: int,
                 seed: int=0,
                 init_sample: np.ndarray=None,
                 c: float=1e-2,
                 scale: float=1.0,
                 alpha: float=1e-3,
                 method: str='lasso_lars',
                 positive_code: bool=False,
                 scale_code: bool=False,
                 verbose=True):
        self.N = 0
        self.rng = np.random.RandomState(seed=seed)
        self.A = np.zeros((n_components, n_components))
        self.B = np.zeros((n_features, n_components))
        
        dictionary = self.rng.normal(loc=0.0, scale=scale, size=(n_components, n_features))
        for j in range(n_components):
            dictionary[j] = clip_by_norm(dictionary[j], c)
            
        dictionary = check_array(dictionary, order="F", copy=False)
        self.D = np.require(dictionary, requirements="W")
        
        self.C = None
        self.c = c
        self.alpha = alpha
        self.method = method
        self.archives = None
        self._verbose = verbose
        self.arch_code = None
        self._positive_code = positive_code
        self._scale_code = scale_code
        self.change_of_dict = []
        
    def get_alpha(self, sample: np.ndarray):
        code = sparse_encode(
            sample,
            self.D,
            algorithm=self.method, 
            alpha=self.alpha,
            check_input=False,
            positive=self._positive_code,
            max_iter=10000
        )
        if self.arch_code is None:
            self.arch_code = code
        else:
            self.arch_code = np.vstack([self.arch_code, code])

        if self._scale_code:
            scaled_code = self._scale_coeffs(code)
            assert np.max(scaled_code) == 1.0
        else:
            scaled_code = code
        
        if self._verbose:
            recon = np.dot(code, self.D)
            print('Spare Coding of Task Embedding')
            print(f'Rate of deactivate: {1-np.mean(np.heaviside(scaled_code, 0)):.4f}')
            print(f'Rate of activate: {np.mean(np.heaviside(scaled_code, 0)):.4f}')
            print(f'Recontruction loss: {np.mean((sample - recon)**2):.4e}')
            print('----------------------------------')

        return scaled_code
    
    def update_dict(self, codes: np.ndarray, sample: np.ndarray):
        self.N += 1
        if self._scale_code:
            codes = self._rescale_coeffs(codes)
        # recording
        if self.C is None:
            self.C = codes
            self.archives = sample
        else:
            self.C = np.vstack([self.C, codes])
            self.archives = np.vstack([self.archives, sample])

        assert self.C.shape[0] == self.N
            
        # Update the auxiliary variables
        self.A += np.dot(codes.T, codes)
        self.B += np.dot(sample.T, codes)

        # pre-verbose
        if self._verbose:
            recons = np.dot(self.C, self.D)
            print('Dictionary Learning')
            print(f'Pre-MSE loss: {np.mean((self.archives - recons)**2):.4e}')

        old_D = deepcopy(self.D)

        # Update dictionary
        self.D = _update_dict(
            self.D,
            sample,
            codes,
            self.A,
            self.B,
            self.c,
            verbose=self._verbose,
            random_state=self.rng,
            positive=self._positive_code
        )
        
        self.change_of_dict.append(np.linalg.norm(old_D - self.D) ** 2 / self.D.size)

        # post-verbose
        if self._verbose:
            recons = np.dot(self.C, self.D)
            print(f'Post-MSE loss: {np.mean((self.archives - recons)**2):.4e}')
            print('----------------------------------')
            
                
# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 768)
        self.fc2 = nn.Linear(768, 768)
        self.fc3 = nn.Linear(768, 768)
        self.fc4 = nn.Linear(768, 1)
        self.layers = [self.fc1, self.fc2, self.fc3]
        self.layers_plus = [self.fc4]

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

LOG_STD_MAX = 10
LOG_STD_MIN = -10

class Actor(nn.Module):
    def __init__(self, env, task_num, final_fc_init_scale=1.0, clip_mean=1.0, tanh_squash=True):
        
        super(Actor, self).__init__()
        self.hidden_dims = [768, 768, 768]
        
        self.action_dim = np.prod(env.single_action_space.shape)
        self.task_num = task_num

        self.clip_mean = clip_mean
        self.log_std_min = LOG_STD_MIN
        self.log_std_max = LOG_STD_MAX
        self.tanh_squash = tanh_squash

        pre_dim = np.array(env.single_observation_space.shape).prod()
        backbones_layer = []
        for hidn in self.hidden_dims:
            backbones_layer.append(nn.Linear(pre_dim, hidn))
            pre_dim = hidn
        self.backbones = nn.ModuleList(backbones_layer)
        
        self.embeds_bb = nn.ModuleList([nn.Embedding(self.task_num, hidn) for hidn in self.hidden_dims])
        
        self.mean_layer = nn.Linear(pre_dim, self.action_dim, bias=False)
        self.log_std_layer = nn.Linear(pre_dim, self.action_dim)
        

        self.activation = F.relu
        self.tanh = F.tanh


    def forward(self, x, t, temperature=1.0):
        masks = {}
        for i, layer in enumerate(self.backbones):
            x = layer(x)
            # straight-through estimator
            phi_l = self.embeds_bb[i](t).clamp(0, 1)
            mask_l = phi_l.expand_as(x)
            
            masks[layer] = mask_l
            # masking outputs
            x *= mask_l
            x = self.activation(x)
        
        means = self.mean_layer(x)

        # Avoid numerical issues by limiting the mean of the Gaussian
        means = means.clamp(-self.clip_mean, self.clip_mean)

        log_stds = self.log_std_layer(x)
        
        # squashing log_std
        log_stds = log_stds.clamp(self.log_std_min, self.log_std_max)

        # numerically stable method
        base_dist = Normal(loc=means, scale=F.softplus(log_stds) * temperature)

        return base_dist, {
            'masks': masks, 
            'means': means, 
            'stddev': F.softplus(log_stds)
        }
    

    def get_grad_masks(self, masks, input_dim = 39):
        grad_masks = {}
        for i, layer in enumerate(self.backbones):
            if i == 0:
                post_m = masks[layer]
                grad_masks[(layer, 'weight')] = 1 - post_m.expand(1, input_dim, self.hidden_dims[i])
                grad_masks[(layer, 'bias')] = 1 - post_m.flatten()
                pre_m = masks[layer]
            else:
                post_m = masks[layer]
                grad_masks[(layer, 'weight')] = 1 - torch.min(
                    pre_m.expand(1, self.hidden_dims[i-1], self.hidden_dims[i]),
                    post_m.expand(1, self.hidden_dims[i-1], self.hidden_dims[i])
                )
                grad_masks[(layer, 'bias')] = 1 - post_m.flatten()
                pre_m = masks[layer]

        return grad_masks
  

class CoTASPLearner:
    def __init__(
        self,
        actor,
        seed: int,
        observations: torch.Tensor,
        actions: torch.Tensor,
        task_num: int,
        update_dict=True,
        update_coef=True,
        ):
        action_dim = actions.shape[-1]
        self.seed = seed
        self.actor = actor
        self.dummy_o = observations

        # preset dict learner for each layer:
        self.dict4layers = {}
        for id_layer, hidn in enumerate(self.actor.hidden_dims):
            dict_learner = OnlineDictLearnerV2(
                384, hidn, seed + id_layer + 1, None)
            self.dict4layers[f'embeds_bb.{id_layer}.weight'] = dict_learner

        # initialize param_masks
        _, dicts = self.sample_actions(self.dummy_o, torch.tensor([0]))
        self.cumul_masks = {k: torch.zeros_like(v) for k, v in dicts['masks'].items()}
        self.param_masks = self.actor.get_grad_masks(self.cumul_masks)

        # initialize other things
        self.update_dict = update_dict
        self.update_coef = update_coef

        self.task_embeddings = []
        self.task_encoder = SentenceTransformer('all-MiniLM-L12-v2')
        
    def start_task(self, task_id, description):
        task_e = self.task_encoder.encode(description)[np.newaxis, :]
        self.task_embeddings.append(task_e)
        
        # Set initial alpha for each layer of MPN
        actor_params = {name: param for name, param in self.actor.named_parameters()}
        for name, param in actor_params.items():
            if 'embeds' in name:
                alpha_l = self.dict4layers[name].get_alpha(task_e)
                alpha_l = torch.tensor(alpha_l.flatten(), dtype=param.dtype)
                with torch.no_grad():
                    # Replace the i-th row
                    param[task_id].copy_(alpha_l)

        # Update the actor with modified parameters
        self.actor.load_state_dict(actor_params)

    def end_task(self, task_id):
        self.step = 0

        # Sample actions and get current masks
        actions, dicts = self.sample_actions(self.dummy_o, torch.tensor([task_id]).to('cuda:0'))
        current_masks = dicts['masks']

        # Update cumulative masks
        for key, value in current_masks.items():
            if key in self.cumul_masks:
                self.cumul_masks[key] = torch.maximum(self.cumul_masks[key], value)
            else:
                self.cumul_masks[key] = value

        # Get gradient masks
        grad_masks = self.actor.get_grad_masks(self.cumul_masks)
        self.param_masks = grad_masks

        # Update dictionary learners
        dict_stats = {}
        if self.update_dict:
            for name, param in self.actor.named_parameters():
                if 'embeds' in name:
                    optimal_alpha_l = param[task_id].detach().cpu().numpy().flatten()
                    optimal_alpha_l = np.array([optimal_alpha_l])
                    
                    task_e = self.task_embeddings[task_id]
                    self.dict4layers[name].update_dict(optimal_alpha_l, task_e)
        else:
            for name in self.actor.state_dict().keys():
                if 'embeds' in name:
                    dict_stats[name] = {
                        'sim_mat': self.dict4layers[name]._compute_overlapping(),
                        'change_of_d': 0
                    }

        return dict_stats

    def sample_actions(self, observations, task_id, temperature=1.0):
        with torch.no_grad():
            dist, dicts = self.actor(torch.Tensor(observations).to('cuda:0'), task_id.to('cuda:0'), temperature)
            actions = dist.sample()
        return actions, dicts

def clip_fn(x):
        return torch.clamp(x, 0, 1.0)

def ste_step_fn(x):
        zero = clip_fn(x) - x.detach()
        return zero + torch.heaviside(x, torch.tensor(0.0)).detach()

def train_env(envs, device, args, env_name, task_index, task_idx, hint):

    agent.start_task(task_idx, hint)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    qf1_optimizer = optim.Adam(list(qf1.parameters()), lr=args.q_lr)
    qf2_optimizer = optim.Adam(list(qf2.parameters()), lr=args.q_lr)

    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha
    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    # TRY NOT TO MODIFY: start the game
    obs, _ = envs.reset(seed=args.seed)
    success_total = 0
    episode_total = 0
    end = False
    for global_step in range(args.total_timesteps):
        if end == True:
            break
        # ALGO LOGIC: put action logic here
        if global_step < args.learning_starts:
            if task_idx == 0:
                actions = envs.action_space.sample()
            else:
                # uniform-previous strategy
                mask_id = np.random.choice(task_idx)
                actions, _ = agent.sample_actions(obs, torch.IntTensor([task_idx]))
                actions = actions.detach().cpu().numpy()
        else:
            actions, _ = agent.sample_actions(obs, torch.IntTensor([task_idx]))
            actions = actions.detach().cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        if "final_info" in infos:
            for info in infos["final_info"]:
                if info['success'] == 1:
                    success_total += 1
                episode_total += 1
                rate = success_total / episode_total
                if episode_total % args.average_num == 0:
                    if rate > args.threshold:
                        end = True
                    episode_total = success_total = 0
                print(
                    f"{env_name}: global_step={global_step}, episodic_return={info['episode']['r']}, success_rate={rate}")
                writer.add_scalar(f"{env_name}/{task_index}/episodic_return", info["episode"]["r"], global_step)
                writer.add_scalar(f"{env_name}/{task_index}/episodic_length", info["episode"]["l"], global_step)
                writer.add_scalar(f"{env_name}/{task_index}/success_rate", rate, global_step)

        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()
        for idx, trunc in enumerate(truncations):
            if trunc:
                real_next_obs[idx] = infos["final_observation"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
        obs = next_obs

        # ALGO LOGIC: training.
        if global_step > args.learning_starts:
            data = rb.sample(args.batch_size)
            with torch.no_grad():
                dist, _ = actor(data.next_observations.float().to('cuda:0'),
                                torch.IntTensor([task_idx]).to('cuda:0'))
                next_state_actions = dist.sample()
                next_state_log_pi = dist.log_prob(next_state_actions).sum(1, keepdim=True)
                qf1_next_target = qf1_target(data.next_observations.float(), next_state_actions)
                qf2_next_target = qf2_target(data.next_observations.float(), next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (
                    min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations.to(torch.float32), data.actions.to(torch.float32)).view(-1)
            qf2_a_values = qf2(data.observations.to(torch.float32), data.actions.to(torch.float32)).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            # optimize the model
            qf1_optimizer.zero_grad()
            qf2_optimizer.zero_grad()
            qf_loss.backward()
            qf1_optimizer.step()
            qf2_optimizer.step()

            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support
                for _ in range(
                        args.policy_frequency
                ):  # compensate for the delay by doing 'actor_update_interval' instead of 1
                    dist, _ = actor(data.observations.float(), torch.IntTensor([task_idx]).to('cuda:0'))
                    pi = dist.sample()
                    log_pi = dist.log_prob(pi).sum(1, keepdim=True)

                    qf1_pi = qf1(data.observations.float(), pi)
                    qf2_pi = qf2(data.observations.float(), pi)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    if args.autotune:
                        with torch.no_grad():
                            dist, _ = actor(data.observations.float(), torch.IntTensor([task_idx]).to('cuda:0'))
                            pi = dist.sample()
                            log_pi = dist.log_prob(pi).sum(1, keepdim=True)
                        alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()

            # update the target networks
            if global_step % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

            if global_step % 100 == 0:
                writer.add_scalar(f"{env_name}/{type}/{task_index}/qf1_values", qf1_a_values.mean().item(), global_step)
                writer.add_scalar(f"{env_name}/{type}/{task_index}/qf2_values", qf2_a_values.mean().item(), global_step)
                writer.add_scalar(f"{env_name}/{type}/{task_index}/qf1_loss", qf1_loss.item(), global_step)
                writer.add_scalar(f"{env_name}/{type}/{task_index}/qf2_loss", qf2_loss.item(), global_step)
                writer.add_scalar(f"{env_name}/{type}/{task_index}/qf_loss", qf_loss.item() / 2.0, global_step)
                writer.add_scalar(f"{env_name}/{type}/{task_index}/actor_loss", actor_loss.item(), global_step)
                writer.add_scalar(f"{env_name}/{type}/{task_index}/alpha", alpha, global_step)
                print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar(f"{env_name}/{type}/{task_index}/SPS", int(global_step / (time.time() - start_time)),
                                  global_step)
                if args.autotune:
                    writer.add_scalar(f"{env_name}/{type}/losses/alpha_loss", alpha_loss.item(), global_step)
    dict_stats = agent.end_task(task_idx)
    envs.close()

def test_env(envs, device, args, env_name, task_index, task_idx):
        obs, _ = envs.reset(seed=args.seed)
        success_total = 0
        episode_total = 0
        for global_step in range(args.test_timesteps):
            dist, _ = actor(torch.Tensor(obs).to(device), torch.IntTensor([task_idx]).to('cuda:0'))
            actions = dist.sample()

            actions = actions.detach().cpu().numpy()

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, rewards, terminations, truncations, infos = envs.step(actions)

            # TRY NOT TO MODIFY: record rewards for plotting purposes
            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info['success'] == 1:
                        success_total += 1
                    episode_total += 1
                    rate = success_total / episode_total
                    if episode_total % args.average_num == 0:
                        episode_total = success_total = 0
                    print(
                        f"test_{env_name}: global_step={global_step}, episodic_return={info['episode']['r']}, success_rate={rate}")
                    writer.add_scalar(f"test_{env_name}/{task_index}/episodic_return", info["episode"]["r"],
                                      global_step)
                    writer.add_scalar(f"test_{env_name}/{task_index}/episodic_length", info["episode"]["l"],
                                      global_step)
                    writer.add_scalar(f"test_{env_name}/{task_index}/success_rate", rate, global_step)
            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            obs = next_obs
        envs.close()


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
        )

    args = tyro.cli(Args)
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = [gym.vector.SyncVectorEnv([make_env(env_name, args.gamma)])
            for env_name in args.env_names]

    max_action = float(envs[0].single_action_space.high[0])

    actor = Actor(envs[0], len(args.env_names)).to(device)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)
    
    agent = CoTASPLearner(
            actor,
            args.seed,
            envs[0].observation_space.sample()[np.newaxis],
            envs[0].action_space.sample()[np.newaxis], 
            len(envs))

    env_num = len(envs)
    for i in range(args.cycle_num):
        for j in range(len(envs)):
            
            train_env(envs[j], device, args, args.env_names[j], i*env_num+j, j, args.env_hints[j])
        
            for k in range(env_num):
                test_env(envs[k], device, args, args.env_names[k], i*env_num+j, j)
            
    writer.close()
