import os, sys, importlib, numpy as np, torch.nn as nn, torch, copy
sys.path.insert(0, os.path.abspath('.'))
from algs.DQN import DDQN
from algs.utilities.SequentialNetwork import SequentialNetwork
from algs.utilities.Noises import UniformNoise, OUNoise, DiscreteUniformNoise
from algs.utilities.ContinuousAgentMaker import ContinuousAgentMaker
from solvers.SimultaneousSolver import SimultaneousSolver
from running.utils import divide_config, save_models, load_models, get_action_values
from running.create_model import create_model


def init(state_dim, action_dim, action_min, action_max, action_n, timesteps, config, threshold=1):
    from algs.DQN import DDQN
    
    model_structure = '256128'
    if 'model_structure' in config:
        model_structure = config.pop('model_structure')
    q_model = create_model(state_dim, model_structure, action_n)
    
    noise = DiscreteUniformNoise(action_n, threshold_decrease=1/timesteps)
    
    action_values = get_action_values(action_dim, action_min, action_max, action_n)
    DDQN = ContinuousAgentMaker(DDQN)
    
    return DDQN(q_model, noise, action_values=action_values, **config)


def learn(env, timesteps, config):
    from algs.DQN import DDQN
    action_n = config.pop('action_n')
    
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, DDQN, adding_args)
    
    u_agent = init(env.state_dim, env.u_action_dim, env.u_action_min, env.u_action_max, 
                   action_n, timesteps, copy.deepcopy(config_agent), threshold=1)
    v_agent = init(env.state_dim, env.v_action_dim, env.v_action_min, env.v_action_max, 
                   action_n, timesteps, copy.deepcopy(config_agent), threshold=1)
    
    solver = SimultaneousSolver(env, u_agent, v_agent, **config_solver)
    solver.go(timesteps)
    
    info = {'trs': solver.total_rewards, 'tts': solver.total_timesteps}
    return [u_agent, v_agent], info

def save(agents, info, path):
    u_agent, v_agent = agents
    torch.save(u_agent.q_model.state_dict(), os.path.join(path, 'u_q_model'))
    torch.save(v_agent.q_model.state_dict(), os.path.join(path, 'v_q_model'))
    np.save(os.path.join(path, 'trs.npy'), info['trs'])
    np.save(os.path.join(path, 'tts.npy'), info['tts'])
    return None

def load(env, config, path):
    from algs.DQN import DDQN
    action_n = config.pop('action_n')
    
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, DDQN, adding_args)
    
    u_agent = init(env.state_dim, env.u_action_dim, env.u_action_min, env.u_action_max, 
                   action_n, 1, copy.deepcopy(config_agent), threshold=0)
    v_agent = init(env.state_dim, env.v_action_dim, env.v_action_min, env.v_action_max, 
                   action_n, 1, copy.deepcopy(config_agent), threshold=0)
    
    u_agent.q_model.load_state_dict(torch.load(os.path.join(path, 'u_q_model')))
    v_agent.q_model.load_state_dict(torch.load(os.path.join(path, 'v_q_model')))
    return u_agent, v_agent
