import datetime
import os
import pprint
import time
import sys
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath
import select
import re
import random

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer, EpisodeBatch
from components.transforms import OneHot

from smac.env import StarCraft2Env

from components.reward_model import reward_model, team_reward_model
from components.get_preference import get_preference
from components.prompt_generator import get_prompt, get_traj_placeholder, get_step_placeholder

def get_agent_own_state_size(env_args):
    sc_env = StarCraft2Env(**env_args)
    # qatten parameter setting (only use in qatten)
    return  4 + sc_env.shield_bits_ally + sc_env.unit_type_bits

def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    # setup loggers
    logger = Logger(_log)
    if args.env == "gfootball" :
        args.env_args['num_agents'] = args.num_agents
        
    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # configure tensorboard logger
    unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(dirname(abspath(__file__)))), "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)
    if args.use_wandb : 
        if args.env == "gfootball" :
            group_name = '{}_{}'.format(args.env_args['map_name'],args.reward_model_info)
        else :
            if args.use_human_pref :
                group_name = 'HP_{}_{}_{}_{}'.format(args.reward_model_info,args.wandb_group_info, args.env_args['map_name'], args.name)
            else :
                if args.use_llm_pref :
                    if args.compare_agents and args.compare_team:
                        group_name = 'p_B_{}_{}_{}_{}_{}_{}'.format(args.reward_model_info,args.wandb_group_info,args.LLM_model, args.env_args['map_name'],args.stop_epoch, args.name)
                    elif args.compare_agents == False and args.compare_team :
                        group_name = 'P_T_{}_{}_{}_{}_{}_{}'.format(args.reward_model_info,args.wandb_group_info, args.LLM_model,args.env_args['map_name'],args.stop_epoch, args.name)
                    elif args.compare_agents and args.compare_team == False :
                        group_name = 'P_I_{}_{}_{}_{}_{}_{}'.format(args.reward_model_info,args.wandb_group_info, args.LLM_model,args.env_args['map_name'],args.stop_epoch, args.name)                    
                    if args.use_team_reward :
                        group_name = '_{}_team'.format(group_name)
                    if args.use_ori_reward :
                        group_name = '_{}_ori'.format(group_name)
                else :
                    group_name = '{}_{}_{}'.format(args.wandb_group_info, args.env_args['map_name'], args.name)
        if args.scalability_test :
            group_name = group_name+'_scalable'
        if args.test_only :
            group_name = group_name+'_test_only'
        if args.use_extrinsic_reward :
            group_name = group_name+'_ext'
        if args.save_test_data :
            group_name = group_name+'_save_test'
        if args.use_kendall :
            group_name = group_name+'_kendall'
        else :
            group_name = group_name+'_wo_kendall'
        if args.use_intrinsic_reward_as_contribution :
            group_name = group_name +'_IQL'
        if args.replay_small :
            group_name = group_name+'_1000'
        if args.sc2_random :
            group_name = group_name+'_Rand'
        group_name = group_name+'_{}'.format(str(args.step_weight))
        group_name = group_name+'_reward_model_num_{}'.format(args.n_reward_functions)
        scenario_name = '{}_{}'.format(args.env_args['map_name'],unique_token)
        project_name = args.project_name
        logger.setup_wandb(project_name,group_name,scenario_name)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)

def _build_inputs(args, batch, max_length,i):
    # Assumes homogenous agents with flat observations.
    # Other MACs might want to e.g. delegate building inputs to each agent
    #bs = batch.batch_size
    bs = 1
    d = batch["obs"].shape[3] + batch["actions_onehot"].shape[3]+args.n_agents
    inputs = th.zeros((1,max_length,args.n_agents,d))
    for t in range(max_length) :
        input_t = []
        input_t.append(batch["obs"][i:i+1,t])  # b1av
        if args.obs_last_action:
            if t == 0:
                input_t.append(th.zeros_like(batch["actions_onehot"][i:i+1,t]))
            else:
                input_t.append(batch["actions_onehot"][i:i+1,t-1])
        if args.obs_agent_id:
            input_t.append(th.eye(args.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1))
        input_t = th.cat([x.reshape(bs, args.n_agents, -1) for x in input_t], dim=-1)
        inputs[:,t] = input_t
    return inputs

