import torch
import numpy as np
from tqdm import tqdm
from config import Args
from policy import Actor
from envs import VectorizedEnv


class Runner:

    def __init__(self, envs: VectorizedEnv, config: Args, device="cuda"):
        self.envs = envs
        self.envs.reset()
        # ob_dim, st_dim, ac_dim, n_agents, n_enemies = self.envs.get_env_infos()
        # self.ob_dim = ob_dim
        # self.st_dim = st_dim
        # self.ac_dim = ac_dim
        # self.n_agents = n_agents
        # self.n_enemies = n_enemies
        # self.n_envs = envs.n_envs

        self.ob_dim = config.ob_dim
        self.st_dim = config.st_dim
        self.ac_dim = config.ac_dim
        self.n_agents = config.n_agents
        self.n_enemies = config.n_enemies
        self.n_envs = config.n_envs
        self.device = device

        self.envs.curr_states_async()

    def collect(self, actor: Actor, n_steps, rnn_state, desc=None, verbose=False):
        actor.eval()
        
        n_envs = self.n_envs
        n_agents = self.n_agents
        n_batches = n_envs * n_agents
        h_dim = actor.h_dim
        ob_dim = self.ob_dim
        st_dim = self.st_dim
        ac_dim = self.ac_dim
        
        all_obs = np.zeros((n_steps, n_envs, n_agents, ob_dim), dtype=np.float32)
        all_states = np.zeros((n_steps, n_envs, st_dim), dtype=np.float32)
        all_avails = np.zeros((n_steps, n_envs, n_agents, ac_dim), dtype=bool)
        all_next_obs = np.zeros((n_steps, n_envs, n_agents, ob_dim), dtype=np.float32)
        all_next_states = np.zeros((n_steps, n_envs, st_dim), dtype=np.float32)
        all_next_avails = np.zeros((n_steps, n_envs, n_agents, ac_dim), dtype=bool)
        all_actions = np.zeros((n_steps, n_envs, n_agents), dtype=np.int64)
        all_log_probs = np.zeros((n_steps, n_envs, n_agents), dtype=np.float32)
        all_rewards = np.zeros((n_steps, n_envs), dtype=np.float32)
        all_dones = np.zeros((n_steps, n_envs), dtype=bool)
        all_infos = []

        for step in tqdm(range(n_steps), desc, ncols=80, leave=False, disable=not verbose):
            obs_np, states_np, avails_np = self.envs.curr_states_wait()
            
            obs = torch.tensor(obs_np, dtype=torch.float32).to(self.device)
            avails = torch.tensor(avails_np, dtype=torch.bool).to(self.device)
            actions, log_probs, rnn_state = actor.sample(obs, avails, rnn_state)

            actions_np = actions.cpu().numpy()
            self.envs.step_async(actions_np)
            log_probs_np = log_probs.cpu().numpy()

            all_obs[step] = obs_np
            all_states[step] = states_np
            all_avails[step] = avails_np
            all_actions[step] = actions_np
            all_log_probs[step] = log_probs_np

            next_obs_np, next_states_np, next_avails_np, rewards_np, dones_np, infos = self.envs.step_wait()
            self.envs.curr_states_async()

            if any(dones_np):
                h_state, c_state = rnn_state
                h_state = h_state.view(n_envs, n_agents, h_dim)
                c_state = c_state.view(n_envs, n_agents, h_dim)
                h_state[dones_np] = 0.0
                c_state[dones_np] = 0.0
                h_state = h_state.view(1, n_batches, h_dim)
                c_state = c_state.view(1, n_batches, h_dim)
                rnn_state = (h_state, c_state)
            
            all_next_obs[step] = next_obs_np
            all_next_states[step] = next_states_np
            all_next_avails[step] = next_avails_np
            all_rewards[step] = rewards_np
            all_dones[step] = dones_np
            all_infos.append(infos)
        
        actor.train()
        return all_obs, all_states, all_avails, all_next_obs, all_next_states, all_next_avails, all_actions, all_log_probs, all_rewards, all_dones, all_infos

    def evaluate(self, envs: VectorizedEnv, actor: Actor, n_episodes=32, verbose=False):
        actor.eval()
        envs.reset()

        n_envs = envs.n_envs
        n_agents = self.n_agents
        h_dim = actor.h_dim
        
        rnn_state = (torch.zeros(1, n_envs * n_agents, h_dim, device=self.device),
                     torch.zeros(1, n_envs * n_agents, h_dim, device=self.device))
        all_infos = []

        p_bar = tqdm(desc="Evaluating", total=n_episodes, ncols=80, leave=False, disable=not verbose)
        while len(all_infos) < n_episodes:
            obs_np, _, avails_np = envs.get_current_states()

            obs = torch.tensor(obs_np, dtype=torch.float32).to(self.device)
            avails = torch.tensor(avails_np, dtype=torch.bool).to(self.device)
            actions, _, rnn_state = actor.sample(obs, avails, rnn_state, deterministic=True)
            actions_np = actions.cpu().numpy()

            _, _, _, _, dones_np, infos = envs.step(actions_np)
            
            if any(dones_np):
                h_state, c_state = rnn_state
                h_state = h_state.view(n_envs, n_agents, h_dim)
                c_state = c_state.view(n_envs, n_agents, h_dim)
                h_state[dones_np] = 0.0
                c_state[dones_np] = 0.0
                h_state = h_state.view(1, n_envs * n_agents, h_dim)
                c_state = c_state.view(1, n_envs * n_agents, h_dim)
                rnn_state = (h_state, c_state)

            for info in infos:
                if len(info) == 0:
                    continue
                all_infos.append(info)
                p_bar.update()
        p_bar.close()

        dead_allies = np.mean([info['dead_allies'] for info in all_infos])
        dead_enemies = np.mean([info['dead_enemies'] for info in all_infos])
        winrates = np.mean([info['won'] for info in all_infos])

        return dead_allies, dead_enemies, winrates
