import os, importlib, dill, numpy as np, copy
from algs.utilities.Noises import DiscreteUniformNoise
from algs.utilities.ContinuousAgentMaker import ContinuousAgentMaker
from envs.EnvironmentsWithFixedAgent import EnvironmentsWithFixedAgent as OneAgentEnv
from solvers.AlternateSolver import AlternateSolver
from running.utils import divide_config, get_action_values
from running.create_model import create_model

def init(env, subalg_name, Subalg, timesteps, config, fixed_agent_index):
    if subalg_name[:3] == 'SB3':
        agent_env = OneAgentEnv(env, fixed_agent_index=fixed_agent_index)
        agent = Subalg("MlpPolicy", agent_env, **config)
        return agent
    
    elif subalg_name == 'DDQN':
        action_n = config.pop('action_n')
        if fixed_agent_index == 'v':
            action_values = get_action_values(env.u_action_dim, env.u_action_min, env.u_action_max, action_n)
        elif fixed_agent_index == 'u':
            action_values = get_action_values(env.v_action_dim, env.v_action_min, env.v_action_max, action_n)
        Subalg = ContinuousAgentMaker(Subalg)
        
        model_structure = '256128'
        if 'model_structure' in config:
            model_structure = config.pop('model_structure')
        q_model = create_model(env.state_dim, model_structure, action_n)
        
        noise = DiscreteUniformNoise(action_n, threshold_decrease=1/timesteps)
            
        agent = Subalg(q_model, noise, action_values=action_values, **config)
            
    return agent


def learn(env, timesteps, config):
    subalg_name = config.pop('subalg_name')
    
    if subalg_name[:3] == 'SB3':
        Subalg = getattr(importlib.import_module('stable_baselines3'), subalg_name[3:])
    elif subalg_name == 'DDQN':
        Subalg = getattr(importlib.import_module(f'algs.DQN'), 'DDQN')
    
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, Subalg, adding_args)

    u_agent = init(env, subalg_name, Subalg, timesteps, copy.deepcopy(config_agent), fixed_agent_index='v')
    v_agent = init(env, subalg_name, Subalg, timesteps, copy.deepcopy(config_agent), fixed_agent_index='u')
    
    solver = AlternateSolver(env, u_agent, v_agent, subalg_name, **config_solver)
    solver.go(timesteps)
    
    info = {'trs': solver.total_rewards, 'tts': solver.total_timesteps}
    return [solver.u_agent, solver.v_agent], info

def save(agents, info, path):
    u_agent, v_agent = agents
    with open(os.path.join(path, 'u_model'), 'wb') as u_model_file:
        dill.dump(u_agent, u_model_file)
    with open(os.path.join(path, 'v_model'), 'wb') as v_model_file:
        dill.dump(v_agent, v_model_file)
    np.save(os.path.join(path, 'trs.npy'), info['trs'])
    np.save(os.path.join(path, 'tts.npy'), info['tts'])
    return None

def load(env, config, path):
    with open(os.path.join(path, 'u_model'), 'rb') as u_model_file:
        u_agent = dill.load(u_model_file)
    with open(os.path.join(path, 'v_model'), 'rb') as v_model_file:
        v_agent = dill.load(v_model_file)
    return u_agent, v_agent
