import gymnasium
import argparse
from tensorboardX import SummaryWriter
import cv2
import numpy as np
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from tqdm import tqdm
import copy
import colorama
import random
import json
import shutil
import pickle
import os
import re
from utils import seed_np_torch, Logger, load_config
from replay_buffer import ReplayBuffer
import env_wrapper
import agents
from sub_models.functions_losses import symexp 
from sub_models.mixture_world_models import MixtureWorldModel 
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP  
import time 
 

def build_single_env(env_name, image_size, seed):
    env = gymnasium.make(env_name, full_action_space=True, render_mode="rgb_array", frameskip=1)
    env = env_wrapper.SeedEnvWrapper(env, seed=seed)
    env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4)
    env = gymnasium.wrappers.ResizeObservation(env, shape=image_size)
    env = env_wrapper.LifeLossInfo(env)
    return env
 
def build_vec_env(env_name, image_size, num_envs, seed):
    # lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7
    def lambda_generator(env_name, image_size):
        return lambda: build_single_env(env_name, image_size, seed)
    env_fns = []
    env_fns = [lambda_generator(env_name, image_size) for i in range(num_envs)]
    vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns)
    return vec_env

 
def train_world_model_step_warm_up(replay_buffer, mixture_world_model, env, batch_size, demonstration_batch_size, batch_length, mode, log_video, logger):
    #obs, action, reward, termination = {},{},{},{}  
    obs, action, reward, termination = \
            replay_buffer[env].sample(batch_size, demonstration_batch_size, batch_length)  
    #mixture_world_model(obs, action, reward, termination, env, len_rank, log_video, logger=logger) 
    final_loss = mixture_world_model(obs, action, reward, termination, env, mode, log_video, logger=logger) 
    # final_loss.backward() 
    # mixture_world_model.module.miwm_optimizer.step()
    scaler = torch.cuda.amp.GradScaler(enabled=True)
    scaler.scale(final_loss).backward()
    scaler.unscale_(mixture_world_model.module.miwm_optimizer)  # for clip grad
    torch.nn.utils.clip_grad_norm_(mixture_world_model.module.parameters(), max_norm=1000.0)  
    scaler.step(mixture_world_model.module.miwm_optimizer) 
    scaler.update()
    mixture_world_model.module.miwm_optimizer.zero_grad(set_to_none=True)
    

# def train_world_model_step(replay_buffer, mixture_world_model, env, batch_size, demonstration_batch_size, batch_length, mode, adaptive_weights, w_k, last_loss, last_two_steps_loss, log_video, logger):
#     #obs, action, reward, termination = {},{},{},{}  
#     obs, action, reward, termination = \
#             replay_buffer[env].sample(batch_size, demonstration_batch_size, batch_length)  
#     model_device = replay_buffer[env].obs_buffer.device
#     #mixture_world_model(obs, action, reward, termination, env, len_rank, log_video, logger=logger)
#     grad_for_weights = {}
#     final_loss, grad_for_weights[env] = mixture_world_model(obs, action, reward, termination, env, mode, log_video, logger=logger) 
#     if torch.distributed.is_initialized():
#         torch.distributed.barrier()
#     grad_for_weights = gather_dict_across_ranks(grad_for_weights)  
#     print( list(grad_for_weights.keys())) 
#     # # print(f"env {env} adaptive_weights[env] {adaptive_weights[env]}")
#     # #logger.log(f"{env}_MixtureWorldModel/adaptive_weight", adaptive_weights[env].item())
#     # # logger.log(f"{env}_MixtureWorldModel/adaptive_weight", weights.item())
#     # # if last_loss[env]==None and last_two_steps_loss[env]==None:
#     # #     last_loss[env] = final_loss.detach()
#     # #     grad_for_weights[env]=torch.tensor(1.0)
#     # # else: 
#     # #     grad_for_weights[env] = final_loss/last_loss[env] 
#     #grad_for_weights = gather_dict_across_ranks(grad_for_weights)
#     #weights = torch.tensor(1.0)
#     weights = get_grad_sim_weights(grad_for_weights, env, model_device).detach() 
#     final_loss = weights *final_loss 
#     #final_loss = adaptive_weights[env]*final_loss
#     logger.log(f"{env}_MixtureWorldModel/weight", weights.item())
#     #w_k[env] = torch.max(w_k[env], torch.tensor(1.0))
#         # print(f"env {env} last_loss {last_loss[env]}")
#         # print(f"env {env} last_two_steps_loss {last_two_steps_loss[env]}")
#         # print(f"env {env} w_k {w_k[env]}")
#     # final_loss = weights *final_loss 
    
