import os, sys, torch.nn as nn, numpy as np
from running.utils import divide_config, save_models, load_models
from running.create_model import create_model
from algs.MADDPG import MADDPG
from algs.utilities.SequentialNetwork import SequentialNetwork
from algs.utilities.Noises import OUNoise
from solvers.ZeroSumSolver import ZeroSumSolver


def init(env, timesteps, config, threshold=1):
    q_input_dim = env.state_dim + env.u_action_dim + env.v_action_dim
    
    model_structure = '256128'
    if 'model_structure' in config:
        model_structure = config.pop('model_structure')
    
    u_q_model = create_model(q_input_dim, model_structure, 1)
    v_q_model = create_model(q_input_dim, model_structure, 1)
    u_pi_model = create_model(env.state_dim, model_structure, env.u_action_dim, nn.Tanh())
    v_pi_model = create_model(env.state_dim, model_structure, env.v_action_dim, nn.Tanh())
    
    u_noise = OUNoise(env.u_action_dim, threshold_decrease=1/timesteps, 
                      sigma=0.5, threshold=threshold)
    v_noise = OUNoise(env.v_action_dim, threshold_decrease=1/timesteps, 
                      sigma=0.5, threshold=threshold)
    
    return MADDPG(env.u_action_min, env.u_action_max, env.v_action_min, env.v_action_max, 
                  u_q_model, v_q_model, u_pi_model, v_pi_model, u_noise, v_noise, **config)


def learn(env, timesteps, config):
    adding_args = ['model_structure']
    config_agent, config_solver = divide_config(config, MADDPG, adding_args)
    agents = init(env, timesteps, config_agent)
    solver = ZeroSumSolver(env, agents, **config_solver)
    solver.go(timesteps)
    info = {'trs': solver.total_rewards, 'tts': solver.total_timesteps}
    return agents, info


def save(agents, info, path):
    save_models(agents, ['u_q_model', 'v_q_model', 'u_pi_model', 'v_pi_model'], path)
    np.save(os.path.join(path, 'trs.npy'), info['trs'])
    np.save(os.path.join(path, 'tts.npy'), info['tts'])
    return None


def load(env, config, path):
    adding_args = ['model_structure']
    config_agent, config_solver = divide_config(config, MADDPG, adding_args)
    agents = init(env, 1, config_agent, threshold=0)
    load_models(agents, ['u_q_model', 'v_q_model', 'u_pi_model', 'v_pi_model'], path)
    return agents.u_agent, agents.v_agent
