import os
import time
import torch
import numpy as np
from tqdm import tqdm
from itertools import count
from trainer.trainer import ALGOS
from trainer.buffer import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from trainer.utils import analyze_logs, silence_stderr


class Runner_QMIX_SMAC(object):

    def __init__(self, args, env_name, env, seed, n_expert_episodes, algo=None):
        self.args = args
        self.env_name = env_name
        self.seed = seed
        self.total_steps = 0
        self.n_expert_episodes = n_expert_episodes
        self.device = self.args.device
        self.buffer_size = self.args.buffer_size
        self.algo = algo.lower() if algo else None
        self.set_seed(self.seed)
        self.env = env
        self.env_info = self.env.get_env_info()
        self.n_agents = self.env_info["n_agents"]
        self.n_enemies = self.env_info["n_enemies"] if "n_enemies" in self.env_info else self.env.env.n_enemies
        self.ob_dim = self.env_info["obs_shape"]
        self.st_dim = self.env_info["state_shape"]
        self.ac_dim = self.env_info["n_actions"]
        self.episode_limit = self.env_info["episode_limit"]
        self.agent = ALGOS[algo.upper()](self.n_agents, self.ob_dim, self.st_dim, self.ac_dim, self.args) if algo else None
        folder = f"logs/eps{n_expert_episodes}_seed{seed}/{env_name}"
        os.makedirs(folder, exist_ok=True)
        self.writer = SummaryWriter(f"{folder}/{self.algo}") if algo else None
        self.ex_buffer = None
        self.ex_policy = None
    
    def init(self, expert_episodes_limit=10000, ex_buffer=None, use_expert=True):
        self.ex_buffer = self.collect_expert_buffer(expert_episodes_limit, ex_buffer, verbose=True) if use_expert else None
        if use_expert:
            self.episode_limit = self.ex_buffer.episode_limit
        self.pi_buffer = ReplayBuffer(self.ob_dim, self.st_dim, self.ac_dim, self.n_agents, self.episode_limit, self.buffer_size) if self.algo is not None and self.algo != "bc" else None
        self.win_rates = []
        self.logs_train = []
        self.logs_eval = []
    
    def collect_expert_buffer(self, n_expert_episodes, ex_buffer=None, seed=0, verbose=False, deterministic=True) -> ReplayBuffer:
        expert_folder = f"/common/scratch/users/t/tvbui/expert_policies"
        expert_path = f"{expert_folder}/{self.env_name}_n{n_expert_episodes}_seed{seed}_buffer.pt"
        if ex_buffer is not None:
            self.ex_buffer = ex_buffer
        elif os.path.exists(expert_path) or self.agent is not None:
            while True:
                time.sleep(0.5)
                if os.path.exists(expert_path):
                    if self.agent is None:
                        break
                    try:
                        self.ex_buffer = torch.jit.load(expert_path)
                        break
                    except:
                        pass
            if self.agent is not None:
                print(f"Expert: {self.env_name:.<22} - Load expert buffer successfully!")
        else:
            if self.ex_buffer is None:
                self.ex_buffer = ReplayBuffer(self.ob_dim, self.st_dim, self.ac_dim, self.n_agents, self.episode_limit, n_expert_episodes)
            if self.ex_policy is None:
                self.ex_policy = torch.jit.load(f"expert_policies/{self.env_name}.pt").to(self.device)
            logs = []
            p = tqdm(total=n_expert_episodes, desc=f"Collecting ... {self.env_name}", disable=not verbose, ncols=0, leave=False, mininterval=4)
            seed = 2**16 - self.episode_limit * seed
            for step in range(n_expert_episodes):
                self.set_seed(seed + step)
                with silence_stderr():
                    self.env.reset()
                total_reward = 0
                rnn_actor = torch.zeros((self.n_agents, 1, self.ex_policy.h_dim), device=self.device)
                masks = torch.ones((self.n_agents, 1), dtype=torch.bool, device=self.device)
                for eps_id in range(self.episode_limit):
                    obs = torch.tensor(np.array(self.env.get_obs()), dtype=torch.float32)
                    state = torch.tensor(self.env.get_state(), dtype=torch.float32)
                    avails = torch.tensor(self.env.get_avail_actions())
                    with torch.no_grad():
                        actions, _, rnn_actor = self.ex_policy._forward(obs.to(self.device), rnn_actor, masks, avails.to(self.device), deterministic)
                        actions = actions.squeeze(-1).detach().cpu()
                    with silence_stderr():
                        reward, done, info = self.env.step(actions.numpy())
                    total_reward += reward
                    self.ex_buffer.store_transition(eps_id, obs, state, avails, actions, reward, done and eps_id + 1 != self.episode_limit)
                    if done:
                        info["total_reward"] = total_reward
                        logs.append(info)
                        break
                obs = torch.tensor(np.array(self.env.get_obs()), dtype=torch.float32)
                state = torch.tensor(self.env.get_state(), dtype=torch.float32)
                avails = torch.tensor(self.env.get_avail_actions())
                self.ex_buffer.store_last_step(eps_id+1, obs, state, avails)
                total_reward, dead_allies, dead_enemies, battle_won = analyze_logs(logs, self.n_agents, self.n_enemies)
                p.desc = f"Expert: {self.env_name:.<22} - {dead_allies:.2f}/{dead_enemies:.2f}/{battle_won:.3f}"
                p.update()
            self.ex_buffer.rstrip()
            self.ex_buffer.total_reward, self.ex_buffer.dead_allies, self.ex_buffer.dead_enemies, self.ex_buffer.battle_won = analyze_logs(logs, self.n_agents, self.n_enemies)
            os.makedirs(expert_folder, exist_ok=True)
            self.ex_buffer.save(expert_path)
        if self.ex_buffer is not None:
            self.ex_buffer.limit(self.n_expert_episodes)
            print(f"Expert: {self.env_name:.<22} - buffer size: {self.ex_buffer.buffer_size}")
        if self.writer:
            self.writer.add_scalar("game_expert/total_reward", self.ex_buffer.total_reward, self.total_steps)
            self.writer.add_scalar("game_expert/win_rate", self.ex_buffer.battle_won, self.total_steps)
            self.writer.add_scalar("game_expert/dead_allies", self.ex_buffer.dead_allies, self.total_steps)
            self.writer.add_scalar("game_expert/dead_enemies", self.ex_buffer.dead_enemies, self.total_steps)
        return self.ex_buffer
    
    def run(self, n_logs=512):
        step_id = 0
        loss_1, loss_2 = 0, 0
        task_name = f"{self.env_name}-{self.algo}-{self.n_expert_episodes}"
        self.agent.batch_size = min(self.agent.batch_size, self.buffer_size)
        print("Running ...", task_name)
        while self.total_steps < self.args.max_train_steps:
            self.set_seed(self.seed + self.total_steps)
            if self.algo == "bc":
                self.total_steps += self.env_info["episode_limit"]
                loss_1, loss_2 = self.agent.train(self.ex_buffer)
            else:
                self.run_episode_smac(evaluate=False)
                if self.pi_buffer.current_size >= self.agent.batch_size:
                    optimizer: torch.optim.Adam = self.agent.optimizer
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = self.args.lr * max(1 - self.total_steps / self.args.max_train_steps, 0.1)
                    loss_1, loss_2 = self.agent.train(self.pi_buffer, self.ex_buffer)
            if n_logs * self.total_steps > step_id * self.args.max_train_steps:
                step_id += 1
                self.writer.add_scalar("losses/loss_1", loss_1, self.total_steps)
                self.writer.add_scalar("losses/loss_2", loss_2, self.total_steps)
                if len(self.logs_train) >= 16:
                    total_reward, dead_allies, dead_enemies, battle_won = analyze_logs(self.logs_train, self.n_agents, self.n_enemies)
                    print(f"Map : {task_name:.<30} - step: {self.total_steps:<7} - {dead_allies:.2f}/{dead_enemies:.2f}/{battle_won:.3f}")
                    self.writer.add_scalar("game/total_reward", total_reward, self.total_steps)
                    self.writer.add_scalar("game/win_rate", battle_won, self.total_steps)
                    self.writer.add_scalar("game/dead_allies", dead_allies, self.total_steps)
                    self.writer.add_scalar("game/dead_enemies", dead_enemies, self.total_steps)
                    self.logs_train = []
                if step_id % 4 == 0:
                    self.evaluate_policy()
                    total_reward, dead_allies, dead_enemies, battle_won = analyze_logs(self.logs_eval, self.n_agents, self.n_enemies)
                    print(f"Eval: {task_name:.<30} - step: {self.total_steps:<7} - {dead_allies:.2f}/{dead_enemies:.2f}/{battle_won:.3f}")
                    self.writer.add_scalar("game_eval/total_reward", total_reward, self.total_steps)
                    self.writer.add_scalar("game_eval/win_rate", battle_won, self.total_steps)
                    self.writer.add_scalar("game_eval/dead_allies", dead_allies, self.total_steps)
                    self.writer.add_scalar("game_eval/dead_enemies", dead_enemies, self.total_steps)
                    # self.agent.qmix.save(f"{self.writer.log_dir}/qmix.pt")
                    self.logs_eval = []
        # self.agent.qmix.save(f"{self.writer.log_dir}/qmix.pt")
        try:
            self.env.close()
        except:
            pass
    
    def set_seed(self, seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        try:
            self.env.env._seed = seed
        except:
            pass
        try:
            self.env.seed = seed
        except:
            pass

    def evaluate_policy(self):
        for step in range(self.args.evaluate_times):
            self.set_seed(self.seed + step)
            self.run_episode_smac(evaluate=True)

    def run_episode_smac(self, evaluate=False):
        with silence_stderr():
            self.env.reset()
            self.agent.qmix.eval_Q_net.reset()
            last_onehot_a_n = torch.zeros((self.n_agents, self.ac_dim))
            total_reward = 0
            for episode_step in count():
                is_last_step = False
                if not evaluate:
                    if episode_step >= self.episode_limit-1:
                        is_last_step = True
                        # info["total_reward"] = total_reward
                        # info["battle_won"] = False
                        # self.logs_train.append(info)
                        # break
                    else:
                        self.total_steps += 1
                obs = torch.tensor(np.array(self.env.get_obs()), dtype=torch.float32)
                state = torch.tensor(self.env.get_state(), dtype=torch.float32)
                avails = torch.tensor(self.env.get_avail_actions())
                actions = self.agent.choose_action(obs, last_onehot_a_n, avails, evaluate)
                last_onehot_a_n = torch.eye(self.ac_dim)[actions]
                reward, done, info = self.env.step(actions.numpy())
                total_reward += reward if isinstance(reward, float) else np.mean(reward)
                if not evaluate and not is_last_step:
                    self.pi_buffer.store_transition(episode_step, obs, state, avails, actions, reward, done and episode_step + 1 != self.episode_limit)
                if done or is_last_step:
                    info["total_reward"] = total_reward
                    if "battle_won" not in info:
                        info["battle_won"] = False
                    if evaluate:
                        self.logs_eval.append(info)
                    else:
                        self.logs_train.append(info)
                    break
            if not evaluate:
                obs = torch.tensor(np.array(self.env.get_obs()), dtype=torch.float32)
                state = torch.tensor(self.env.get_state(), dtype=torch.float32)
                avails = torch.tensor(self.env.get_avail_actions())
                self.pi_buffer.store_last_step(episode_step+1, obs, state, avails)