def evaluate_sequential(args, runner):

    for _ in range(args.test_nepisode):
        runner.run(test_mode=True,kendall_need=False)

    if args.save_replay:
        runner.save_replay()

    runner.close_env()

def write_process(processing_path,
                  data) :
    processing_data = '#{}#{}#{}#{}#'.format(data[0],data[1],data[2],data[3])
    with open(processing_path,'w') as file :
        file.write(processing_data)
        
def read_process(processing_path) :
    # want_to_read
    # 0 : preference_done
    # 1 : reward_updating_done
    # 2 : preference_doing
    # 3 : reward_updating_doing
    
    with open(processing_path,'r') as file :
        text = file.read()

    pattern = r'-?\d+'
    numbers = re.findall(pattern,text)
    for i,number in enumerate(numbers) :
        numbers[i] = int(number)
        
    return numbers

def wait_or_go(processing_path,num,seq) :
    data = read_process(processing_path)
    if data[num] > seq-1 :
        stop = False
    elif data[num] == seq-1 and data[num+2] == 0 :
        stop = True
        data[num+2] = 1
        write_process(processing_path,data)
    elif data[num] == seq-1 and data[num+2] == 1 :
        stop = False
        while True :
            data_new = read_process(processing_path)
            if data_new[num+2] == 0 :
                break
            time.sleep(30)
    else :
        stop = False
    return stop
    
    
def diff_ind_team(epi,mean_save_data,threshold=1.0,kendal=False,after_update_num=False) :
    differences = list()
    seq = list()
    all_means = list()
    if kendal :
        for i in range(epi["int_ind_reward"].shape[0]) :
            kendal = epi["kendalltau"][i]
            terminated = epi["terminated"][i]
            end_point = th.where(terminated==1)[0].item()
            kendal_mean = th.mean(kendal[:end_point+1])
            print(kendal_mean)
            if threshold == 1.0 :
                if kendal_mean.item() <= mean_save_data :
                    differences.append(kendal_mean)
                    seq.append(i)
            else :
                if kendal_mean.item() <= threshold or (kendal_mean.item() <= 0.7 and after_update_num) :
                    differences.append(kendal_mean)                
                    seq.append(i)
            all_means.append(kendal_mean)
        return seq,differences,th.mean(th.tensor(all_means)).item()
    else :
        seq = [random.randint(0,7) for _ in range(3)] 
        return seq,[th.zeros(1),th.zeros(1),th.zeros(1)],0
        
    
    
