


import random
import sys
from llrl.agents.umcts import UMCTS
from llrl.experiments import run_agents_lifelong
from llrl.utils.env_handler import make_env_distribution

import numpy as np
import torch

PARAM = [
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 50, 'n_steps': 4, 'n_known': 1, 'stochastic': False},
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 50, 'n_steps': 4, 'n_known': 1, 'stochastic': True},
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 150, 'n_steps': 4, 'n_known': 3, 'stochastic': True},
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 500, 'n_steps': 4, 'n_known': 10, 'stochastic': True},

    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 50, 'n_steps': 7, 'n_known': 1, 'stochastic': False},
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 50, 'n_steps': 7, 'n_known': 1, 'stochastic': True},
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 150, 'n_steps': 7, 'n_known': 3, 'stochastic': True},
    {'version': 2, 'size': 6, 'n_tasks': 20, 'n_episodes': 500, 'n_steps': 7, 'n_known': 10, 'stochastic': True},

    {'version': 2, 'size': 11, 'n_tasks': 20, 'n_episodes': 200, 'n_steps': 10, 'n_known': 1, 'stochastic': False},
    {'version': 2, 'size': 11, 'n_tasks': 30, 'n_episodes': 400, 'n_steps': 10, 'n_known': 1, 'stochastic': True},
    {'version': 2, 'size': 11, 'n_tasks': 30, 'n_episodes': 1200, 'n_steps': 10, 'n_known': 3, 'stochastic': True},
    {'version': 2, 'size': 11, 'n_tasks': 15, 'n_episodes': 2000, 'n_steps': 10, 'n_known': 10, 'stochastic': True},  # 11 -> BIS
    {'version': 2, 'size': 11, 'n_tasks': 30, 'n_episodes': 4000, 'n_steps': 10, 'n_known': 10, 'stochastic': True},  # 11 -> SELECTED

    {'version': 2, 'size': 14, 'n_tasks': 20, 'n_episodes': 200, 'n_steps': 12, 'n_known': 1, 'stochastic': False},
    {'version': 2, 'size': 14, 'n_tasks': 30, 'n_episodes': 400, 'n_steps': 12, 'n_known': 1, 'stochastic': True},
    {'version': 2, 'size': 14, 'n_tasks': 30, 'n_episodes': 1200, 'n_steps': 12, 'n_known': 3, 'stochastic': True},
    {'version': 2, 'size': 14, 'n_tasks': 30, 'n_episodes': 4000, 'n_steps': 12, 'n_known': 10, 'stochastic': True},  # FAIL

    {'version': 2, 'size': 20, 'n_tasks': 20, 'n_episodes': 200, 'n_steps': 20, 'n_known': 1, 'stochastic': False},
    {'version': 2, 'size': 20, 'n_tasks': 30, 'n_episodes': 400, 'n_steps': 20, 'n_known': 1, 'stochastic': True},
    {'version': 2, 'size': 20, 'n_tasks': 30, 'n_episodes': 1200, 'n_steps': 20, 'n_known': 3, 'stochastic': True},
    {'version': 2, 'size': 20, 'n_tasks': 10, 'n_episodes': 2000, 'n_steps': 20, 'n_known': 5, 'stochastic': True},  # 11 -> BIS
    {'version': 2, 'size': 20, 'n_tasks': 30, 'n_episodes': 4000, 'n_steps': 20, 'n_known': 10, 'stochastic': True},  # 11 -> SELECTED
]
def set_seed(seed):
    random.seed(seed)  
    np.random.seed(seed)  
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True  #
    
def experiment(p, name):
    # Parameters
    log_dir = 'logs'

    gamma = .95
    n_env = 5
    size = p['size']

    env_distribution = make_env_distribution(
        env_class='tight', n_env=n_env, gamma=gamma,
        env_name=name,
        version=p['version'],
        w=size,
        h=size,
        stochastic=p['stochastic'],
        verbose=False
    )
    # state_dim = env_distribution.get_state_dim()
    states = env_distribution.get_init_state()
    state_dim = len(states)
    print('State dim:', state_dim)
    actions = env_distribution.get_actions()
    n_known = p['n_known']
    p_min = 1. / n_env
    epsilon_q = .01
    epsilon_m = .01
    delta = .1
    r_max = 1.
    v_max = 10.
    n_states = 4
    max_mem = 1

    max_depth = 10
    num_simulations = 1


    umcts = UMCTS(actions=actions, gamma = gamma, r_max=r_max, num_simulations=num_simulations, max_depth=max_depth)


    agents_pool = [umcts]

    # Run
    run_agents_lifelong(agents_pool, env_distribution, n_instances=1, n_tasks=p['n_tasks'], n_episodes=p['n_episodes'],
                        n_steps=p['n_steps'], reset_at_terminal=False, open_plot=False, plot_title=False,
                        plot_legend=2, do_run=True, do_plot=True, parallel_run=False, n_processes=None,
                        episodes_moving_average=True, episodes_ma_width=100, tasks_moving_average=False,
                        latex_rendering=True, log_dir=log_dir)




    


if __name__ == '__main__':
    
    experiment_index = int(sys.argv[1])
    tight_version = PARAM[experiment_index]['version']
    experiment_name = 'tight-v' + str(tight_version) + '-' + str(experiment_index)
    
    set_seed(0)

    print(PARAM[experiment_index])
    # exit()

    experiment(PARAM[experiment_index], experiment_name)