import os, sys, numpy as np, torch.nn as nn
from algs.utilities.Noises import DiscreteUniformNoise
from algs.utilities.ContinuousAgentsMaker import ContinuousAgentsMaker
from algs.utilities.ZSQModels import ZSQModel_United
from algs.utilities.Buffers.ExperienceReplayBuffer import ExperienceReplayBuffer
from running.utils import divide_config, save_models, load_models, get_action_values
from running.create_model import create_model
from solvers.ZeroSumSolver import ZeroSumSolver


def init(env, timesteps, config, threshold=1):
    from algs.DIDQN import DIDQN

    action_n = config.pop('action_n')
    u_action_values = get_action_values(env.u_action_dim, env.u_action_min, env.u_action_max, action_n)
    v_action_values = get_action_values(env.v_action_dim, env.v_action_min, env.v_action_max, action_n)
    DIDQN = ContinuousAgentsMaker(DIDQN)
    
    model_structure = '256128'
    if 'model_structure' in config:
        model_structure = config.pop('model_structure')
    u_q_model = create_model(env.state_dim, model_structure, action_n)
    v_q_model = create_model(env.state_dim, model_structure, action_n)
    
    u_noise = DiscreteUniformNoise(action_n, threshold_decrease=1/timesteps,
                                   threshold=threshold)
    v_noise = DiscreteUniformNoise(action_n, threshold_decrease=1/timesteps,
                                   threshold=threshold)
    agents = DIDQN(u_q_model, v_q_model, u_noise, v_noise, 
                    u_action_values=u_action_values, v_action_values=v_action_values, **config)
    
    return agents


def learn(env, timesteps, config):
    from algs.DIDQN import DIDQN
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, DIDQN, adding_args)
    agents = init(env, timesteps, config_agent)
    solver = ZeroSumSolver(env, agents, **config_solver)
    solver.go(timesteps)
    info = {'trs': solver.total_rewards, 'tts': solver.total_timesteps}
    return agents, info


def save(agents, info, path):
    save_models(agents, ['u_q_model', 'v_q_model'], path)
    np.save(os.path.join(path, 'trs.npy'), info['trs'])
    np.save(os.path.join(path, 'tts.npy'), info['tts'])
    return None


def load(env, config, path):
    from algs.DIDQN import DIDQN
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, DIDQN, adding_args)
    agents = init(env, 1, config_agent, threshold=0)
    load_models(agents, ['u_q_model', 'v_q_model'], path)
    return agents.u_agent, agents.v_agent
