# -*- coding: utf-8 -*-
import os
import random
import argparse
from datetime import datetime
import numpy as np
from tqdm import tqdm
import torch
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter

from pettingzoo.mpe import simple_tag_v3
from smac.env import StarCraft2Env

from models.mappo import MAPPO
from models.rppo import PPO_RNN
from models.multi_obj_ppo import PPO_MO
from models.policy_encoder import StylePredictor, Policy_Encoder

from utils.funcs import z_sample, logit2onehot, get_available_gpu, set_seed, policy_encoder_data_process
from utils.replay_buffer import ReplayBufferDict_MA, ReplayBufferDict


def parse_args():
    parser = argparse.ArgumentParser(description="Robust MARL Adversarial Training")

    # Environment and Model Parameters
    parser.add_argument('--env_type', type=str, default='mpe', choices=['mpe', 'smac'], help='env type')
    parser.add_argument('--env_name', type=str, default='simple_tag_v3', help='env name')
    parser.add_argument('--map_name', type=str, default='3m', help='SMAC map name')
    parser.add_argument('--learning_rate_encoder', type=float, default=0.0003)
    parser.add_argument('--lr_rl_actor', type=float, default=0.0005)
    parser.add_argument('--lr_rl_critic', type=float, default=0.0005)
    parser.add_argument('--gamma', type=float, default=0.98)
    parser.add_argument('--lmbda', type=float, default=0.95)
    parser.add_argument('--eps_clip', type=float, default=0.1)
    parser.add_argument('--K_epochs', type=int, default=30)
    parser.add_argument('--T_horizon', type=int, default=25)
    parser.add_argument('--update_episode', type=int, default=50)
    parser.add_argument('--update_encoder', type=int, default=10)

    # Training Schedule
    parser.add_argument('--alternate_interval', type=int, default=10000)
    parser.add_argument('--total_episode', type=int, default=200000)

    # Style Parameters
    parser.add_argument('--style_num', type=int, default=3)
    parser.add_argument('--alpha_style', type=float, default=0.5)

    # Policy Encoder
    parser.add_argument('--pooling', type=str, default='att', choices=['mean', 'att', 'cls'])
    parser.add_argument('--ecd_output_dim', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--save_freq', type=int, default=10000)
    parser.add_argument('--print_interval', type=int, default=100)

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--warmup', action='store_true', default=True)

    return parser.parse_args()


def init_env_and_models(args, device):
    if args.env_type == 'mpe':
        env = simple_tag_v3.parallel_env()
        coop_names = ['adversary_0', 'adversary_1']
        adv_name = 'adversary_2'
        prey_name = 'agent_0'
        agent_names = coop_names + [adv_name, prey_name]

        obs_dim = 16
        prey_obs_dim = 14
        action_dim = 5
        policy_loc_dim = 6
        policy_glb_dim = 10
        style_vec_dim = args.style_num
        state_dim = obs_dim * 2
        coop_agents = MAPPO(obs_dim + policy_loc_dim, state_dim + policy_glb_dim,
                            action_dim, policy_loc_dim, args.lr_rl_actor, args.lr_rl_critic,
                            args.gamma, args.K_epochs, args.eps_clip, False, device, policy_rep=True)
        adv_agent = PPO_MO(obs_dim + style_vec_dim, action_dim, device)
        prey_model = PPO_RNN(prey_obs_dim, action_dim, device)

    else:  # SMAC
        env = StarCraft2Env(map_name=args.map_name, difficulty='7', seed=args.seed)
        env_info = env.get_env_info()
        coop_names = [f"ally_{i}" for i in range(env_info["n_agents"]) if i != 0]
        adv_name = "ally_0"
        agent_names = [f"ally_{i}" for i in range(env_info["n_agents"])]

        obs_dim = env_info["obs_shape"]
        state_dim = env_info["state_shape"]
        action_dim = env_info["n_actions"]
        policy_loc_dim = 6
        policy_glb_dim = 10
        style_vec_dim = args.style_num

        coop_agents = MAPPO(obs_dim + policy_loc_dim, state_dim + policy_glb_dim,
                            action_dim, policy_loc_dim, args.lr_rl_actor, args.lr_rl_critic,
                            args.gamma, args.K_epochs, args.eps_clip, False, device, policy_rep=True)
        adv_agent = PPO_MO(obs_dim + style_vec_dim, action_dim, device)
        prey_model = None  # SMAC does not use a separate prey model

    style_predictor = StylePredictor(obs_dim + action_dim, args.ecd_output_dim, args.style_num,
                                     args.learning_rate_encoder, device)
    policy_encoder = Policy_Encoder(obs_dim, policy_loc_dim, policy_glb_dim, args.learning_rate_encoder, device)
    past_policy_encoder = Policy_Encoder(obs_dim, policy_loc_dim, policy_glb_dim, args.learning_rate_encoder, device)
    past_policy_encoder.load_state_dict(policy_encoder.state_dict())
    past_policy_encoder.eval()  # set the past encoder to evaluation mode
    policy_encoder_target = Policy_Encoder(obs_dim, policy_loc_dim, policy_glb_dim, args.learning_rate_encoder, device)
    policy_encoder_target.load_state_dict(policy_encoder.state_dict())

    return env, coop_agents, adv_agent, prey_model, style_predictor, policy_encoder, policy_encoder_target, past_policy_encoder, \
            agent_names, coop_names, adv_name, obs_dim, action_dim

def adversarial_training(args, env, coop_agents, adv_agent,
                         prey_model, style_predictor, policy_encoder,
                         policy_encoder_target, past_policy_encoder,
                         device, writer,
                         agent_names, coop_names, adv_name,
                         pred_obs_dim, action_dim, style_num,
                         ):

    traj_buffer_encoder = ReplayBufferDict_MA()
    traj_buffer_predictor = ReplayBufferDict()

    score_history_coop = []
    score_history_adv = []
    score_history_predict = []

    for n_epi in range(args.total_episode):
        s, infos = env.reset()
        if args.env_type == "smac":
            s = {name: s[i] for i, name in enumerate(agent_names)}

        done = {name: False for name in agent_names}

        # whether train adversarial agent
        train_adv = (n_epi // args.alternate_interval) % 2 == 0
        epoch_idx = int(n_epi // (args.alternate_interval * 2))

        # LSTM's hidden states
        h_out_adv = (torch.zeros([1, 1, 32], dtype=torch.float).to(device),
                     torch.zeros([1, 1, 32], dtype=torch.float).to(device))
        h_out_prey = (torch.zeros([1, 1, 32], dtype=torch.float).to(device),
                      torch.zeros([1, 1, 32], dtype=torch.float).to(device))

        score_coop = 0.0
        score_adv = 0.0
        score_predict = 0.0

        # style embedding
        z, style_idx = z_sample(style_num, 1, device)
        z = z.detach().cpu().numpy().flatten()
        log_p_z = -torch.log(torch.tensor(style_num, dtype=torch.float, device=device))

        # store current episode trajectory
        currt_ep_traj_coop = {name: np.zeros((args.T_horizon, pred_obs_dim), dtype=np.float32)
                              for name in coop_names}
        currt_ep_traj_adv = np.zeros((args.T_horizon, pred_obs_dim + action_dim), dtype=np.float32)
        policy_rep = {name: 0 for name in coop_names}

        # ================= Rollout =================
        for t in range(args.T_horizon):
            actions = {}

            for agent in agent_names:
                if agent in coop_names:
                    # Cooperative agents
                    currt_ep_traj_coop[agent][t] = s[agent]
                    policy_rep[agent] = policy_encoder_target.local_encoder(
                        torch.from_numpy(currt_ep_traj_coop[agent]).float().to(device).unsqueeze(0)
                    )
                    s_c = np.concatenate((s[agent], policy_rep[agent].detach().cpu().numpy().flatten()))
                    agent_index = agent.split('_')[-1]
                    actions[agent] = coop_agents.select_action(torch.from_numpy(s_c).float(), agent_index)

                elif agent == adv_name:
                    # Adversarial agent
                    s_z = np.concatenate((s[agent], z))
                    prob_adv, h_out_adv = adv_agent.pi(torch.from_numpy(s_z).float().to(device), h_out_adv)
                    if prob_adv.dim() > 2:
                            prob_adv = prob_adv.squeeze(dim=1)
                    actions[agent] = Categorical(prob_adv.view(-1)).sample().item()
                    action_adv = logit2onehot(actions[agent], action_dim).detach().squeeze().numpy()
                    currt_ep_traj_adv[t] = np.concatenate((s[agent], action_adv))

                else:
                    # Prey or other allies
                    prob_prey, h_out_prey = prey_model.pi(torch.from_numpy(s[agent]).float().to(device), h_out_prey)
                    actions[agent] = Categorical(prob_prey.view(-1)).sample().item()

            s_prime, r, done, truncated, info = env.step(actions)
            if args.env_type == "smac":
                s_prime = {name: s_prime[i] for i, name in enumerate(agent_names)}

            if t == args.T_horizon - 1:
                done = {name: True for name in agent_names}

            # ===== Buffer Update =====
            for key in agent_names:
                if key in coop_names and not train_adv:
                    agent_index = key.split('_')[-1]
                    coop_agents.buffers[agent_index].rewards.append(r[key])
                    coop_agents.buffers[agent_index].is_terminals.append(done[key])
                    if key == coop_names[0]: 
                        coop_global_state = np.concatenate([s[n] for n in coop_names])
                        pred_traj_global = np.hstack(currt_ep_traj_coop.values())
                        _, global_rep_curr, _ = policy_encoder_target(
                            torch.from_numpy(pred_traj_global).float().to(device).unsqueeze(0),
                            len(coop_names)
                        )
                        coop_agents.policy_rep_buffer.policy_rep.append(global_rep_curr.detach())
                        traj_buffer_encoder.add_transition(coop_global_state, done[adv_name], style_idx)

                elif key == adv_name and train_adv:
                    # Style reward
                    style_index_generate = style_idx % style_num
                    traj_tensor_adv = torch.tensor(currt_ep_traj_adv, device=device).unsqueeze(0)
                    style_logits_predict, _ = style_predictor(traj_tensor_adv)
                    style_idx_tensor = torch.tensor(style_index_generate, dtype=torch.long, device=device)
                    r_style = Categorical(logits=style_logits_predict).log_prob(style_idx_tensor) - log_p_z
                    r_style = args.alpha_style * r_style
                    r_task = -r[key]
                    r_lst = [r_task, r_style]
                    adv_agent.put_data((
                        s_z, actions[key], r_lst,
                        np.concatenate((s_prime[key], z)),
                        prob_adv[:, actions[key]].item(),
                        h_out_adv, h_out_adv, done[key]
                    ))
                
                    traj_buffer_predictor.add_transition(currt_ep_traj_adv[t], done[adv_name], style_index_generate)

            s = s_prime
            score_coop += r[coop_names[0]]
            score_adv += r_task
            score_predict += r_style.detach().cpu().numpy()

        score_history_coop.append(score_coop)
        score_history_adv.append(score_adv)
        score_history_predict.append(score_predict)

        # ================= Update Models =================
        if (n_epi + 1) % args.update_encoder == 0:
            if train_adv:
                # Train style predictor
                sampled_episode = traj_buffer_predictor.sample_episodes(args.batch_size)
                if sampled_episode is not None:
                    batch_traj, batch_idx, _ = sampled_episode
                    batch_traj = torch.tensor(batch_traj, dtype=torch.float, device=device)
                    batch_idx = torch.tensor(batch_idx, dtype=torch.long, device=device)
                    loss_z = style_predictor.train_step(batch_traj, batch_idx)
                    writer.add_scalar("Loss/Style prediction loss", loss_z, n_epi * args.T_horizon)
            else:
                # Train policy encoder
                sampled_traj = traj_buffer_encoder.sample_episodes(args.batch_size)
                if sampled_traj is not None:
                    traj, contrast_label, _ = policy_encoder_data_process(sampled_traj, device)

                    # load past policy encoder
                    if epoch_idx >= 1:
                        past_encoder_path = f'policy_encoder_{epoch_idx-1}.pth'
                        if os.path.exists(past_encoder_path):
                            past_policy_encoder.load_model(past_encoder_path)
                            past_policy_encoder.eval()

                    total_loss = policy_encoder.train_step(
                        traj, contrast_label, len(agent_names),
                        past_encoder=past_policy_encoder if epoch_idx >= 1 else None
                    )
                    writer.add_scalar("Loss/Total encoder contrastive loss", total_loss, n_epi * args.T_horizon)

            # Momentum update target encoder
            def momentum_update(target_net, online_net, momentum=0.995):
                for pt, po in zip(target_net.parameters(), online_net.parameters()):
                    pt.data = momentum * pt.data + (1.0 - momentum) * po.data
            momentum_update(policy_encoder_target, policy_encoder, momentum=0.995)

        # ================= Update RL =================
        if (n_epi + 1) % args.update_episode == 0:
            if not train_adv:
                coop_agents.update()
                adv_agent.data = []
                traj_buffer_predictor.clear()
            else:
                adv_agent.train()
                adv_agent.train_net()
                coop_agents.clear_buffer()
                traj_buffer_encoder.clear()

        # ================= Logging =================
        if n_epi % args.print_interval == 0 and n_epi != 0:
            avg_score1 = np.mean(score_history_coop[-args.print_interval:])
            avg_score2 = np.mean(score_history_adv[-args.print_interval:])
            avg_score3 = np.mean(score_history_predict[-args.print_interval:])
            print(f"[Episode {n_epi}] Coop Score: {avg_score1:.2f} | Adv Score: {avg_score2:.2f} | Style Reward: {avg_score3:.2f}")

        # ================= Save Models =================
        if (n_epi + 1) % args.save_freq == 0:
            torch.save(coop_agents.state_dict(), f'coop_agents_{epoch_idx}.pth')
            torch.save(adv_agent.state_dict(), f'adv_agent_{epoch_idx}.pth')
            policy_encoder.save_model(f'policy_encoder_{epoch_idx}.pth')
            style_predictor.save_model('style_predictor.pth')


# ================= main =================
def main():
    args = parse_args()
    set_seed(args.seed)

    available_device = get_available_gpu()
    device = torch.device(available_device)

    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_name = f"{args.env_name}_style_num_{args.style_num}_{current_time}"
    base_log_dir = f"./result/tensorboard/{experiment_name}"
    os.makedirs(base_log_dir, exist_ok=True)
    writer = SummaryWriter(base_log_dir)

    env, coop_agents, adv_agent, prey_model, style_predictor, policy_encoder, policy_encoder_target, policy_encoder_past, \
    agent_names, coop_names, adv_name, obs_dim, action_dim = init_env_and_models(args, device)
    print(f"Environment: {args.env_name}, Agents: {len(agent_names)}, Coop Agents: {len(coop_names)}, Adversarial Agent: {adv_name}")
    
    adversarial_training(args, env, coop_agents, adv_agent,
                         prey_model, style_predictor, policy_encoder,
                         policy_encoder_target, policy_encoder_past,
                         device, writer,
                         agent_names, coop_names, adv_name,
                         obs_dim, action_dim, args.style_num,
                         )


if __name__ == "__main__":
    main()

