import torch
import numpy as np
from gym.spaces import Box
from torch.autograd import Variable
from tqdm import trange

from marl import MARL
from config import *

if cfg.random_imputation:       # Random imputation for ablation study
    from random_imputation_buffer import ReplayBuffer
else:                           # GAIN imputation
    from gain_buffer import ReplayBuffer

"""
    Using mask-based update, the number of training data is less than the number of
    collected data. IA-MARL trains network less frequently.
    e.g., MADDPG update every cfg.freq_update steps, while IA-MARL update every
    1 / (1 - cfg.drop_p) ** 2 * cfg. freq_update steps
    Similarly, pretraining and exploration frequencies are scaled.
"""
if not cfg.unscale:
    scale = 1 / (1 - cfg.drop_p) ** cfg.mask_type       # Frequency scale
    cfg.freq_update = int(scale * cfg.freq_update)      # Network update scale
    cfg.n_epi_exp = int(scale * cfg.n_epi_exp)          # Exploaration scale
    g_max_step = 50 * scale                             #
else:
    cfg.freq_update = int(1 * cfg.freq_update)
    cfg.n_epi_exp = int(1 * cfg.n_epi_exp)
    g_max_step = 50 * 1

assert cfg.unscale == False       # Do not use unscale option


# Select action and update replay buffer
def step(all_obs, algo, drop_p=0, rep_buf=None, explore=False):
    imputation_loss = 0
    r_obs = [Variable(torch.Tensor([all_obs[a_j]]), requires_grad=False)
             for a_j in range(algo.n_agent)]
    r_act = algo.step(r_obs, explore=explore)           # Action
    all_act = [act[0].data.numpy() for act in r_act]
    all_n_obs, all_rew, all_don, _ = env.step(all_act)
    all_mask = np.random.rand(algo.n_agent) > drop_p    # With probability drop_p, data is missed
    if rep_buf is not None:
        imputation_loss = rep_buf.push_imputation(all_obs, all_act, all_rew, all_n_obs, all_don, all_mask)
    return all_n_obs, all_rew, (all_obs, all_act, all_rew, all_n_obs,
                                all_don, all_mask), imputation_loss


# Training MARL algorithm using completed data
def update(cfg, algo, rep_buf, logger=None):
    algo.post_rollout('cuda')           # Use GPU for the training.
    for a_j in range(algo.n_agent):
        # Samples training data for the agent a_j.
        # When cfg.mask_type is 2, mask-based update (i.e., IA-MARL)
        # When cfg.mask_type is 0, Simple-IA-MARL.
        sample = rep_buf.sample_mask(cfg.size_batch, a_j, mask_type=cfg.mask_type, device='cuda')
        algo.update(sample, a_j, logger=logger)     # MARL critic and policy update
    algo.update_all_targets()                       # Target networks update
    algo.prep_rollout('cpu')


# Initialize for the pretraining of GAIN
algo_null = MARL.init_from_env(env, agent_algo=cfg.agent_algo, adv_algo=cfg.adv_algo,
                               tau=cfg.tau, lr=cfg.lr, hidden_dim=cfg.hidden_dim)
algo = MARL.init_from_env(env, agent_algo=cfg.agent_algo, adv_algo=cfg.adv_algo,
                          tau=cfg.tau, lr=cfg.lr, hidden_dim=cfg.hidden_dim)
rep_buf = ReplayBuffer(cfg.len_buf, algo.n_agent, [obsp.shape[0] for obsp in env.observation_space],
          [acsp.shape[0] if isinstance(acsp, Box) else acsp.n for acsp in env.action_space],
          cfg.buf_train, cfg.buf_flush, g_max_step=g_max_step)
y_vals = np.zeros((len(range(0, cfg.n_epi, cfg.save_interval)), algo.n_agent))

# Training of IA-MARL (Algorithm 2)
# Pretrain GAIN
# We collectes data through cfg.pre_epi epsidoes to train GAIN
t_null = 0
for ep_pre in trange(0, cfg.pre_epi):
    all_obs = env.reset()
    algo_null.scale_noise(cfg.s_noise / (max(0, cfg.n_epi_exp - ep_pre) / cfg.n_epi_exp + 1e6))
    for et_i in range(cfg.len_epi):
        # Collect data and train GAIN
        all_obs, all_rew, _, imputation_loss = step(all_obs, algo_null, cfg.drop_p, rep_buf=rep_buf, explore=True)
        t_null += 1
        if (len(rep_buf) >= cfg.size_batch) and (t_null % cfg.freq_update == 0):
            # Update MARL algorithm
            update(cfg, algo_null, rep_buf, logger=logger)
        if cfg.imputation_loss:     # Check imputation loss
            if imputation_loss != 0:
                pre_gain[0].append(ep_pre); pre_gain[1].append(imputation_loss)
# Initialize replay buffer, excepted pre-trained GAIN
rep_buf.curr_i, rep_buf.fill_i, rep_buf.curr_g, rep_buf.fill_g = 0, 0, 0, 0

# MARL training step
t = 0
for ep_i in trange(0, cfg.n_epi):
    all_obs = env.reset(); tmp_rew = np.zeros(algo.n_agent)
    algo.scale_noise(cfg.s_noise / (max(0, cfg.n_epi_exp - ep_i) / cfg.n_epi_exp + 1e6))

    for et_i in range(cfg.len_epi):
        # Collect data and train GAIN
        all_obs, all_rew, _, imputation_loss = step(all_obs, algo, cfg.drop_p, rep_buf=rep_buf, explore=True)
        t += 1; tmp_rew += np.array(all_rew)

        if (len(rep_buf) >= cfg.size_batch) and (t % cfg.freq_update == 0):
            # Update MARL algorithm using mask-based update
            update(cfg, algo, rep_buf, logger=logger)

        if cfg.imputation_loss:     # Check imputation loss
            if imputation_loss != 0:
                post_gain[0].append(ep_i); post_gain[1].append(imputation_loss)
    for a_j in range(algo.n_agent):
        logger.add_scalar('a%i/mer' % a_j, tmp_rew[a_j] / cfg.len_epi, ep_i)

    if ep_i % cfg.save_interval == 0:   # For every save interval, test MARL algorithm
        tmp_rew = np.zeros((algo.n_agent, cfg.save_iter))
        for eval_ep in range(cfg.save_iter):
            all_obs = env.reset()
            for et_i in range(cfg.len_epi):
                all_obs, all_rew, _, imputation_loss = step(all_obs, algo)
                tmp_rew[:, eval_ep] += np.array(all_rew)
        y_vals[ep_i // cfg.save_interval, :] = tmp_rew.mean(axis=1)
        np.save(eval_dir / '{}_R{}'.format(cfg.model_name, cfg.seed), y_vals)

logger.export_scalars_to_json(str(log_dir / 'summary.json')) ;logger.close()
env.close()

# DEBUG, Save impuration loss for the check.
# np.save(gain_path / 'pre_{}_{}'.format(cfg.model_name, cfg.seed), np.array(pre_gain))
# np.save(gain_path / 'post_{}_{}'.format(cfg.model_name, cfg.seed), np.array(post_gain))
