import os, sys, importlib, numpy as np, torch.nn as nn 
sys.path.insert(0, os.path.abspath('.'))
from algs.utilities.SequentialNetwork import SequentialNetwork
from algs.utilities.Noises import UniformNoise, OUNoise, DiscreteUniformNoise
from algs.utilities.ContinuousAgentMaker import ContinuousAgentMaker
from solvers.OneAgentSolver import OneAgentSolver
from running.utils import divide_config, get_action_values


def learn(env, timesteps, config):
    alg_name = config.pop('alg_name')
    ts_multiplier = config.pop('ts_multiplier')
    test_timesteps = int(ts_multiplier * timesteps)
    
    if alg_name == 'CCEM':
        from algs.CEM import CEM_Continuous
        pi_model = SequentialNetwork([env.state_dim, 256, 128, env.action_dim], nn.ReLU(), nn.Tanh())
        noise = UniformNoise(env.action_dim, threshold_decrease=1/test_timesteps)
        config_agent, config_solver = divide_config(config, CEM_Continuous)
        agent = CEM_Continuous(env.action_min, env.action_max, pi_model, noise, **config_agent)
        solver = OneAgentSolver(env, agent, test_timesteps, learning_type='by_sessions', **config_solver)
        solver.go()
        
    elif alg_name == 'DCEM':
        from algs.CEM import CEM_Discrete
        action_n = config.pop('action_n')
        config_agent, config_solver = divide_config(config, CEM_Discrete)
        pi_model = SequentialNetwork([env.state_dim, 256, 128, action_n], nn.ReLU())
        noise = DiscreteUniformNoise(action_n, threshold_decrease=1/test_timesteps)
        CEM_Discrete = ContinuousAgentMaker(CEM_Discrete)
        action_values = get_action_values(env.action_dim, env.action_min, env.action_max, action_n)
        agent = CEM_Discrete(pi_model, noise, action_values=action_values, **config_agent)
        solver = OneAgentSolver(env, agent, test_timesteps, learning_type='by_sessions', **config_solver)
        solver.go()
        
    elif alg_name == 'DDQN':
        from algs.DQN import DDQN
        action_n = config.pop('action_n')
        config_agent, config_solver = divide_config(config, DDQN)
        q_model = SequentialNetwork([env.state_dim, 256, 128, action_n], nn.ReLU())
        noise = DiscreteUniformNoise(action_n, threshold_decrease=1/test_timesteps)
        DDQN = ContinuousAgentMaker(DDQN)
        action_values = get_action_values(env.action_dim, env.action_min, env.action_max, action_n)
        agent = DDQN(q_model, noise, action_values=action_values, **config_agent)
        solver = OneAgentSolver(env, agent, test_timesteps, learning_type='by_fives', **config_solver)
        solver.go()
        
    elif alg_name == 'DDPG':
        from algs.DDPG import DDPG
        testing_learning = 'by_fives'
        q_model = SequentialNetwork([env.state_dim + env.action_dim, 256, 128, 1], nn.ReLU())
        pi_model = SequentialNetwork([env.state_dim, 256, 128, env.action_dim], nn.ReLU(), nn.Tanh())
        noise = OUNoise(env.action_dim, threshold_decrease=1/test_timesteps)
        config_agent, config_solver = divide_config(config, DDPG)
        agent = DDPG(env.action_min, env.action_max, q_model, pi_model, noise, **config_agent)
        solver = OneAgentSolver(env, agent, test_timesteps, learning_type='by_fives', **config_solver)
        solver.go()
    
    elif 'SB3' in alg_name:
        alg = getattr(importlib.import_module('stable_baselines3'), alg_name[3:])
        config_agent, config_solver = divide_config(config, alg)
        agent = alg("MlpPolicy", env, **config_agent)
        agent.learn(total_timesteps=test_timesteps)
    
    return {'trs': env.total_rewards, 'tts': env.total_timesteps, 'ubs': env.u_best_session, 'vbs': env.v_best_session}


def save(info, opponent_index, path):
    np.save(os.path.join(path, 'trs.npy'), info['trs'])
    np.save(os.path.join(path, 'tts.npy'), info['tts'])
    np.save(os.path.join(path, 'bs.npy'), info[f'{opponent_index}bs'])
    