from collections import OrderedDict
import os
import time
import numpy as np
import torch
from torch import distributed as distrib
from torch import nn as nn

import hydra
from omegaconf import DictConfig, OmegaConf
import logging

from typing import DefaultDict, Dict, Optional
from torch import Tensor
from torch.optim.lr_scheduler import LinearLR
from torch import optim

from habitat import Config
from habitat.config.default import Config as CN
from habitat_vc.config import get_config
from habitat_vc.il.objectnav.custom_baseline_registry import custom_baseline_registry
import communication_utils
from tqdm import tqdm
from timing_tools import SilentTqdm
from stats_tracker import StatsTracker
import threading
from torch.utils.tensorboard import SummaryWriter

from dataset import Dataset

logger = logging.getLogger("Trainer")

class Trainer:
    def __init__(self, config: Optional[Config] = None) -> None:
        distrib.init_process_group(backend="nccl")
        torch.cuda.set_device(distrib.get_rank())
        communication_utils.init_communication()
        self.local_rank = distrib.get_rank()
        self.device = torch.device("cuda", self.local_rank)
        
        # if distrib.get_rank() == 0:
        #     os.makedirs(config.LOG_DIR, exist_ok=True)
        #     logger.add_filehandler(config.LOG_FILE)
        self.config = config
    
    def setup_policy(self, config: Config, observation_space, action_space):
        """Initialize the policy from the configuration."""

        # Initialize the policy using the custom baseline registry
        policy = custom_baseline_registry.get_policy(config.IL.POLICY.name)
        self.policy = policy.from_config(
            config, observation_space, action_space
        )
        
        # Load pretrained state
        if self.config.IL.Distillation.pretrained:
            pretrained_state = torch.load(
                self.config.IL.Distillation.pretrained_weights, map_location="cpu"
            )
            logger.info("Loading pretrained state")

            missing_keys = self.policy.load_state_dict(
                {
                    k.replace("model.", ""): v
                    for k, v in pretrained_state["state_dict"].items()
                },
                strict=False,
            )
            logger.info("Loading checkpoint missing keys: {}".format(missing_keys))

        self.policy.to(self.device)
        self.policy = torch.nn.parallel.DistributedDataParallel(
            self.policy,
            device_ids=[self.local_rank],
            output_device=self.local_rank,
            find_unused_parameters=self.config.IL.Distillation.find_unused_parameters,
        )

    def train(self):
        dataset = Dataset(self.config.DATASET.DATA_ROOT)
        self.config.defrost()
        self.config.RUN_TYPE = "train"
        self.config.freeze()
        self.setup_policy(self.config, dataset.observation_space, dataset.action_space)
        self.policy.train()

        num_updates = self.config.NUM_UPDATES
        num_envs = self.config.NUM_ENVS
        num_steps = self.config.IL.Distillation.num_steps


        # Initialize the optimizer and learning rate scheduler
        encoder_lr = self.config.IL.Distillation.encoder_lr
        translate_lr = self.config.IL.Distillation.translate_lr
        policy_lr = self.config.IL.Distillation.policy_lr
        eps = self.config.IL.Distillation.eps
        wd = self.config.IL.Distillation.wd
        max_grad_norm = self.config.IL.Distillation.max_grad_norm
        visual_encoder_params = list()
        translate_params = list()
        policy_params = list()
        for name, param in self.policy.named_parameters():
            if param.requires_grad:
                if "visual_encoder.backbone" in name:
                    visual_encoder_params.append(param)
                elif "visual_encoder" in name:
                    translate_params.append(param)
                else:
                    policy_params.append(param)
        optimizer = optim.AdamW(
            [
                {"params": visual_encoder_params, "lr": encoder_lr},
                {"params": translate_params, "lr": translate_lr},
                {"params": policy_params, "lr": policy_lr},
            ],
            lr=encoder_lr,
            eps=eps,
            weight_decay=wd,
        )
        
        scheduler = LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=0.01,
            total_iters=num_updates,
        )

        # init loss function
        loss_fn = nn.CrossEntropyLoss()

        # init rnn hidden states
        num_rnn_layer_multiplier = (
            2 if self.config.MODEL.STATE_ENCODER.rnn_type == "LSTM" else 1
        )
        rnn_hidden_states = torch.zeros(
            self.config.MODEL.STATE_ENCODER.num_recurrent_layers * num_rnn_layer_multiplier,
            num_envs,
            self.config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        update = 0
        count_checkpoints = 0

        if communication_utils._MAIN_PROCESS:
            # prepare tqdm without bar
            progress_bar = SilentTqdm(total=num_updates)

            # Initialize the status tracker
            status_tracker = StatsTracker(ema_alpha=0.05, log_interval=self.config.LOG_INTERVAL * distrib.get_world_size(), tensorboard="tensorboard")

        # import ipdb; ipdb.set_trace()
        # Training loop
        while update < num_updates:
            logger.info(f"Start an epoch: {update}/{num_updates}")
            for batch in dataset.data_loader(shuffle=True, num_steps=num_steps, num_envs=num_envs):
                if update >= num_updates:
                    break

                # Prepare the batch
                observation = batch.to(self.device)
                masks = observation.pop("continued_mask")
                demonstration = observation.pop("demonstration")
                prev_actions = observation.pop("prev_actions")
                inflection_weight = (demonstration != prev_actions).float() * masks.float() * (self.config.IL.Distillation.inflection_coef-1) + 1.0
                

                # Forward pass
                logits, rnn_hidden_states, distribution_entropy = self.policy(
                    observation, rnn_hidden_states, prev_actions, masks
                )

                # Compute loss
                action_loss = loss_fn(
                    logits.view(-1, logits.size(-1)),
                    demonstration.view(-1),
                ) * inflection_weight.view(-1)
                action_loss = action_loss.sum() / inflection_weight.sum()
                
                metrics = OrderedDict({
                    "Accuracy": ((logits.argmax(dim=-1).flatten() == demonstration.flatten()).float().mean().item()),
                    "loss/action": action_loss.item(),
                })
                
                total_loss = action_loss
                if hasattr(self.policy.module, 'forward_dict') and 'merge_function' in self.policy.module.forward_dict:
                    forward_dict = self.policy.module.forward_dict
                    forward_dict = forward_dict['merge_function'](forward_dict=forward_dict)
                    for key, value in forward_dict.items():
                        metrics[key] = value.item() if isinstance(value, Tensor) else value
                        if "loss" in key:
                            total_loss += value
                
                # Backward pass and optimization
                optimizer.zero_grad()
                total_loss.backward()
                if max_grad_norm > 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.policy.parameters(), max_grad_norm
                    )
                    metrics["grad_norm"] = grad_norm.item()
                optimizer.step()
                scheduler.step()

                # update the rnn hidden states
                rnn_hidden_states = rnn_hidden_states.detach()
                update += 1

                metrics = communication_utils.gather_messages(metrics)
                if communication_utils._MAIN_PROCESS:
                    progress_bar.update(1)
                    average_metrics = {
                        key: sum(m[key] for m in metrics) / len(metrics) for key in metrics[0]
                    }
                    status_tracker.update(average_metrics)

                if update % self.config.CHECKPOINT_INTERVAL == 0 and communication_utils._MAIN_PROCESS:
                    self.save_checkpoint(
                        f"ckpt.{count_checkpoints}.pth",
                        dict(step=update),
                    )
                    count_checkpoints += 1

                if distrib.get_rank() == 0 and update % self.config.LOG_INTERVAL == 0 and communication_utils._MAIN_PROCESS:
                    progress_msg = progress_bar.status()
                    metrics_msg = status_tracker.format(precision=4)
                    message = f"{progress_msg} | {metrics_msg}"
                    logger.info(message)
                    

    def save_checkpoint(
        self, file_name: str, extra_state: Optional[Dict] = None
    ) -> None:
        r"""Save checkpoint with specified name.

        Args:
            file_name: file name for checkpoint

        Returns:
            None
        """
        state_dict = self.policy.module.state_dict()
        save_state_dict = OrderedDict()
        for k, v in state_dict.items():
            k = 'model.' + k
            save_state_dict[k] = v.cpu()

        checkpoint = {
            "state_dict": save_state_dict,
            "config": self.config,
        }
        if extra_state is not None:
            checkpoint["extra_state"] = extra_state

        os.makedirs("checkpoints", exist_ok=True)
        torch.save(checkpoint, os.path.join("checkpoints", file_name))


    @torch.no_grad()
    def evaluate(self):
        eval_dataset = Dataset(self.config.DATASET.EVAL_DATA_ROOT)
        self.config.defrost()
        self.config.RUN_TYPE = "eval"
        self.config.freeze()
        self.setup_policy(self.config, eval_dataset.observation_space, eval_dataset.action_space)
        self.policy.eval()

        eval_dataset.distribute()

        total_accuracy = 0.0
        total_count = 0
        length = len(eval_dataset)
        lengths = communication_utils.gather_messages(length)
        if communication_utils._MAIN_PROCESS:
            total_length = sum(lengths)
            progress_bar = tqdm(total=total_length, desc="Evaluating")
        
            def threaded_progress_bar():
                while not progress_bar.n >= total_length:
                    messages = communication_utils.collect_messages()
                    for message in messages:
                        progress_bar.update(message)
                    time.sleep(0.1)
            progress_thread = threading.Thread(target=threaded_progress_bar)
            progress_thread.start()
        
        

        for episode_idx in range(len(eval_dataset)):
            track = eval_dataset[episode_idx]
            label = track.pop('demonstration')
            masks = track.pop('continued_mask')
            prev_actions = torch.zeros((1,), dtype=torch.long, device=self.device)
            num_rnn_layer_multiplier = (
                2 if self.config.MODEL.STATE_ENCODER.rnn_type == "LSTM" else 1
            )
            rnn_hidden_states = torch.zeros(
                self.config.MODEL.STATE_ENCODER.num_recurrent_layers * num_rnn_layer_multiplier,
                1,
                self.config.MODEL.STATE_ENCODER.hidden_size,
                device=self.device,
            )
            
            pred_actions = torch.zeros_like(label, device=self.device)

            for i in range(label.shape[0]):
                observation = track[i:i + 1].to(self.device)
                continued_mask = masks[i:i + 1].to(self.device)
                old_rnn_hidden_states = rnn_hidden_states
                logits, rnn_hidden_states, distribution_entropy = self.policy(
                    observation,
                    rnn_hidden_states,
                    prev_actions,
                    continued_mask,
                )
                action = logits.argmax(dim=-1).squeeze(-1)
                prev_actions = label[i]

                pred_actions[i] = action

            # Calculate accuracy
            correct = (pred_actions.cpu() == label).sum().item()
            total_accuracy += correct
            total_count += label.numel()
            communication_utils.send_message(1)

        if communication_utils._MAIN_PROCESS:
            while progress_bar.n < total_length:
                time.sleep(1.0)
            progress_bar.close()
            progress_thread.join()
        
        communication_utils.broadcast_message("finished")

        accuracy = torch.tensor(
            total_accuracy, device=self.device, dtype=torch.float32
        )
        total_count = torch.tensor(
            total_count, device=self.device, dtype=torch.float32
        )
        # Reduce accuracy and count across all processes
        distrib.reduce(accuracy, dst=0, op=distrib.ReduceOp.SUM)
        distrib.reduce(total_count, dst=0, op=distrib.ReduceOp.SUM)
        # Log the accuracy only from the main process
        total_accuracy = accuracy.item()
        total_count = total_count.item()
        accuracy = total_accuracy / total_count if total_count > 0 else 0.0

        if distrib.get_rank() == 0:
            logger.info(f"Evaluation Accuracy: {accuracy:.4f}")
        
        return accuracy

@hydra.main(
    version_base=None,
    config_path="configs",
    config_name="config_objectnav_deit-t-freeze.yaml",
)
def test(cfg: DictConfig):
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)

    trainer = Trainer(config=config)
    trainer.evaluate()

@hydra.main(
    version_base=None,
    config_path="configs",
    config_name="config_objectnav_deit-t-freeze.yaml",
)
def train(cfg: DictConfig):
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)

    trainer = Trainer(config=config)
    trainer.train()

if __name__ == "__main__":
    # test
    train()