#     scaler = torch.cuda.amp.GradScaler(enabled=True)
#     scaler.scale(final_loss).backward()
#     scaler.unscale_(mixture_world_model.module.miwm_optimizer)  # for clip grad
#     grad_norm = torch.nn.utils.clip_grad_norm_(mixture_world_model.module.parameters(), max_norm=100.0)
#     logger.log(f"{env}_MixtureWorldModel/grad_norm", grad_norm.item()) 
#     scaler.step(mixture_world_model.module.miwm_optimizer) 
#     scaler.update()
#     mixture_world_model.module.miwm_optimizer.zero_grad(set_to_none=True)
#     return  last_loss[env], last_two_steps_loss[env], w_k[env]
    
def train_world_model_step(replay_buffer, mixture_world_model, env, batch_size, demonstration_batch_size, batch_length, mode, log_video, logger):
    
    
    obs, action, reward, termination = \
            replay_buffer[env].sample(batch_size, demonstration_batch_size, batch_length)  
    model_device = replay_buffer[env].obs_buffer.device
    #mixture_world_model(obs, action, reward, termination, env, len_rank, log_video, logger=logger)
    if mode=='train':
       final_loss = mixture_world_model(obs, action, reward, termination, env, mode, log_video, logger=logger)  
       return final_loss
    elif mode=='get_grad':
       grad_for_weights = {}
       grad_for_weights[env] = mixture_world_model(obs, action, reward, termination, env, mode, log_video, logger=logger)
       return grad_for_weights[env]
 





def train_agent_step(agent, rank_env, latent_input, agent_action, agent_logprob, agent_value, imagine_reward, imagine_termination,logger):
 
    loss, loss_dict = agent(rank_env=rank_env,
                latent=latent_input,
                action=agent_action,
                old_logprob=agent_logprob,
                old_value=agent_value,
                reward=imagine_reward,
                termination=imagine_termination,
                logger=logger ) 
   
    agent.module.scaler.scale(loss).backward()
    agent.module.scaler.unscale_(agent.module.optimizer)  # for clip grad
    torch.nn.utils.clip_grad_norm_(agent.module.parameters(), max_norm=100.0)
    agent.module.scaler.step(agent.module.optimizer)
    agent.module.scaler.update()
    agent.module.optimizer.zero_grad(set_to_none=True)
    return loss_dict
 
                    


@torch.no_grad()
def mixture_world_model_imagine_data(replay_buffer, mixture_world_model, agent: agents.ActorCriticAgent, env_names, env,
                             imagine_batch_size, imagine_demonstration_batch_size,
                             imagine_context_length, imagine_batch_length,  
                             log_video, logger):
    '''
    Sample context from replay buffer, then imagine data with world model and agent
    ''' 
 
    mixture_world_model.eval()
    agent.eval() 
    sample_obs, sample_action,  _, _ = replay_buffer[env].sample(
              imagine_batch_size, imagine_demonstration_batch_size, imagine_context_length)
    old_logprob = None
    old_value = None 
    #latent_mixture_imagine, action_mixture_imagine, reward_hat_mixture_imagine, termination_hat_mixture_imagine = {},{},{},{}
    with torch.no_grad():
        latent_output, action_output, reward_hat_output, termination_hat_output = mixture_world_model.module.imagine_data(
            agent, env, sample_obs, sample_action,
            imagine_batch_size=imagine_batch_size,
            imagine_batch_length=imagine_batch_length,
            log_video=log_video, 
            logger=logger
        ) 
    # latent_mixture_imagine[env]=latent_output.clone()
    # action_mixture_imagine[env]=action_output.clone()
    # reward_hat_mixture_imagine[env]=reward_hat_output.clone()
    # termination_hat_mixture_imagine[env]=termination_hat_output.clone()
 
    return latent_output.clone(), action_output.clone(), old_logprob, old_value, reward_hat_output.clone(), termination_hat_output.clone()

 

