import logging
import os
from collections import defaultdict, Counter
import fcntl
import random
import configargparse
from typing import List
from functools import reduce

# Necessary for multithreading.
os.environ["OMP_NUM_THREADS"] = "1"

import torch
from torch import multiprocessing as mp
from torch.distributions import Categorical
import wandb
from nle import nethack
import numpy as np
import scipy.stats as stats
from mamba_ssm.utils.generation import InferenceParams
import hydra
from omegaconf import DictConfig, OmegaConf

from il_scale.nethack.utils.setup import DDPUtil, get_wandb_name, create_env
from il_scale.nethack.utils.model import load_checkpoint, count_params
from il_scale.nethack.agent import Agent
from il_scale.nethack.resetting_env import ResettingEnvironment

mp.set_sharing_strategy('file_system') # see https://github.com/pytorch/pytorch/issues/11201

logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=logging.INFO,
)

class RolloutDashboard:
    """
    Class to keep track of any state or analytics
    during an individual rollout.
    """
    def __init__(self):
        self.steps = 0
        self.max_dlvl = 1
        self.temp_act_dict = defaultdict(int)
        self.act_history = []
        self.ppl = 0
        self.top_k_count = defaultdict(int)
    
    def avg_ppl(self):
        """
        Return average perplexity across time steps.
        """
        return self.ppl/self.steps

    def get_metrics(self):
        """
        Return all tracked metrics.
        """
        metrics = dict()
        metrics["avg_ppl"] = self.avg_ppl().item()
        metrics["act_counts"] = self.temp_act_dict
        metrics["steps"] = self.steps
        metrics["max_dlvl"] = self.max_dlvl
        metrics["top_k_count"] = self.top_k_count

        return metrics
    
    def step(self, observation: dict, policy_outputs: torch.tensor, action: int):
        """
        Perform any state/analytics updates.
        """
        # Count actions
        self.temp_act_dict[action] += 1

        # Build action history
        self.act_history.append(action)

        # Update perplexity sum
        self.ppl += torch.exp(Categorical(logits=policy_outputs["policy_logits"].flatten()).entropy())

        # Maybe update max dlvl achieved
        dungeon_level = observation["blstats"][0][0][nethack.NLE_BL_DLEVEL].item()
        if dungeon_level > self.max_dlvl:
            self.max_dlvl = dungeon_level

        # Update top-k count
        p_logits_sort = sorted([(idx, l) for idx, l in enumerate(policy_outputs["policy_logits"].flatten().tolist())], key=lambda x: x[-1], reverse=True)
        for rank, (idx, _) in enumerate(p_logits_sort):
            if action == idx:
                self.top_k_count[rank + 1] += 1
                break

        # Increase step counter
        self.steps += 1

