import os, sys, numpy as np, torch.nn as nn, copy
from algs.utilities.SequentialNetwork import SequentialNetwork
from algs.utilities.Noises import DiscreteUniformNoise
from algs.utilities.ZSQModels import ZSQModel_United
from running.utils import divide_config, save_models, load_models, get_action_values
from running.create_model import create_model
from solvers.ZeroSumSolver import ZeroSumSolver


def init(env, timesteps, config, pure_agent, threshold=1):
    from algs.CounterDQN import CounterDQN
    from algs.utilities.ContinuousAgentsMaker import ContinuousAgentsMaker

    action_n = config.pop('action_n')
    u_action_values = get_action_values(env.u_action_dim, env.u_action_min, env.u_action_max, action_n)
    v_action_values = get_action_values(env.v_action_dim, env.v_action_min, env.v_action_max, action_n)
    CounterDQN = ContinuousAgentsMaker(CounterDQN)
    
    model_structure = '256128'
    if 'model_structure' in config:
        model_structure = config.pop('model_structure')
    
    pre_q_model = create_model(env.state_dim, model_structure, action_n ** 2)
    q_model = ZSQModel_United(pre_q_model, action_n, action_n)
    
    u_noise = DiscreteUniformNoise(action_n, threshold_decrease=1/timesteps, threshold=threshold)
    v_noise = DiscreteUniformNoise(action_n, threshold_decrease=1/timesteps, threshold=threshold)
    agents = CounterDQN(q_model, u_noise, v_noise, pure_agent,
                 u_action_values=u_action_values, v_action_values=v_action_values, **config)
    return agents


def learn(env, timesteps, config):
    from algs.CounterDQN import CounterDQN
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, CounterDQN, adding_args)
    
    
    u_agents = init(env, int(0.5 * timesteps), copy.deepcopy(config_agent), pure_agent='u')
    solver = ZeroSumSolver(env, u_agents, **config_solver)
    solver.go(int(0.5 * timesteps))
    total_rewards = solver.total_rewards
    total_timesteps = solver.total_timesteps
    
    v_agents = init(env, int(0.5 * timesteps), copy.deepcopy(config_agent), pure_agent='v')
    solver = ZeroSumSolver(env, v_agents, **config_solver)
    solver.go(int(0.5 * timesteps))
    total_rewards.extend(solver.total_rewards)
    total_timesteps.extend(solver.total_timesteps)
    
    agents = [u_agents, v_agents]
    info = {'trs': total_rewards, 'tts': total_timesteps}
    return agents, info


def save(agents, info, path):
    u_agents, v_agents = agents
    save_models(u_agents, ['q_model'], path, 'u_')
    save_models(v_agents, ['q_model'], path, 'v_')
    np.save(os.path.join(path, 'trs.npy'), info['trs'])
    np.save(os.path.join(path, 'tts.npy'), info['tts'])
    return None


def load(env, config, path):
    from algs.CounterDQN import CounterDQN
    adding_args = ['action_n', 'model_structure']
    config_agent, config_solver = divide_config(config, CounterDQN, adding_args)
    
    u_agents = init(env, 1, copy.deepcopy(config_agent), pure_agent='u', threshold=0)
    v_agents = init(env, 1, copy.deepcopy(config_agent), pure_agent='v', threshold=0)

    load_models(u_agents, ['q_model'], path, 'u_')
    load_models(v_agents, ['q_model'], path, 'v_')
    return u_agents.u_agent, v_agents.v_agent
