from concurrent.futures import ThreadPoolExecutor
import time
import os
from collections import defaultdict

import torch
import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np
import torch.nn as nn

from il_scale.nethack.agent import Agent
from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.data.parquet_data import ParquetData
from il_scale.nethack.utils.model import load_checkpoint
from il_scale.nethack.logger import Logger

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    print(cfg.data)
    if cfg.data.dataset_type == 'glyph':
        data = ParquetData(cfg.data)
        print('Using parquet loader ...')
    else:
        data = TTYData(cfg.data)
        print('Using ttyrec loader ...')

    # load model
    agent = Agent(cfg, None)
    agent.construct_model(cfg)
    checkpoint = load_checkpoint(cfg.setup.model_load_name, cfg.setup.wandb_id, savedir=cfg.setup.wandb_load_dir)
    agent.load(checkpoint["model_state_dict"])
    agent.to('cuda')

    # logger 
    logger = Logger(cfg)
    logger.start()

    criterion = nn.CrossEntropyLoss()

    with ThreadPoolExecutor(max_workers=50) as tp:
        dataloader = data.get_train_dataloader(tp)

        for i, batch in enumerate(dataloader, 1):
            with torch.no_grad():
                agent_outputs = agent.predict(batch)

            # Reshape logits
            T, B = agent_outputs['policy_logits'].shape[:2]
            logits = agent_outputs['policy_logits'].view(B * T, -1)
            labels = batch['labels'].contiguous().view(B * T)
            batch_samples = labels.shape[0]

            loss = criterion(logits, labels)

            logger.update_metrics(B, T, loss, logits, labels, labels.shape[0], batch, i, compute_per_act_recall=True)
            logger.sample_step(labels.shape[0])

            if i % 100 == 0:
                break

        # compute recall per action
        with open(f'action_errors_{cfg.setup.wandb_id}_{cfg.setup.model_load_name}.txt', 'w') as f:
            for act in logger.per_act_recall:
                f.write(f"Action {act} recall: {logger.per_act_recall[act]['tp'] / (logger.per_act_recall[act]['fn'] + logger.per_act_recall[act]['tp']) :.3f}\n")




if __name__ == "__main__":
    main()