class Rollout:
    """
    A class used to rollout trained models on the NLE environment.
    """
    def __init__(self, config: DictConfig):
        self.config = config
        self.model_flags = self._get_wandb_config_from_id(config.rollout.wandb_id)

    def _get_wandb_config_from_id(self, wandb_id: str):
        """
        Get wandb config from wandb_id.
        """
        model_cfg = OmegaConf.load(os.path.join('models', wandb_id, 'cfg.omega'))

        # HARD CODED EXCEPTIONS DUE TO MODEL RESUMING
        if wandb_id == '2rx53f2b':
            model_cfg = OmegaConf.load(os.path.join('models', '1dj2vzju', 'cfg.omega'))
            model_cfg.network.hdim = 512
            model_cfg.network.tf_num_layers = 5
            model_cfg.network.tf_num_heads = 512 // 64
            
        elif wandb_id == 'ysjiriyk':
            model_cfg = OmegaConf.load(os.path.join('models', '1dj2vzju', 'cfg.omega'))
            model_cfg.network.hdim = 384
            model_cfg.network.tf_num_layers = 4
            model_cfg.network.tf_num_heads = 384 // 64
        
        # add cfg modifications consistent with nethack_config.yaml
        if 'use_message' not in model_cfg.network:
            model_cfg.network.use_message = True
        if 'use_crop' not in model_cfg.network:
            model_cfg.network.use_crop = True
        if 'use_observation' not in model_cfg.network:
            model_cfg.network.use_observation = True
            
        print('use message', model_cfg.network.use_message)
        return model_cfg

    def _agent_setup(self):
        """
        Construct agent and load in weights.
        """
        # Construct agent
        agent = Agent(self.model_flags, None)
        agent.construct_model(self.model_flags)

        logging.info(f"AMP state: {self.model_flags.setup.use_amp}")

        # Load checkpoint & weights
        checkpoint = load_checkpoint(self.config.rollout.model_load_name, self.config.rollout.wandb_id, savedir=self.config.rollout.wandb_load_dir)
        agent.load(checkpoint["model_state_dict"])

        # Put agent in eval model
        agent.model.eval()

        return agent

    def _submit_actor(self, ctx, seed: int, idx: int):
        """
        Submit and return actor idx with given seed.
        """
        actor = ctx.Process(
            target=self._single_rollout,
            args=(
                seed,
                idx
            ),
            name="Actor-%i" % idx,
        )
        actor.start()

        return actor

    def _get_seeds(self):
        """
        Generate num_rollouts number of seeds.
        """
        return random.sample(list(range(int(1e6))), self.config.rollout.num_rollouts)

    def _setup_env(self, ttyrec_save_folder: str, seed: int, device: torch.device = torch.device('cpu')):
        """
        All logic related to setting up the appropriate NLE environment.
        """

        # Setup environment
        if self.config.rollout.env == 'NetHackChallenge-v0':
            gym_env = create_env(
                self.config.rollout.env, 
                save_ttyrec_every=self.config.rollout.save_ttyrec_every,
                savedir=ttyrec_save_folder, # will only save here if save_ttyrec_every is nonzero
                penalty_time=0.0,
                penalty_step=self.config.rollout.rollout_penalty_step,
                max_episode_steps=self.config.rollout.max_episode_steps,
                no_progress_timeout=10_000,
                character=self.config.rollout.rollout_character,
            )
            logging.info(f"Rolling out with {self.config.rollout.rollout_character} ...")
        else:
            gym_env = create_env(
                self.config.rollout.env, 
                save_ttyrec_every=self.config.rollout.save_ttyrec_every,
                savedir=ttyrec_save_folder, # will only save here if save_ttyrec_every is nonzero
                penalty_time=0.0,
                penalty_step=self.config.rollout.rollout_penalty_step,
                max_episode_steps=self.config.rollout.max_episode_steps,
            )

        # Set seed
        if self.config.rollout.env != 'NetHackChallenge-v0':
            gym_env.seed(seed, seed)

        env_keys = ("tty_chars", "tty_colors", "tty_cursor", "blstats", "glyphs", "inv_glyphs", "message")
        env = ResettingEnvironment(
            gym_env, 
            num_lagged_actions=0,
            env_keys=env_keys,
            device=device
        )

        return env

    @torch.no_grad()
    def _single_rollout(self, seed: int, actor_num: int, device: torch.device = torch.device('cpu')):
        """
        Rollout and log relevant objects (observations, actions, returns).
        """
        inference_params = InferenceParams(max_seqlen=self.config.rollout.max_seqlen, max_batch_size=1)

        env = self._setup_env(None, seed, device)

        observation = env.initial()
        observation["prev_action"] = observation["last_action"] # key name conversion

        frame_stack_chars = torch.zeros((1, self.model_flags.network.obs_frame_stack - 1, nethack.TERMINAL_SHAPE[0], nethack.TERMINAL_SHAPE[1])).to(device)
        frame_stack_colors = frame_stack_chars.clone()
        # Zeros are unseen in training, add 32 to make it like end of game frame
        if self.model_flags.network.obs_frame_stack > 1:
            frame_stack_chars += 32

        dashboard = RolloutDashboard()

        while dashboard.steps < self.config.rollout.max_episode_steps:
            # Stack frames
            observation["tty_chars"] = observation["tty_chars"].unsqueeze(2)
            observation["tty_colors"] = observation["tty_colors"].unsqueeze(2)

            # Update frame stack
            if self.model_flags.network.obs_frame_stack > 1:
                frame_stack_chars = observation["tty_chars"][:, 0, -(self.model_flags.network.obs_frame_stack - 1):].clone()
                frame_stack_colors = observation["tty_colors"][:, 0, -(self.model_flags.network.obs_frame_stack - 1):].clone()

            observation["done"] = observation["done"].bool()

            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.model_flags.setup.use_amp):
                policy_outputs, _ = self.agent.predict(observation, inference_params=inference_params)

            # TEST: see if autoregressive vs. parallel gets you the same outputs
            # copy_input = {k: v[-1:, ...] for k, v in observation.items()}
            # policy_outputs_copy = self.agent_copy.predict(copy_input, inference_params=inference_params, mamba_use_inference_params=True)

            # print('policy_logits', policy_outputs['policy_logits'])
            # print('policy_logits_copy', policy_outputs_copy['policy_logits'])

            # assert torch.allclose(policy_outputs['policy_logits'], policy_outputs_copy['policy_logits'], rtol=10, atol=1e-1), f"policy_logits do NOT match! {torch.max(torch.abs(policy_outputs['policy_logits'] - policy_outputs_copy['policy_logits']))}"

            new_observation = env.step(policy_outputs["action"])
            inference_params.seqlen_offset += 1

            new_observation["prev_action"] = new_observation["last_action"] # key name conversion

            # Update dashboard
            dashboard.step(new_observation, policy_outputs, policy_outputs["action"].item())

            # Check if rollout is done
            if new_observation["done"].item():
                logging.info("Reached done signal.")
                self._wrap_up_rollout(new_observation, dashboard, None)
                break

            # concat observations
            for k in observation.keys():
                if k in ["tty_chars", "tty_colors"]:
                    observation[k] = observation[k][:, :, 0, ...]
                observation[k] = torch.cat([observation[k], new_observation[k]], dim=0)[-self.config.rollout.max_seqlen:, ...]

        else: 
            logging.info("Cutting episode short ...")
            # Episode might not have finished
            self._wrap_up_rollout(new_observation, dashboard, None)

        env.close()

    def _wrap_up_rollout(self, observation, dashboard: RolloutDashboard, ttyrec_save_folder: str):
        """
        Do any final logging/saving/etc. that needs to happen
        when the game ends.
        """
        metrics = dashboard.get_metrics()
        metrics["episode_return"] = observation["episode_return"].item()

        logging.info(
            "Episode ended after %d steps. Return: %.1f",
            observation["episode_step"].item(),
            observation["episode_return"].item(),
        )
        logging.info(f'{metrics}')

        self.metrics_q.put(metrics)

        if self.done_q:
            self.done_q.put('done!')

        model_name = self.config.rollout.model_load_name.split('.')[0]
        returns_file = os.path.join(self.config.setup.save_dir, f'eval_returns_{model_name}.txt')
        lens_file = os.path.join(self.config.setup.save_dir, f'eval_lens_{model_name}.txt')
        dlvls_file = os.path.join(self.config.setup.save_dir, f'eval_dlvls_{model_name}.txt')

        os.makedirs(self.config.setup.save_dir, exist_ok=True)
        with open(returns_file, 'a') as f:
            f.write(str(metrics["episode_return"]) + '\n')

        with open(lens_file, 'a') as f:
            f.write(str(metrics["steps"]) + '\n')

        with open(dlvls_file, 'a') as f:
            f.write(str(metrics["max_dlvl"]) + '\n')

    def rollout_gpu(self):
        """
        Rollout trained model ~flags.num_rollouts number of times on GPU.
        """
        self.agent = self._agent_setup()

        ddp_util = DDPUtil()

        seeds = self._get_seeds()

        self.metrics_q = mp.Manager().Queue()
        self.done_q = mp.Manager().Queue()
        mp.spawn(
            self._rollout_chunk_gpu,
            args=(self.config.rollout.num_gpus, ddp_util, seeds),
            nprocs=self.config.rollout.num_gpus,
            join=True
        )

        return self._post_process()

    def _rollout_chunk_gpu(self, rank: int, world_size: int, ddp_util: DDPUtil, seeds: List[int]):
        """
        TODO
        """
        ddp_util.setup(rank, world_size)
        self.agent.to(rank)
        # self.agent_copy.to(rank)
        self.agent.move_to_ddp(rank, world_size)
        # self.agent_copy.move_to_ddp(rank, world_size)
        seeds = seeds[rank * len(seeds)//world_size:(rank + 1) * len(seeds)//world_size]
 
        for idx, seed in enumerate(seeds):
            self._single_rollout(seed, idx, rank)

    def _post_process(self):
        """
        Compute and save final metrics.
        """
        returns = []
        episode_lens = []
        dlvls = []
        while not self.metrics_q.empty():
            metrics = self.metrics_q.get()

            # returns
            returns.append(metrics["episode_return"])

            # episode lens
            episode_lens.append(metrics["steps"])

            # dungeon levels
            dlvls.append(metrics["max_dlvl"])

        return returns, episode_lens, dlvls

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    rollout = Rollout(cfg)
    
    rollout.rollout_gpu()

if __name__ == "__main__":
    main()