def joint_train_mixture_world_model_agent(env_names, env_values, vec_env, num_envs, max_steps, image_size,
                                  replay_buffer, mixture_world_model, agent: agents.ActorCriticAgent,
                                  train_dynamics_every_steps, train_agent_every_steps,
                                  batch_size, demonstration_batch_size, batch_length,
                                  imagine_batch_size, imagine_demonstration_batch_size,
                                  imagine_context_length, imagine_batch_length,
                                  save_every_steps, seed, rank, logdir, logger):
   
    sum_reward = {}
    current_obs = {}
    current_info = {}
    context_obs = {}
    context_action = {}
    # vec_env = {}
    #envs_str = "_".join(env_names)
    envs_str = "all_atari_tasks" 
    model_device = torch.device(f"cuda:{rank}") 
    
    if rank == 0:
        os.makedirs(f"{logdir}/ckpt/{envs_str}", exist_ok=True)
    for idx, env in enumerate(env_names):
        if model_device == replay_buffer[env].obs_buffer.device:
            # os.makedirs(f"{logdir}/ckpt/{n_values[idx]}", exist_ok=True)
            # vec_env[env] = build_vec_env(env_values[idx], image_size, num_envs=num_envs, seed=seed)
            print("Current env: " + colorama.Fore.YELLOW + f"{env_names[idx]} on {model_device}" + colorama.Style.RESET_ALL)
            #reset
            sum_reward[env] = np.zeros(num_envs)
            current_obs[env], current_info[env] = vec_env[env].reset()
            context_obs[env] = deque(maxlen=16) 
            context_action[env] = deque(maxlen=16) 
 
    obs, reward, done, truncated, info, done_flag = {}, {}, {}, {}, {}, {} 
 
    action = {}
    model_context_action = {}
    enc_obs_dict = {}
    context_latent = {}
    prior_flattened_sample = {}
    last_dist_feat = {}
   
    progress_bar = tqdm(range(max_steps//num_envs), desc=f"Training") if rank == 0 else range(max_steps//num_envs) 
    
    for total_steps in progress_bar:
        for idx, env in enumerate(env_names):  
            if replay_buffer[env].ready():
                mixture_world_model.eval()
                agent.eval() 
                if model_device == replay_buffer[env].obs_buffer.device: 
                    with torch.no_grad(): 
                        if len(context_action[env]) == 0:
                            action[env] = vec_env[env].action_space.sample() 
            else:
                if model_device == replay_buffer[env].obs_buffer.device:
                    action[env] = vec_env[env].action_space.sample() 
 
        for idx, env in enumerate(env_names):
            if model_device == replay_buffer[env].obs_buffer.device:
                with torch.no_grad(): 
                    if len(context_action[env]) != 0:
                        enc_obs_dict[env] = torch.cat(list(context_obs[env]), dim=1)  
                        context_latent[env] = mixture_world_model.module.encode_obs_single(enc_obs_dict[env],env) 
                        model_context_action[env] = np.stack(list(context_action[env]), axis=1)
                        model_context_action[env] = torch.Tensor(model_context_action[env]).cuda()
        for idx, env in enumerate(env_names):
            if model_device == replay_buffer[env].obs_buffer.device:
                with torch.no_grad(): 
                    if len(context_action[env]) != 0: 
                        prior_flattened_sample[env], last_dist_feat[env] = mixture_world_model.module.calc_last_dist_feat(context_latent[env], model_context_action[env], env)
                        task_id = mixture_world_model.module.get_task_index(env)
                        action_input = mixture_world_model.module.single_task_emb(torch.cat([prior_flattened_sample[env], last_dist_feat[env]], dim=-1),task_id) 
                        action[env] = agent.module.sample_as_env_action(
                                action_input,
                                greedy=False)         
                context_obs[env].append(rearrange(torch.Tensor(current_obs[env]).cuda() , "B H W C -> B 1 C H W")/255) 
                context_action[env].append(action[env])
    
        
        for env in env_names: 
            if model_device == replay_buffer[env].obs_buffer.device: 
                obs[env], reward[env], done[env], truncated[env], info[env] = vec_env[env].step(action[env])   
                done_flag[env] = np.logical_or(done[env], truncated[env]) 
                replay_buffer[env].append(current_obs[env], action[env], reward[env], np.logical_or(done[env], info[env]["life_loss"])) 
                done_flag[env] = np.logical_or(done[env], truncated[env])
                if done_flag[env].any():
                    for i in range(num_envs):
                        if done_flag[env][i]:
                            logger.log(f"sample/{env}_reward", sum_reward[env][i])
                            logger.log(f"sample/{env}_episode_steps", current_info[env]["episode_frame_number"][i]//4)  # framskip=4
                            logger.log(f"replay_buffer/{env}_length", len(replay_buffer[env]))
                            sum_reward[env][i] = 0
 
                # update current_obs, current_info and sum_reward
                sum_reward[env] += reward[env]
                current_obs[env] = obs[env] 
                current_info[env] = info[env] 
       
                # <<< sample part
        if total_steps % (save_every_steps//num_envs) == 0:
            world_log_video = True
        else:
            world_log_video = False 
        final_loss = {} 
     
        for env in env_names:
        
            if replay_buffer[env].ready() and \
            total_steps % (train_dynamics_every_steps//num_envs) == 0: 
                if model_device == replay_buffer[env].obs_buffer.device: 
                    final_loss[env] = train_world_model_step(
                            replay_buffer=replay_buffer,
                            mixture_world_model=mixture_world_model,
                            env=env, 
                            batch_size=batch_size,
                            demonstration_batch_size=demonstration_batch_size,
                            batch_length=batch_length,
                            mode='train', 
                            log_video=world_log_video,
                            logger=logger
                    ) 
                    scaler = torch.cuda.amp.GradScaler(enabled=True)
                    scaler.scale(final_loss[env]).backward()
                    

                    scaler.unscale_(mixture_world_model.module.miwm_optimizer)  # for clip grad
                    grad_norm = torch.nn.utils.clip_grad_norm_(mixture_world_model.module.parameters(), max_norm=1000.0)
                    logger.log(f"{env}_MixtureWorldModel/grad_norm", grad_norm.item()) 
                
                    scaler.step(mixture_world_model.module.miwm_optimizer) 
                    scaler.update()
                    mixture_world_model.module.miwm_optimizer.zero_grad(set_to_none=True)
 
               
 
    # train agent part >>>
        agent_loss = {}
        imagine_latent, agent_action, agent_logprob, agent_value, imagine_reward, imagine_termination, latent_input ={},{},{},{},{},{},{}
        rank_envs = []
        for idx, env in enumerate(env_names):
            if replay_buffer[env].ready() \
              and total_steps % (train_agent_every_steps//num_envs) == 0 \
                and total_steps*num_envs >= 0:
                if total_steps % (save_every_steps//num_envs) == 0:
                    log_video = True
                else:
                    log_video = False  
                if model_device == replay_buffer[env].obs_buffer.device:  
                    rank_envs.append(env)  
                    imagine_latent_, agent_action_, agent_logprob_, agent_value_, imagine_reward_, imagine_termination_ = mixture_world_model_imagine_data(
                    replay_buffer=replay_buffer,
                    mixture_world_model=mixture_world_model,
                    agent=agent,
                    env_names=env_names,
                    env=env,
                    imagine_batch_size=imagine_batch_size,
                    imagine_demonstration_batch_size=imagine_demonstration_batch_size,
                    imagine_context_length=imagine_context_length,
                    imagine_batch_length=imagine_batch_length, 
                    log_video=log_video,
                    logger=logger)  
                    imagine_latent[env] = imagine_latent_.clone() 
                    agent_action[env] = agent_action_.clone()  
                    imagine_reward[env] = imagine_reward_.clone()
                    imagine_termination[env] = imagine_termination_.clone() 
                    task_id = mixture_world_model.module.get_task_index(env) 
                    latent_input_ = mixture_world_model.module.single_task_emb(imagine_latent[env], task_id) 
                    latent_input[env] = latent_input_.clone()   
        agent_loss =train_agent_step(agent, 
                                     rank_envs, 
                                     latent_input, 
                                     agent_action,
                                     agent_logprob, 
                                     agent_value, 
                                     imagine_reward, 
                                     imagine_termination,
                                     logger)
        agent_loss = gather_dict_across_ranks(agent_loss)
        if len(list(agent_loss.keys()))==len(env_names):  
            agent.module.update_slow_critic()
 
          
        # <<< train agent part

        for name, param in mixture_world_model.module.shared_world_model.storm_transformerwoaction.named_parameters():
            tensor = param.data.clone()
            dist.broadcast(tensor, src=0) 
            assert torch.allclose(param.data, tensor), f"{name} dif rank {rank}" 
        #print(f"model parameters are the same")
        for name, param in agent.module.actor.named_parameters():
            tensor = param.data.clone()
            dist.broadcast(tensor, src=0) 
            assert torch.allclose(param.data, tensor), f"{name} dif rank {rank}" 
        #print(f"model parameters are the same")
        # save model per episode
        if rank == 0 and total_steps % (save_every_steps//num_envs) == 0:
            print(colorama.Fore.GREEN + f"Saving model at total steps {total_steps}" + colorama.Style.RESET_ALL) 
            torch.save(agent.module.state_dict(), f"{logdir}/ckpt/{envs_str}/agent_{total_steps}.pth")
        if rank == 0 and total_steps % (save_every_steps//num_envs) == 0:
            torch.save({
                        'model': mixture_world_model.module.state_dict(),
                        'cluster_indices': mixture_world_model.module.cluster_indices, 
                        },f"{logdir}/ckpt/{envs_str}/mixture_world_model_{total_steps}.pth")


def joint_warm_up_mixture_world_model_agent(env_names, env_values, n_values, num_envs, warm_up_steps, image_size,
                                  replay_buffer, mixture_world_model, 
                                  train_dynamics_every_steps, 
                                  batch_size, demonstration_batch_size, batch_length, 
                                  save_every_steps, seed, rank, logdir, logger):
   
  
    current_obs = {}
    current_info = {} 
    vec_env = {}
    #envs_str = "_".join(env_names)
    envs_str = "all_atari_tasks" 
    model_device = torch.device(f"cuda:{rank}") 
    
    if rank == 0:
        os.makedirs(f"{logdir}/ckpt/{envs_str}", exist_ok=True)
    for idx, env in enumerate(env_names):
        if model_device == replay_buffer[env].obs_buffer.device:
            # os.makedirs(f"{logdir}/ckpt/{n_values[idx]}", exist_ok=True)
            vec_env[env] = build_vec_env(env_values[idx], image_size, num_envs=num_envs, seed=seed)
            print("Current env: " + colorama.Fore.YELLOW + f"{env_names[idx]} on {model_device}" + colorama.Style.RESET_ALL)
            #reset 
            current_obs[env], current_info[env] = vec_env[env].reset()
    
 
    obs, reward, action, done, truncated, info, done_flag = {}, {}, {}, {}, {}, {}, {} 
 
    warm_up_bar = tqdm(range(warm_up_steps//num_envs), desc=f"Warm up") if rank == 0 else range(warm_up_steps//num_envs) 
    #progress_bar = tqdm(range(max_steps//num_envs), desc=f"Training rank {rank}") #if rank == 0 else range(max_steps//num_envs) 
    for total_steps in warm_up_bar:
        for idx, env in enumerate(env_names):  
            if replay_buffer[env].ready():
                mixture_world_model.eval() 
                pass 
            else:
                if model_device == replay_buffer[env].obs_buffer.device:
                    action[env] = vec_env[env].action_space.sample() 
 
        for env in env_names: 
            if replay_buffer[env].ready():
                pass
            else:
                if model_device == replay_buffer[env].obs_buffer.device: 
                    obs[env], reward[env], done[env], truncated[env], info[env] = vec_env[env].step(action[env])   
                    done_flag[env] = np.logical_or(done[env], truncated[env]) 
                    replay_buffer[env].append(current_obs[env], action[env], reward[env], np.logical_or(done[env], info[env]["life_loss"])) 
    
                    # update current_obs, current_info and sum_reward
             
                    current_obs[env] = obs[env] 
                    current_info[env] = info[env] 
        
                # <<< sample part
        if total_steps % (save_every_steps//num_envs) == 0:
            world_log_video = True
        else:
            world_log_video = False  
        # final_loss = {}
        # losses = []
        for env in env_names:
            if replay_buffer[env].ready() and \
                total_steps % (train_dynamics_every_steps//num_envs) == 0: 
                if model_device == replay_buffer[env].obs_buffer.device:  
                    train_world_model_step_warm_up(
                            replay_buffer=replay_buffer,
                            mixture_world_model=mixture_world_model,
                            env=env, 
                            batch_size=batch_size,
                            demonstration_batch_size=demonstration_batch_size,
                            batch_length=batch_length,
                            mode='warm_up',
                            log_video=world_log_video,
                            logger=logger
                    )   
 
 
        for name, param in mixture_world_model.module.shared_world_model.storm_transformerwoaction.named_parameters():
            tensor = param.data.clone()
            dist.broadcast(tensor, src=0) 
            assert torch.allclose(param.data, tensor), f"{name} dif rank {rank}"  
        if rank == 0 and total_steps % (save_every_steps//num_envs) == 0:
            torch.save(mixture_world_model.module.state_dict(), f"{logdir}/ckpt/{envs_str}/warm_up_mixture_world_model_{total_steps}.pth")
        
    del obs
    del reward
    del done
    del truncated
    del info
    del done_flag
    del current_obs 
    del current_info   
    del action   
    return vec_env
    

def gather_full_gradvector_dict(local_gradvector_dict, world_size, device):
    gathered = [None for _ in range(world_size)]
    torch.distributed.all_gather_object(gathered, local_gradvector_dict)
    full_gradvector_dict = {}
    for partial in gathered:
        for k, v in partial.items():
            full_gradvector_dict[k] = v.to(device)
    return full_gradvector_dict

def gather_dict_across_ranks(local_dict):
    assert dist.is_initialized(), "Distributed not initialized"
    world_size = dist.get_world_size() 
    gathered_list = [None for _ in range(world_size)] 
    dist.all_gather_object(gathered_list, local_dict) 
    global_dict = {}
    for d in gathered_list:
        global_dict.update(d) 
    return global_dict
def get_grad_sim_weights(grad_vectors, env, model_device, eps=1e-8, T=0.1):
    keys_list = list(grad_vectors.keys())
    task_id = keys_list.index(env)
  
    sim_matrix = torch.zeros((len(grad_vectors), len(grad_vectors))).to(model_device)
    for i in range(len(grad_vectors)):
        for j in range(len(grad_vectors)):
            if i != j:
                i_name = keys_list[i]
                j_name = keys_list[j]  
                sim_matrix[i, j] = torch.dot(grad_vectors[i_name].to(model_device), grad_vectors[j_name].to(model_device)) / (
                        grad_vectors[i_name].to(model_device).norm() * grad_vectors[j_name].to(model_device).norm() + eps
                    )   
    weights = []
    for i in range(len(grad_vectors)):
        avg_sim = torch.mean(sim_matrix[i])  
        weights.append(torch.exp(avg_sim / T)) 
    weights = torch.softmax(torch.stack(weights), dim=0) 
    reg_term = 0.0
    for i in range(len(grad_vectors)):
        for j in range(i + 1, len(grad_vectors)):
            if sim_matrix[i, j] < 0: 
                reg_term += torch.abs(sim_matrix[i, j])
    reg_term = (1/len(grad_vectors))*0.1*reg_term
    return weights.detach(), reg_term 

def update_task_weights(env_names, w_k, adaptive_weights, model_device, T=0.5):
    w_tensor = torch.tensor([w_k[env] for env in env_names], device=model_device, dtype=torch.float32) 
    max_w = max(w_k[env].to(model_device) for env in env_names)  
    sum_exp_weights = torch.tensor(0.0, device=model_device)
    
    for env in env_names: 
        exp_term = torch.exp((w_k[env].to(model_device) - max_w) / T)
        sum_exp_weights += exp_term
  
    sum_exp_weights = torch.clamp(sum_exp_weights, min=1e-10)    
    for env in env_names: 
        exp_term = torch.exp((w_k[env].to(model_device) - max_w) / T)
        adaptive_weights[env] = len(env_names) * exp_term / sum_exp_weights 
        adaptive_weights[env] = torch.clamp(adaptive_weights[env], min=1e-6) 
    return adaptive_weights
  
def get_weights(weights_vectors, env, model_device, eps=1e-6):
    sum_weights = torch.tensor(0.0, device=model_device) 
    for env_name in weights_vectors:
        sum_weights +=  weights_vectors[env_name].to(model_device) 
    weight = len(weights_vectors)* weights_vectors[env_name].to(model_device) / sum_weights 
    return weight
 
 
def broadcast_model_parameters(model, src=0): 
    for name, param in model.state_dict().items():
        dist.broadcast(param.data, src)

def build_mixture_world_model(conf, env_names, action_dim, rank):
    device = torch.device(f"cuda:{rank}")
    model = MixtureWorldModel(
        in_channels=conf.Models.WorldModel.InChannels,
        env_names=env_names,
        if_harmony=conf.Models.WorldModel.if_harmony,
        stoch_dim=conf.Models.WorldModel.StochDim,
        action_dim=action_dim,
        task_dim=conf.Models.WorldModel.TaskDim,
        n_experts=conf.Models.WorldModel.n_experts, 
        n_activate_experts=conf.Models.WorldModel.n_activate_experts,
        n_clusters=conf.Models.WorldModel.n_clusters, 
        transformer_max_length=conf.Models.WorldModel.TransformerMaxLength,
        transformer_hidden_dim=conf.Models.WorldModel.TransformerHiddenDim,
        transformer_num_layers=conf.Models.WorldModel.TransformerNumLayers, 
        expert_num_layers=conf.Models.WorldModel.ExpertNumLayers,
        transformer_num_heads=conf.Models.WorldModel.TransformerNumHeads
    ).to(device)
    dist.barrier()
    broadcast_model_parameters(model, src=0)
    return DDP(model, device_ids=[rank], output_device=rank, broadcast_buffers=False, find_unused_parameters=True)
 

def build_agent(conf, env_names, action_dim, rank):
    device = torch.device(f"cuda:{rank}")
    agent = agents.ActorCriticAgent(
        feat_dim=conf.Models.WorldModel.StochDim*conf.Models.WorldModel.StochDim+conf.Models.WorldModel.TransformerHiddenDim+conf.Models.WorldModel.TaskDim,#
        env_names=env_names,
        n_clusters=conf.Models.WorldModel.n_clusters, 
        num_layers=conf.Models.Agent.NumLayers,
        hidden_dim=conf.Models.Agent.HiddenDim,
        action_dim=action_dim,
        gamma=conf.Models.Agent.Gamma,
        lambd=conf.Models.Agent.Lambda,
        entropy_coef=conf.Models.Agent.EntropyCoef,
    ).to(device)
    dist.barrier()
    broadcast_model_parameters(agent, src=0)
    return DDP(agent, device_ids=[rank], output_device=rank, broadcast_buffers=False, find_unused_parameters=True)

def get_specified_gpus(): 
    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        specified_gpus = [f"cuda:{i}" for i in range(len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')))]
    else:
        specified_gpus = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
    # print(specified_gpus)
    return specified_gpus 
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '1224'
    torch.cuda.set_device(rank)
    #num_gpus = torch.cuda.device_count()     
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    

def cleanup():
    dist.destroy_process_group()



def train(rank, world_size, replay_buffer, args, conf):
    setup(rank, world_size)  
    env_names = args.env_names.split(",")
    print(f"Rank {rank} using GPU {torch.cuda.current_device()}")
    # set seed
    #gpu_list = get_specified_gpus()
    #envs_for_rank = get_envs_for_rank(rank, gpu_list , replay_buffer)
    seed_np_torch(seed=args.seed)
     
    # tensorboard writer
    n_values = []
    env_values = []
    traj_path = []
 
    for env in env_names:
        n_values.append(f"{env}-life_done-wm_2L512D8H-100k-seed{args.seed}")
        env_values.append(f"ALE/{env}-v5")
 'Frostbite' 
    envs_str = "all_atari_tasks"    
    logdir = args.logdir 
    logger = Logger(path=f"{logdir}/runs/{envs_str}/{rank}") 
 

    if rank==0: 
        print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL) 
    shutil.copy(args.config_path, f"{args.logdir}/runs/all_atari_tasks/config.yaml") 
  
    for idx, env in enumerate(env_values):
        dummy_env = build_single_env(env, conf.BasicSettings.ImageSize, seed=0)
        current_action_dim = dummy_env.action_space.n  
    action_dim = current_action_dim
    agent = build_agent(conf, env_names, current_action_dim, rank) 
    mixture_world_model = build_mixture_world_model(conf, env_names, action_dim, rank) 
  
    vec_env = joint_warm_up_mixture_world_model_agent(
            env_names=env_names, 
            env_values=env_values, 
            n_values=n_values, 
            num_envs=conf.JointTrainAgent.NumEnvs,
            warm_up_steps=conf.JointTrainAgent.WarmupSteps,
            image_size=conf.BasicSettings.ImageSize,
            replay_buffer=replay_buffer,
            mixture_world_model=mixture_world_model, 
            train_dynamics_every_steps=conf.JointTrainAgent.TrainDynamicsEverySteps, 
            batch_size=conf.JointTrainAgent.BatchSize,
            demonstration_batch_size=conf.JointTrainAgent.DemonstrationBatchSize if conf.JointTrainAgent.UseDemonstration else 0,
            batch_length=conf.JointTrainAgent.BatchLength,  
            save_every_steps=conf.JointTrainAgent.SaveEverySteps,
            seed=args.seed,
            rank=rank, 
            logdir=logdir,
            logger=logger) 
    local_device = torch.cuda.current_device()
    cluster_vector = gather_full_gradvector_dict(mixture_world_model.module.gradvector, dist.get_world_size(), local_device)
    mixture_world_model.module.gradvector = cluster_vector
    mixture_world_model.module.route_cluster_index() 
    agent.module.initial_cluster(mixture_world_model.module.cluster_indices)
    # Modify training loop to handle distributed training
    joint_train_mixture_world_model_agent(
            env_names=env_names, 
            env_values=env_values, 
            vec_env=vec_env, 
            num_envs=conf.JointTrainAgent.NumEnvs,
            max_steps=conf.JointTrainAgent.SampleMaxSteps,
            image_size=conf.BasicSettings.ImageSize,
            replay_buffer=replay_buffer,
            mixture_world_model=mixture_world_model,
            agent=agent,
            train_dynamics_every_steps=conf.JointTrainAgent.TrainDynamicsEverySteps, 
            train_agent_every_steps=conf.JointTrainAgent.TrainAgentEverySteps,
            batch_size=conf.JointTrainAgent.BatchSize,
            demonstration_batch_size=conf.JointTrainAgent.DemonstrationBatchSize if conf.JointTrainAgent.UseDemonstration else 0,
            batch_length=conf.JointTrainAgent.BatchLength,
            imagine_batch_size=conf.JointTrainAgent.ImagineBatchSize,
            imagine_demonstration_batch_size=conf.JointTrainAgent.ImagineDemonstrationBatchSize if conf.JointTrainAgent.UseDemonstration else 0,
            imagine_context_length=conf.JointTrainAgent.ImagineContextLength,
            imagine_batch_length=conf.JointTrainAgent.ImagineBatchLength,
            save_every_steps=conf.JointTrainAgent.SaveEverySteps,
            seed=args.seed,
            rank=rank, 
            logdir=logdir,
            logger=logger) 
    cleanup()

def get_envs_for_rank(rank, gpu_list, replay_buffer):
    rank_gpu = gpu_list[rank]
    envs_for_rank = []
    for env_name, buffer in replay_buffer.items():
        if str(buffer.obs_buffer.device) == rank_gpu:
            envs_for_rank.append(env_name)
    return envs_for_rank

if __name__ == "__main__":
    # ignore warnings
    import warnings
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    import torch.multiprocessing as mp
    mp.set_start_method('fork', force=True)
 
    # parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("-seed", type=int, required=True)
    parser.add_argument("-config_path", type=str, required=True) 
    parser.add_argument("-logdir", type=str, required=True) 
    parser.add_argument("-env_names", type=str, required=True)   
    parser.add_argument("-n_gpus", type=int, default=torch.cuda.device_count())
    args = parser.parse_args()
    

    env_names = args.env_names.split(",") 
    conf = load_config(args.config_path) 
    
    gpu_list = get_specified_gpus() 
    gpu_list = list(reversed(gpu_list))
    replay_buffer = {}
    for i, env in enumerate(env_names):
        assigned_gpu = gpu_list[i % len(gpu_list)]
        replay_buffer[env] = ReplayBuffer(
            obs_shape=(conf.BasicSettings.ImageSize, conf.BasicSettings.ImageSize, 3),
            num_envs=conf.JointTrainAgent.NumEnvs,
            max_length=conf.JointTrainAgent.BufferMaxLength,
            warmup_length=conf.JointTrainAgent.BufferWarmUp,
            store_on_gpu=conf.BasicSettings.ReplayBufferOnGPU
        )
        replay_buffer[env].to(assigned_gpu)
    world_size = args.n_gpus
    if world_size >= 1:
        mp.spawn(train, args=(world_size, replay_buffer, args, conf), nprocs=world_size, join=True)
    else:
        train(0, 1, args, conf)

    
