from concurrent.futures import ThreadPoolExecutor
import os
import random
import time

import torch
import hydra
from omegaconf import DictConfig, OmegaConf
import torch.nn as nn
import numpy as np
import scipy.stats as stats

from il_scale.nethack.agent import Agent
from il_scale.nethack.data.parquet_data import ParquetData
from il_scale.nethack.utils.model import load_checkpoint
from il_scale.nethack.utils.model import count_params

def mask_labels_from_gameids(gameids: torch.tensor, labels: torch.tensor):
    mask = (gameids == 0).to(labels.device)
    new_labels = labels.clone()
    new_labels.masked_fill_(mask, -100)
    return new_labels

@torch.no_grad()
def evaluate(data, agent, overlapping: bool = False, eval_ctx_len: int = None, use_amp: bool = False):
    criterion = nn.CrossEntropyLoss(reduction='none')

    with ThreadPoolExecutor(max_workers=30) as tp:
        dataloader = data.get_train_dataloader(tp, loop_forever=False)

        losses = []
        loss_indiv = []
        total_samples = 0
        last_obs = None
        # timer_samples = 0
        # start = time.time()

        for i, batch in enumerate(dataloader, 1):
            T, B = batch['labels'].shape[:2]

            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
                if overlapping and i > 1:
                    new_batch = dict()
                    for k in batch.keys():
                        new_batch[k] = torch.cat([last_obs[k], batch[k]], dim=0)[-eval_ctx_len:, ...]
                    new_batch['labels'][:-T] = -100

                    agent_outputs, _ = agent.predict(new_batch)

                    logits = agent_outputs['policy_logits']
                    labels = mask_labels_from_gameids(new_batch['gameids'], new_batch['labels'].contiguous().long())
                    
                    # Reshape logits
                    logits = logits.view(B * new_batch['labels'].shape[0], -1)
                    labels = labels.view(B * new_batch['labels'].shape[0])
                else:
                    agent_outputs, _ = agent.predict(batch)
                
                    logits = agent_outputs['policy_logits']
                    labels = mask_labels_from_gameids(batch['gameids'], batch['labels'].contiguous().long())
                    
                    # Reshape logits
                    logits = logits.view(B * T, -1)
                    labels = labels.view(B * T)
                
                loss = criterion(logits, labels)

            loss_indiv.append(loss[labels != -100].cpu())
            losses.append(loss.sum().cpu())
            total_samples += (labels != -100).sum().item()
            # timer_samples += (labels != -100).sum().item()

            # if i % 10 == 0:
            #     end = time.time()
            #     # fps
            #     print(f'FPS: {timer_samples / (end - start)}')
            #     start = time.time()
            #     timer_samples = 0

            last_obs = batch if not overlapping or i == 1 else new_batch

    # return mean, SE
    return torch.sum(torch.stack(losses)) / total_samples, stats.sem(torch.cat(loss_indiv, dim=0))

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    # Set seeds to always get the same dataset ordering every time
    random.seed(0)
    torch.manual_seed(0)
    np.random.seed(0)

    if cfg.setup.overlapping:
        cfg.data.eval_ctx_len = cfg.data.unroll_length
        cfg.data.unroll_length = int(0.1 * cfg.data.unroll_length)

    # 0. Load wandb
    model_cfg = OmegaConf.load(os.path.join('models', cfg.setup.wandb_id, 'cfg.omega'))

    # HARD CODED EXCEPTIONS DUE TO MODEL RESUMING
    if cfg.setup.wandb_id == '2rx53f2b':
        model_cfg = OmegaConf.load(os.path.join('models', '1dj2vzju', 'cfg.omega'))
        model_cfg.network.hdim = 512
        model_cfg.network.tf_num_layers = 5
        model_cfg.network.tf_num_heads = 512 // 64
    elif cfg.setup.wandb_id == 'ysjiriyk':
        model_cfg = OmegaConf.load(os.path.join('models', '1dj2vzju', 'cfg.omega'))
        model_cfg.network.hdim = 384
        model_cfg.network.tf_num_layers = 4
        model_cfg.network.tf_num_heads = 384 // 64

    # add cfg modifications consistent with nethack_config.yaml
    if 'use_message' not in model_cfg.network:
        model_cfg.network.use_message = True
    if 'use_crop' not in model_cfg.network:
        model_cfg.network.use_crop = True
    if 'use_observation' not in model_cfg.network:
        model_cfg.network.use_observation = True

    # 1. Setup data
    data = ParquetData(cfg.data)

    # 2. Load model
    agent = Agent(model_cfg, None)
    agent.construct_model(model_cfg)
    checkpoint = load_checkpoint(cfg.setup.model_load_name, cfg.setup.wandb_id, savedir=cfg.setup.wandb_load_dir)
    agent.load(checkpoint["model_state_dict"])
    agent.to('cuda')

    # 3. Evaluate
    print(f'AMP State: {model_cfg.setup.use_amp}')
    eval_loss, eval_se = evaluate(data, agent, cfg.setup.overlapping, cfg.data.eval_ctx_len, use_amp=model_cfg.setup.use_amp)

    # 4. Log results
    print(f'Saving eval result of {eval_loss:.3f} ({eval_se:.4f}) to {cfg.setup.save_dir}')
    os.makedirs(cfg.setup.save_dir, exist_ok=True)
    torch.save({
        'eval_loss': eval_loss.item(),
        'eval_se': eval_se.item(),
        'params': count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer),
        'samples': checkpoint['num_samples']
    }, os.path.join(cfg.setup.save_dir, f'eval_loss_overlap_{cfg.setup.overlapping}_{cfg.setup.model_load_name}'))
        

if __name__ == "__main__":
    main()