def run_sequential(args, logger):
    
    print('!!!!!!!!!!!!!!!!!!!! start !!!!!!!!!!!!!!!!!!!!')
    if args.use_team_reward == False :
        if args.compare_agents and args.compare_team :
            model_info_num = 0 
        elif args.compare_agents == False and args.compare_team :
            model_info_num = 1 
        elif args.compare_agents and args.compare_team == False :
            model_info_num = 2 
    else :
        if args.compare_agents and args.compare_team :
            model_info_num = 3 
        elif args.compare_agents == False and args.compare_team :
            model_info_num = 4 
        elif args.compare_agents and args.compare_team == False :
            model_info_num = 5 
            
    args.model_info_num = model_info_num 
    args.map_name = args.env_args['map_name']
    # Set up schemes and groups here
    runner = r_REGISTRY[args.runner](args=args, logger=logger)
    env_info = runner.get_env_info()
    
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]
    args.obs_shape = env_info["obs_shape"]
    
    args.accumulated_episodes = getattr(args, "accumulated_episodes", None)

    # Init runner so we can get env info

    if getattr(args, 'agent_own_state_size', False):
        args.agent_own_state_size = get_agent_own_state_size(args.env_args)
        
    
    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "probs": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.float},
        "reward": {"vshape": (1,)},
        "ori_reward":{"vshape":(1,)},
        "int_ind_reward":{"vshape":(1,)},
        "int_ind_reward_std":{"vshape":(1,)},
        "int_team_reward":{"vshape":(1,)},
        "int_team_reward_std":{"vshape":(1,)},
        "ind_reward":{"vshape":(1,), "group": "agents"},
        "ind_terminated":{"vshape":(1,),"group":"agents"},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
        "kendalltau" :{"vshape":(1,)},
        "int_ind_reward_all":{"vshape":(args.n_reward_functions,args.n_agents,)} 
    }
    groups = {
        "agents": args.n_agents
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)
    # Setup multiagent controller here
    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

    # Learner
    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)

    if args.use_cuda:
        learner.cuda()

    if args.checkpoint_path != "":

        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        learner.load_models(model_path)
        runner.t_env = timestep_to_load

        if args.evaluate or args.save_replay:
            evaluate_sequential(args, runner)
            return

    # start training
    episode = 0
    n_epi = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    epi_threshold = 1000
    start_time = time.time()
    last_time = start_time
    
    logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max))
    
    reward_model_updating_num = 0 
    save_cnt = 0 
    
    processing_path = '{}/{}_{}_{}_{}_{}.txt'.format(args.processing_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num,model_info_num)
    preference_done=0
    reward_updating_done=0
    preference_doing=0
    reward_updating_doing=0
    reward_model_update_interval = args.update_reward_model_interval
    save_diff = th.zeros((1000,2))
    gather_threshold = args.gather_threshold
    compare_cnt = 0
    d_mean = 0 
    if args.use_team_reward :
        previous_diff_reward_mean = 0
    else :
        previous_diff_reward_mean = 1
    
    processing_data = '#{}#{}#{}#{}#'.format(preference_done,reward_updating_done,preference_doing,reward_updating_doing)
    if not os.path.exists(processing_path) :
        with open(processing_path,'w') as file :
            file.write(processing_data)
    else :
        process_data = read_process(processing_path)
        preference_done=process_data[0]
        reward_updating_done=process_data[1]
        preference_doing=process_data[2]
        reward_updating_doing=process_data[3]

    if args.test_only :
        updated_reward_path1 = '{}/{}_{}_{}_{}.pt'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,reward_updating_done)
        updated_reward_path2 = '{}/{}_{}_{}_{}_team.pt'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,reward_updating_done)
        runner.get_reward_function(updated_reward_path1,updated_reward_path2)

    update_when = th.zeros((100,2))
    while runner.t_env <= args.t_max :

        # Run for a whole episode at a time
        with th.no_grad():
            if reward_model_updating_num < args.stop_epoch and args.use_kendall :
                episode_batch = runner.run(test_mode=False,kendall_need=True)
            else :
                episode_batch = runner.run(test_mode=False,kendall_need=False)
                
            buffer.insert_episode_batch(episode_batch)                

        if args.save_test_data == False or reward_model_updating_num == 0 :

            ##############################################################################################
            ########################SAVE REPLAY BUFFER ###################################################
            ##############################################################################################
            if args.use_team_reward :
                seq_diff_reward_functions,diff_reward_functions,d_mean = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=gather_threshold,kendal=True)
            else :
                if args.use_kendall :
                    if reward_model_updating_num != 0 :
                        seq_diff_reward_functions,diff_reward_functions,d_mean = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=gather_threshold,kendal=True)
                    else :
                        if (n_epi >= reward_model_update_interval) and (save_cnt < args.n_pref) :
                            seq_diff_reward_functions,diff_reward_functions,d_mean = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=1.0,kendal=True,after_update_num=True)
                        else :
                            seq_diff_reward_functions,diff_reward_functions,d_mean = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=1.0,kendal=True)
                else :
                    seq_diff_reward_functions,diff_reward_functions,d_mean = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=gather_threshold,kendal=False)


            if n_epi != 0 :
                gather_threshold = (gather_threshold * (n_epi-args.batch_size_run) + (d_mean*args.batch_size_run) ) / n_epi
            else :
                gather_threshold = d_mean
            print('Epi : {} / Buffer : {} / updating_num : {} / previous mean : {:.3f} / now mean : {:.3f}'.format(episode,save_cnt,reward_model_updating_num,gather_threshold,d_mean))
            

            ##################################################################################################
            ### To compare the cases when Kendall's Tau is used or not #######################################

            if len(seq_diff_reward_functions) != args.batch_size_run and compare_cnt < args.n_pref :
                for i in range(args.batch_size_run) :
                    if i not in seq_diff_reward_functions :
                        compare_cnt+=1
                        if not os.path.exists('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num)) :
                            os.makedirs('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num))
                        save_traj_dir = '{}/{}_{}_{}_{}/Comparing_{}_'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num,compare_cnt)
                        for key in scheme :
                            if not os.path.exists('{}{}.pt'.format(save_traj_dir,key)) :                    
                                if key == 'obs' :
                                    obs_include_previous_action = _build_inputs(args, episode_batch, episode_batch.max_seq_length,i)
                                    th.save(th.unsqueeze(obs_include_previous_action[0],dim=0),'{}{}.pt'.format(save_traj_dir,key))
                                else :
                                    th.save(th.unsqueeze(episode_batch[key][i],dim=0),'{}{}.pt'.format(save_traj_dir,key))

            ##################################################################################################
            ##################################################################################################

            if len(seq_diff_reward_functions) != 0 and reward_model_updating_num < args.stop_epoch and args.test_only == False and save_cnt < args.n_pref : #and (n_epi >= reward_model_update_interval or reward_model_updating_num == 0) : 
                for i in range(len(seq_diff_reward_functions)) :
                    save_diff[save_cnt,0] = save_cnt
                    save_diff[save_cnt,1] = diff_reward_functions[i].item()
                    save_cnt+=1
                    if not os.path.exists('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num)) :
                        os.makedirs('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num))
                    save_traj_dir = '{}/{}_{}_{}_{}/{}_'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num,save_cnt)
                    for key in scheme :
                        if not os.path.exists('{}{}.pt'.format(save_traj_dir,key)) :                    
                            if key == 'obs' :
                                obs_include_previous_action = _build_inputs(args, episode_batch, episode_batch.max_seq_length,seq_diff_reward_functions[i])
                                th.save(th.unsqueeze(obs_include_previous_action[0],dim=0),'{}{}.pt'.format(save_traj_dir,key))
                            else :
                                th.save(th.unsqueeze(episode_batch[key][seq_diff_reward_functions[i]],dim=0),'{}{}.pt'.format(save_traj_dir,key))


            ##############################################################################################
            ######################## Update Reward Function ##############################################
            ##############################################################################################
            #if save_cnt >= reward_model_update_interval or n_epi >= args.maxium_update_num :
            if  (n_epi >= reward_model_update_interval or reward_model_updating_num == 0) and save_cnt >= args.n_pref :
                n_epi = 0 
                if args.epsilon_reset :
                    runner.mac.action_selector.epsilon = args.epsilon_start
                update_when[reward_model_updating_num,0] = episode
                update_when[reward_model_updating_num,1] = runner.t_env
                th.save(update_when,'save_update_log/{}_{}_{}_{}_save_when.pt'.format(args.LLM_model,args.reward_model_info,model_info_num,args.stop_epoch))
                reward_model_updating_num += 1
                th.save(save_diff,'save_update_log/{}_{}_save_diff_{}_{}.pt'.format(args.LLM_model,args.reward_model_info,model_info_num,reward_model_updating_num))
                previous_diff_reward_mean = th.mean(save_diff[:save_cnt,1])
                save_diff = th.zeros((1000,2))

                if reward_model_updating_num == 1 :
                    for reward_num in range(args.n_reward_functions) :
                        save_reward_path = '{}/{}_{}_{}_{}_{}_{}.pt'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num-1,reward_num)
                        th.save(runner.reward_models[reward_num].state_dict(),save_reward_path)
                        if args.use_team_reward :
                            save_reward_path2 = '{}/{}_{}_{}_{}_{}_{}_team.pt'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num-1,reward_num)
                            th.save(runner.team_reward_models[reward_num].state_dict(),save_reward_path2)

                print('########################################')
                print('waiting for training')
                runner.close_env()

                ##############################################################################################
                ################### Preference Setup #########################################################
                ##############################################################################################

                ##############################################################################################
                ### 1. Prompt Generation & Preference gathering ###
                ## traj_preference 
                print('########################################')
                print('Started obtaining preferences(traj) from the LLM.')
                skip_or_go = wait_or_go(processing_path,0,reward_model_updating_num)
                if skip_or_go :
                    if args.compare_team :
                        get_preference(gpt_model=args.LLM_model,
                                       seq=reward_model_updating_num-1,
                                       n_pref=args.n_pref,
                                       step_n_pref=args.step_n_pref,
                                       n_epi=args.n_pref,
                                       traj_or_step='traj',
                                       model_name=args.reward_model_info,
                                       scenario=args.map_name,
                                       n_agents = args.n_agents,
                                       n_repeat=args.n_repeat,
                                       key=args.gpt_api_key,
                                      envs = args.env,
                                      scalability=args.scalability_test)

                #######################################
                ## step_preference 
                    if args.compare_agents :
                        print('########################################')
                        print('Started obtaining preferences(step) from the LLM.')
                        if args.scalability_test :
                            get_preference(gpt_model=args.LLM_model,
                                           seq=reward_model_updating_num-1,
                                           n_pref=args.n_pref,
                                           step_n_pref=args.step_n_pref,
                                           n_epi=args.n_pref,
                                           traj_or_step='step',
                                           model_name=args.reward_model_info,
                                           scenario=args.map_name,
                                           n_agents = args.compare_agents_num,
                                           n_repeat=args.n_repeat,
                                           pref_per_epi = 3,
                                           key=args.gpt_api_key,
                                          envs = args.env,
                                          scalability=args.scalability_test)                            
                        else :
                            get_preference(gpt_model=args.LLM_model,
                                           seq=reward_model_updating_num-1,
                                           n_pref=args.n_pref,
                                           step_n_pref=args.step_n_pref,
                                           n_epi=args.n_pref,
                                           traj_or_step='step',
                                           model_name=args.reward_model_info,
                                           scenario=args.map_name,
                                           n_agents = args.n_agents,
                                           n_repeat=args.n_repeat,
                                           pref_per_epi = 3,
                                           key=args.gpt_api_key,
                                          envs = args.env,
                                          scalability=args.scalability_test)
                    process_data = read_process(processing_path)
                    process_data[2] = 0
                    process_data[0] = reward_model_updating_num
                    write_process(processing_path,process_data)

                ##############################################################################################
                ### 2. Training Reward Model ###
                print('########################################')
                print('Started training reward model')
                skip_or_go = wait_or_go(processing_path,1,reward_model_updating_num)
                if skip_or_go :
                    runner.update_reward_model(model_path=args.reward_model_path,
                                               seq=reward_model_updating_num,
                                               training_steps=args.reward_model_epoch,
                                               info_num = model_info_num,
                                               update_threshold = args.update_threshold
                                              )

                    process_data = read_process(processing_path)
                    process_data[3] = 0
                    process_data[1] = reward_model_updating_num
                    write_process(processing_path,process_data)

                ##############################################################################################
                ### 3. Load Reward Model ###         
                print('########################################')
                print('Started loading reward model')
                while True :
                    updated_reward_path = '{}/{}_{}_{}_{}_{}_'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num)
                    updated_reward_path2 = '{}/{}_{}_{}_{}_{}_'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num)

                    if os.path.exists(updated_reward_path+'0.pt') :
                        runner.get_reward_function(path1=updated_reward_path,path2=updated_reward_path2)
                        runner.open_env()
                        print('########################################')
                        print('ok! good!')
                        break

                ##############################################################################################
                '''
                ### 4. Change all rewards in replay buffer
                prev_replay_memory = buffer.sample(buffer.episodes_in_buffer)
                updated_replay_memory = runner.update_reward_in_replay_memory(prev_replay_memory)
                buffer.update_all(updated_replay_memory,buffer.episodes_in_buffer)
                '''

                #args.update_reward_model_interval += args.update_reward_model_interval
                ##############################################################################################
                if reward_model_updating_num == 1 :
                    reward_model_update_interval = args.update_reward_model_interval
                save_cnt = 0 
                compare_cnt = 0             
            

            
        if buffer.can_sample(args.batch_size) and (reward_model_updating_num >= 1 or args.use_intrinsic_reward_as_contribution or args.ind_q_learning) :
            next_episode = episode + args.batch_size_run
            if args.accumulated_episodes and next_episode % args.accumulated_episodes != 0:
                continue

            episode_sample = buffer.sample(args.batch_size)
            
            # Truncate batch to only filled timesteps
            max_ep_t = episode_sample.max_t_filled()
            episode_sample = episode_sample[:, :max_ep_t]

            if episode_sample.device != args.device:
                episode_sample.to(args.device)

            learner.train(runner.reward_models,episode_sample, runner.t_env, episode)
            del episode_sample

        # Execute test runs once in a while
        n_test_runs = max(1, args.test_nepisode // runner.batch_size)
        if (runner.t_env - last_test_T) / args.test_interval >= 1.0:

            logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
                time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time)))
            last_time = time.time()

            last_test_T = runner.t_env
            for _ in range(n_test_runs):
                episode_batch = runner.run(test_mode=True,kendall_need=False)
                
                if args.save_test_data and reward_model_updating_num != 0  :

                    ##############################################################################################
                    ########################SAVE REPLAY BUFFER ###################################################
                    ##############################################################################################
                    if args.use_team_reward :
                        #seq_diff_reward_functions,diff_reward_functions = diff_ind_team(episode_batch,(previous_diff_reward_mean+th.mean(save_diff[:save_cnt+1,1]))/2,threshold=args.gather_threshold)
                        seq_diff_reward_functions,diff_reward_functions = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=args.gather_threshold,kendal=True)
                    else :
                        #seq_diff_reward_functions,diff_reward_functions = diff_ind_team(episode_batch,(previous_diff_reward_mean+th.mean(save_diff[:save_cnt+1,1]))/2,threshold=args.gather_threshold,kendal=True)
                        if reward_model_updating_num != 0 :
                            seq_diff_reward_functions,diff_reward_functions = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=args.gather_threshold,kendal=True)
                        else :
                            seq_diff_reward_functions,diff_reward_functions = diff_ind_team(episode_batch,previous_diff_reward_mean,threshold=1.0,kendal=True)

                    print('Epi : {} / Buffer : {} / updating_num : {} / previous mean : {:.3f} / now mean : {:.3f}'.format(episode,save_cnt,reward_model_updating_num,previous_diff_reward_mean,th.mean(save_diff[:save_cnt+1,1])))


                    ###############################################################################################################################################################################################################################
                    ### To compare the cases when Kendall's Tau is used or not ####################################################################################################################################################################

                    if len(seq_diff_reward_functions) != args.batch_size_run and compare_cnt < args.n_pref :
                        for i in range(args.batch_size_run) :
                            if i not in seq_diff_reward_functions :
                                compare_cnt+=1
                                if not os.path.exists('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num)) :
                                    os.makedirs('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num))
                                save_traj_dir = '{}/{}_{}_{}_{}/Comparing_{}_'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num,compare_cnt)
                                for key in scheme :
                                    if not os.path.exists('{}{}.pt'.format(save_traj_dir,key)) :                    
                                        if key == 'obs' :
                                            obs_include_previous_action = _build_inputs(args, episode_batch, episode_batch.max_seq_length,i)
                                            th.save(th.unsqueeze(obs_include_previous_action[0],dim=0),'{}{}.pt'.format(save_traj_dir,key))
                                        else :
                                            th.save(th.unsqueeze(episode_batch[key][i],dim=0),'{}{}.pt'.format(save_traj_dir,key))

                    ##############################################################################################################################################################################################################################
                    ##############################################################################################################################################################################################################################

                    if len(seq_diff_reward_functions) != 0 and reward_model_updating_num < args.stop_epoch and args.test_only == False and save_cnt < args.n_pref : 
                        for i in range(len(seq_diff_reward_functions)) :
                            save_diff[save_cnt,0] = save_cnt
                            save_diff[save_cnt,1] = diff_reward_functions[i].item()
                            save_cnt+=1
                            if not os.path.exists('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num)) :
                                os.makedirs('{}/{}_{}_{}_{}'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num))
                            save_traj_dir = '{}/{}_{}_{}_{}/{}_'.format(args.replay_buffer_save_path,args.LLM_model,args.map_name,args.reward_model_info,reward_model_updating_num,save_cnt)
                            for key in scheme :
                                if not os.path.exists('{}{}.pt'.format(save_traj_dir,key)) :                    
                                    if key == 'obs' :
                                        obs_include_previous_action = _build_inputs(args, episode_batch, episode_batch.max_seq_length,seq_diff_reward_functions[i])
                                        th.save(th.unsqueeze(obs_include_previous_action[0],dim=0),'{}{}.pt'.format(save_traj_dir,key))
                                    else :
                                        th.save(th.unsqueeze(episode_batch[key][seq_diff_reward_functions[i]],dim=0),'{}{}.pt'.format(save_traj_dir,key))


                    ##############################################################################################
                    ######################## Update Reward Function ##############################################
                    ##############################################################################################
                    #if save_cnt >= reward_model_update_interval or n_epi >= args.maxium_update_num :
                    if  (n_epi >= reward_model_update_interval or reward_model_updating_num == 0) and save_cnt >= args.n_pref :
                        n_epi = 0 
                        if args.epsilon_reset :
                            runner.mac.action_selector.epsilon = args.epsilon_start
                        update_when[reward_model_updating_num,0] = episode
                        update_when[reward_model_updating_num,1] = runner.t_env
                        th.save(update_when,'save_update_log/{}_{}_{}_{}_save_when.pt'.format(args.LLM_model,args.reward_model_info,model_info_num,args.stop_epoch))
                        reward_model_updating_num += 1
                        th.save(save_diff,'save_update_log/{}_{}_save_diff_{}_{}.pt'.format(args.LLM_model,args.reward_model_info,model_info_num,reward_model_updating_num))
                        previous_diff_reward_mean = th.mean(save_diff[:save_cnt,1])
                        save_diff = th.zeros((1000,2))

                        if reward_model_updating_num == 1 :
                            for reward_num in range(args.n_reward_functions) :
                                save_reward_path = '{}/{}_{}_{}_{}_{}_{}.pt'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num-1,reward_num)
                                th.save(runner.reward_models[reward_num].state_dict(),save_reward_path)
                                if args.use_team_reward :
                                    save_reward_path2 = '{}/{}_{}_{}_{}_{}_{}_team.pt'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num-1,reward_num)
                                    th.save(runner.team_reward_models[reward_num].state_dict(),save_reward_path2)

                        print('########################################')
                        print('waiting for training')
                        runner.close_env()

                        ##############################################################################################
                        ################### Preference Setup #########################################################
                        ##############################################################################################

                        ##############################################################################################
                        ### 1. Prompt Generation & Preference gathering ###
                        ## traj_preference 
                        print('########################################')
                        print('Started obtaining preferences(traj) from the LLM.')
                        skip_or_go = wait_or_go(processing_path,0,reward_model_updating_num)
                        if skip_or_go :
                            get_preference(gpt_model=args.LLM_model,
                                           seq=reward_model_updating_num-1,
                                           n_pref=args.n_pref,
                                           step_n_pref=args.step_n_pref,
                                           n_epi=args.n_pref,
                                           traj_or_step='traj',
                                           model_name=args.reward_model_info,
                                           scenario=args.map_name,
                                           n_agents = args.n_agents,
                                           n_repeat=args.n_repeat,
                                           key=args.gpt_api_key,
                                          envs = args.env)


                        #######################################
                        ## step_preference 
                            if args.compare_agents :
                                print('########################################')
                                print('Started obtaining preferences(step) from the LLM.')
                                get_preference(gpt_model=args.LLM_model,
                                               seq=reward_model_updating_num-1,
                                               n_pref=args.n_pref,
                                               step_n_pref=args.step_n_pref,
                                               n_epi=args.n_pref,
                                               traj_or_step='step',
                                               model_name=args.reward_model_info,
                                               scenario=args.map_name,
                                               n_agents = args.n_agents,
                                               n_repeat=args.n_repeat,
                                               pref_per_epi = int(args.n_pref/args.update_reward_model_interval)+1,
                                               key=args.gpt_api_key,
                                              envs = args.env)
                            process_data = read_process(processing_path)
                            process_data[2] = 0
                            process_data[0] = reward_model_updating_num
                            write_process(processing_path,process_data)

                        ##############################################################################################
                        ### 2. Training Reward Model ###
                        print('########################################')
                        print('Started training reward model')
                        skip_or_go = wait_or_go(processing_path,1,reward_model_updating_num)
                        if skip_or_go :
                            runner.update_reward_model(model_path=args.reward_model_path,
                                                       seq=reward_model_updating_num,
                                                       training_steps=args.reward_model_epoch,
                                                       info_num = model_info_num,
                                                       update_threshold = args.update_threshold
                                                      )

                            process_data = read_process(processing_path)
                            process_data[3] = 0
                            process_data[1] = reward_model_updating_num
                            write_process(processing_path,process_data)

                        ##############################################################################################
                        ### 3. Load Reward Model ###         
                        print('########################################')
                        print('Started loading reward model')
                        while True :
                            updated_reward_path = '{}/{}_{}_{}_{}_{}_'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num)
                            updated_reward_path2 = '{}/{}_{}_{}_{}_{}_'.format(args.reward_model_path,args.LLM_model,args.map_name,args.reward_model_info,model_info_num,reward_model_updating_num)

                            if os.path.exists(updated_reward_path+'0.pt') :
                                runner.get_reward_function(path1=updated_reward_path,path2=updated_reward_path2)
                                runner.open_env()
                                print('########################################')
                                print('ok! good!')
                                break

                        ##############################################################################################
                        '''
                        ### 4. Change all rewards in replay buffer
                        prev_replay_memory = buffer.sample(buffer.episodes_in_buffer)
                        updated_replay_memory = runner.update_reward_in_replay_memory(prev_replay_memory)
                        buffer.update_all(updated_replay_memory,buffer.episodes_in_buffer)
                        '''

                        #args.update_reward_model_interval += args.update_reward_model_interval
                        ##############################################################################################
                        if reward_model_updating_num == 1 :
                            reward_model_update_interval = args.update_reward_model_interval
                        save_cnt = 0 
                        compare_cnt = 0 
                      


        if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0):
            model_save_time = runner.t_env
            save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env))
            #"results/models/{}".format(unique_token)
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))

            # learner should handle saving/loading -- delegate actor save/load to mac,
            # use appropriate filenames to do critics, optimizer states
            learner.save_models(save_path)

        episode += args.batch_size_run
        n_epi += args.batch_size_run
        
        if (runner.t_env - last_log_T) >= args.log_interval:
            logger.log_stat("episode", episode, runner.t_env)
            logger.log_stat("kendall_prev_mean", gather_threshold, runner.t_env)
            logger.log_stat("kendall_now_mean", gather_threshold, runner.t_env)
            logger.print_recent_stats()
            last_log_T = runner.t_env
    
                        
    runner.close_env()
    logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
