from collections import OrderedDict
import os
import time
import numpy as np
import torch
from torch import distributed as distrib
from torch import nn as nn

import hydra
from omegaconf import DictConfig, OmegaConf
import logging

from typing import DefaultDict, Dict, Optional
from torch import Tensor
from torch.optim.lr_scheduler import LinearLR
from torch import optim

from habitat import Config
from habitat.config.default import Config as CN
from habitat_vc.config import get_config
from habitat_vc.il.objectnav.custom_baseline_registry import custom_baseline_registry
import communication_utils
from tqdm import tqdm
from timing_tools import SilentTqdm
from stats_tracker import StatsTracker
import threading
from torch.utils.tensorboard import SummaryWriter

from dataset import Dataset

logger = logging.getLogger("Trainer")

class Trainer:
    def __init__(self, config: Optional[Config] = None) -> None:
        distrib.init_process_group(backend="nccl")
        torch.cuda.set_device(distrib.get_rank())
        communication_utils.init_communication()
        self.local_rank = distrib.get_rank()
        self.device = torch.device("cuda", self.local_rank)
        
        # if distrib.get_rank() == 0:
        #     os.makedirs(config.LOG_DIR, exist_ok=True)
        #     logger.add_filehandler(config.LOG_FILE)
        self.config = config
    
    def setup_policy(self, config: Config, observation_space, action_space):
        """Initialize the policy from the configuration."""

        # Initialize the policy using the custom baseline registry
        policy = custom_baseline_registry.get_policy(config.IL.POLICY.name)
        self.policy = policy.from_config(
            config, observation_space, action_space
        )
        
        # Load pretrained state
        if self.config.IL.Distillation.pretrained:
            pretrained_state = torch.load(
                self.config.IL.Distillation.pretrained_weights, map_location="cpu"
            )
            logger.info("Loading pretrained state")

            missing_keys = self.policy.load_state_dict(
                {
                    k.replace("model.", ""): v
                    for k, v in pretrained_state["state_dict"].items()
                },
                strict=False,
            )
            logger.info("Loading checkpoint missing keys: {}".format(missing_keys))

        self.policy.to(self.device)

    def run(self):
        eval_dataset = Dataset(self.config.DATASET.EVAL_DATA_ROOT)
        os.makedirs(os.path.join(self.config.DATASET.EVAL_DATA_ROOT, "grad_cam"), exist_ok=True)
        self.config.defrost()
        self.config.RUN_TYPE = "eval"
        self.config.freeze()
        self.setup_policy(self.config, eval_dataset.observation_space, eval_dataset.action_space)
        self.policy.eval()
        self.policy.net.visual_encoder.backbone.token_grad = True
        self.policy.net.state_encoder.train()

        eval_dataset.distribute()

        total_accuracy = 0.0
        total_count = 0
        length = len(eval_dataset)
        lengths = communication_utils.gather_messages(length)
        if communication_utils._MAIN_PROCESS:
            total_length = sum(lengths)
            progress_bar = tqdm(total=total_length, desc="Evaluating")
        
            def threaded_progress_bar():
                while not progress_bar.n >= total_length:
                    messages = communication_utils.collect_messages()
                    for message in messages:
                        progress_bar.update(message)
                    time.sleep(0.1)
            progress_thread = threading.Thread(target=threaded_progress_bar)
            progress_thread.start()
        
        

        for episode_idx in range(len(eval_dataset)):
            track = eval_dataset[episode_idx]
            label = track.pop('demonstration')
            masks = track.pop('continued_mask')
            prev_actions = torch.zeros((1,), dtype=torch.long, device=self.device)
            num_rnn_layer_multiplier = (
                2 if self.config.MODEL.STATE_ENCODER.rnn_type == "LSTM" else 1
            )
            rnn_hidden_states = torch.zeros(
                self.config.MODEL.STATE_ENCODER.num_recurrent_layers * num_rnn_layer_multiplier,
                1,
                self.config.MODEL.STATE_ENCODER.hidden_size,
                device=self.device,
            )
            
            pred_actions = torch.zeros_like(label, device=self.device)
            attentions = list()

            for i in range(label.shape[0]):
                observation = track[i:i + 1].to(self.device)
                continued_mask = masks[i:i + 1].to(self.device)
                old_rnn_hidden_states = rnn_hidden_states
                logits, rnn_hidden_states, distribution_entropy = self.policy(
                    observation,
                    rnn_hidden_states.detach(),
                    prev_actions,
                    continued_mask,
                )
                action = logits.argmax(dim=-1).squeeze(-1)
                prev_actions = label[i]
                logits[0, label[i]].backward()
                forward_dict = self.policy.net.visual_encoder.backbone.forward_dict
                xs = forward_dict['xs']
                grad_cams = list()
                for x in xs:
                    grad = x.grad
                    assert grad is not None, "Gradients are None, check if token_grad is set to True"
                    grad_cam = x * grad
                    grad_cam = grad_cam.sum(dim=-1)
                    grad_cams.append(grad_cam)
                grad_cams = torch.cat(grad_cams, dim=0)
                attentions.append(grad_cams)
                pred_actions[i] = action
            attentions = torch.stack(attentions, dim=0).detach().cpu().numpy()
            np.savez_compressed(
                os.path.join(
                    self.config.DATASET.EVAL_DATA_ROOT, "grad_cam", f"{track.metadata['episode']}.npz"
                ),
                attentions=attentions,
            )
            communication_utils.send_message(1)

        if communication_utils._MAIN_PROCESS:
            while progress_bar.n < total_length:
                time.sleep(1.0)
            progress_bar.close()
            progress_thread.join()
        
        communication_utils.broadcast_message("finished")
        
        return

@hydra.main(
    version_base=None,
    config_path="configs",
    config_name="config_objectnav_deit-t-freeze.yaml",
)
def run(cfg: DictConfig):
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)

    trainer = Trainer(config=config)
    trainer.run()

if __name__ == "__main__":
    # test
    run()