from concurrent.futures import ThreadPoolExecutor
import time
import os
from collections import defaultdict
import copy

import hydra
from omegaconf import DictConfig, OmegaConf
import wandb
import torch
from nle import nethack

from il_scale.nethack.agent import Agent
from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.data.parquet_data import ParquetData
from il_scale.nethack.utils.model import load_checkpoint

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    print(cfg.data)
    if cfg.data.dataset_type == 'glyph':
        data = ParquetData(cfg.data)
        print('Using parquet loader ...')
    else:
        data = TTYData(cfg.data)
        print('Using ttyrec loader ...')


    # load model
    agent = Agent(cfg, None)
    agent.construct_model(cfg)
    checkpoint = load_checkpoint(cfg.setup.model_load_name, cfg.setup.wandb_id, savedir=cfg.setup.wandb_load_dir)
    agent.load(checkpoint["model_state_dict"])
    agent.to('cuda')


    err_examples = []
    err_10_examples = []
    with ThreadPoolExecutor(max_workers=50) as tp:
        dataloader = data.get_train_dataloader(tp)
        i = 0
        for batch in dataloader:
            print(f'GAME ID: {batch["gameids"][0][0].item()}')

            batch_1 = batch
            batch_2 = copy.deepcopy(batch)
            # batch_2['prev_action'][...] = 107

            with torch.no_grad():
                print('batch 1 prev action', batch_1['prev_action'])
                print('batch 2 prev action', batch_2['prev_action'])
                agent_outputs_1 = agent.predict(batch_1)
                agent_outputs_2 = agent.predict(batch_2)

            # Reshape logits
            T, B = agent_outputs_1['policy_logits'].shape[:2]
            logits_1 = agent_outputs_1['policy_logits'].view(B * T, -1)
            logits_2 = agent_outputs_2['policy_logits'].view(B * T, -1)
            print('logits 1', logits_1)
            print('logits 2', logits_2)
            exit(0)
            labels = batch['labels'].contiguous().view(B * T)
            batch_samples = labels.shape[0]

            _, top_k = torch.topk(logits, k=10, dim=1)

            preds = torch.argmax(logits, dim=1)

            acc = torch.sum(preds == labels)/batch_samples
            print('acc', acc)

            top_10 = torch.sum((top_k[:, :10] == labels.view(-1, 1)), dim=-1)
            top_10 = top_10.view(T, B)
            top_k = top_k.view(T, B, 10)

            re_preds = preds.view(T, B)
            re_labels = labels.view(T, B)
            for b in range(B):
                for t in range(T):
                    true_act = nethack.ACTIONS[re_labels[t][b].item()]
                    pred_act = nethack.ACTIONS[re_preds[t][b].item()]
                    prev_act = nethack.ACTIONS[re_labels[t-1][b].item()]
                    if true_act != pred_act:
                        if t+2 < T:
                            err_examples.append({
                                'prev_state': batch['tty_chars'][t-1][b][0],
                                'state': batch['tty_chars'][t][b][0],
                                'next_state': batch['tty_chars'][t+2][b][0],
                                'prev_label': prev_act,
                                'label': true_act,
                                'pred': pred_act,
                                'blstats': batch['blstats'][t][b],
                            })

                    if not top_10[t][b]:
                        err_10_examples.append({
                            'state': batch['tty_chars'][t][b][0],
                            'label': true_act,
                            'pred': [nethack.ACTIONS[a] for a in top_k[t][b].tolist()]
                        })

            break
            i += 1
            if i == 5:
                break
    
    # print error examples
    with open('errors.txt', 'w') as f:
        for err in err_examples:
            true_act = str(err['label'])
            pred_act = str(err['pred'])
            prev_act = str(err['prev_label'])

            # skip navigational errors
            if true_act.startswith('CompassDirection') and pred_act.startswith('CompassDirection'):
                continue

            f.write(f"Prev action: {prev_act}\n")
            f.write(f"True action: {true_act}\n")
            f.write(f"Predicted action: {pred_act}\n")
            f.write("PREV STATE\n")
            for row in err['prev_state']:
                f.write(''.join([chr(c) for c in row]))
                f.write('\n')
            f.write("STATE\n")
            for row in err['state']:
                f.write(''.join([chr(c) for c in row]))
                f.write('\n')
            f.write(f"blstats: {err['blstats']}\n")
            f.write("NEXT STATE\n")
            for row in err['next_state']:
                f.write(''.join([chr(c) for c in row]))
                f.write('\n')
            f.write('\n')

    # print error examples
    with open('errors_10.txt', 'w') as f:
        for err in err_10_examples:
            true_act = str(err['label'])
            # if true_act.startswith('CompassDirection'):
            #     continue
            f.write(f"True action: {chr(err['label'])}\n")
            for a in err['pred']:
                f.write(f"Predicted action: {chr(a)}\n")

            for row in err['state']:
                f.write(''.join([chr(c) for c in row]))
                f.write('\n')
            f.write('\n')

if __name__ == "__main__":
    main()