import numpy as np
import torch
from torch import distributed as distrib
from torch import nn as nn

import hydra
from omegaconf import DictConfig, OmegaConf
import logging

from typing import DefaultDict, Optional
from torch.optim.lr_scheduler import LambdaLR

from habitat import Config
from habitat.config.default import Config as CN
from habitat_vc.config import get_config
from habitat_vc.il.objectnav.custom_baseline_registry import custom_baseline_registry
import habitat_vc.utils as utils
from tqdm import tqdm

from dataset import Dataset

class Trainer:
    def __init__(self, config: Optional[Config] = None, logger=None) -> None:
        distrib.init_process_group(backend="nccl")
        torch.cuda.set_device(distrib.get_rank())
        self.local_rank = distrib.get_rank()
        self.device = torch.device("cuda", self.local_rank)
        self.dataset = Dataset(config.DATASET.DATA_ROOT)
        self.logger = logger
        self.config = config
    
    def setup_policy(self, config: Config):
        """Initialize the policy from the configuration."""
        observation_space = self.dataset.observation_space
        action_space = self.dataset.action_space

        # Initialize the policy using the custom baseline registry
        policy = custom_baseline_registry.get_policy(config.IL.POLICY.name)
        self.policy = policy.from_config(
            config, observation_space, action_space
        )
        
        # Load pretrained state
        # import ipdb; ipdb.set_trace()  # Debug
        if self.config.IL.BehaviorCloning.pretrained:
            pretrained_state = torch.load(
                self.config.IL.BehaviorCloning.pretrained_weights, map_location="cpu"
            )
            self.logger.info("Loading pretrained state")
            
            missing_keys = self.policy.load_state_dict(
                {
                    k.replace("model.", ""): v
                    for k, v in pretrained_state["state_dict"].items()
                },
                strict=False,
            )
            self.logger.info("Loading checkpoint missing keys: {}".format(missing_keys))

        self.policy.to(self.device)
        # self.policy = torch.nn.parallel.DistributedDataParallel(
        #     self.policy,
        #     device_ids=[self.local_rank],
        #     output_device=self.local_rank,
        #     find_unused_parameters=True,
        # )

    @torch.no_grad()
    def evaluate(self):
        self.config.defrost()
        self.config.RUN_TYPE = "eval"
        self.config.freeze()
        self.setup_policy(self.config)
        self.policy.eval()

        eval_dataset = Dataset(self.config.DATASET.EVAL_DATA_ROOT)
        eval_dataset.distribute()

        logits, rnn_hidden_states, distribution_entropy = torch.load("logits.pth", map_location=self.device)
        obs_batch, recurrent_hidden_states_batch, actions_batch, prev_actions_batch, masks_batch = torch.load("obs.pth", map_location=self.device)

        (logitsA, rnn_hidden_statesA, distribution_entropyA) = self.policy(
                obs_batch,
                recurrent_hidden_states_batch,
                prev_actions_batch,
                masks_batch,
            )
        batch_pred_actions = logitsA.argmax(dim=-1).squeeze(-1)
        
        episode_id = "A19WXS1CLVLEEX:3A7Y0R2P2QROKUD0IWVY59N8MI0JX3"
        dataset_episode_idx = eval_dataset.episodes.index(episode_id+".npz")
        track = eval_dataset[dataset_episode_idx]
        label = track.pop('demonstration')
        masks = track.pop('continued_mask')
        prev_actions = torch.zeros((1,), dtype=torch.long, device=self.device)
        num_rnn_layer_multiplier = (
            2 if self.config.MODEL.STATE_ENCODER.rnn_type == "LSTM" else 1
        )
        rnn_hidden_states = torch.zeros(
            self.config.MODEL.STATE_ENCODER.num_recurrent_layers * num_rnn_layer_multiplier,
            1,
            self.config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        
        pred_actions = torch.zeros_like(label, device=self.device)
        # step by step?
        for i in range(label.shape[0]):
            observation = track[i:i + 1].to(self.device)
            reset_mask = masks[i:i + 1].to(self.device)
            old_rnn_hidden_states = rnn_hidden_states
            logits, rnn_hidden_states, distribution_entropy = self.policy(
                observation,
                rnn_hidden_states,
                prev_actions,
                reset_mask,
            )
            action = logits.argmax(dim=-1).squeeze(-1)
            prev_actions = label[i]

            pred_actions[i] = action
        
        # batch predict
        prev_actions = torch.zeros((1,), dtype=torch.long, device=self.device)
        num_rnn_layer_multiplier = (
            2 if self.config.MODEL.STATE_ENCODER.rnn_type == "LSTM" else 1
        )
        rnn_hidden_states = torch.zeros(
            self.config.MODEL.STATE_ENCODER.num_recurrent_layers * num_rnn_layer_multiplier,
            1,
            self.config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        ( logits, rnn_hidden_states, distribution_entropy ) = self.policy(
            track.to(self.device),
            rnn_hidden_states,
            prev_actions,
            masks.to(self.device),
        )
        batch_pred_actions = logits.argmax(dim=-1).squeeze(-1)
        
        import ipdb; ipdb.set_trace()  # Debug



@hydra.main(
    version_base=None,
    config_path="configs",
    config_name="config_objectnav_deit-t-freeze.yaml",
)
def test(cfg: DictConfig):
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)

    logger = logging.getLogger("distillation")
    trainer = Trainer(config=cfg, logger=logger)
    trainer.evaluate()

if __name__ == "__main__":
    # test
    test()