#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the CC-BY-NC license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3
import os
import random
from datetime import datetime
from itertools import count

import hydra
import numpy as np
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn.functional as F
from torch import distributed as distrib
from scipy.stats import spearmanr
import torch.cuda.amp as amp
from kneed import KneeLocator

from habitat import logger
from habitat.config import Config
from habitat.config.default import Config as CN
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_vc.config import get_config
from habitat_vc.visual_encoder import VisualEncoder
from habitat_baselines.rl.ddppo.ddp_utils import (
    EXIT,
    REQUEUE,
    rank0_only,
    add_signal_handlers,
    init_distrib_slurm,
    load_resume_state,
    requeue_job,
    save_resume_state,
)
from habitat_baselines.utils.common import (
    batch_obs,
    generate_video,
    linear_decay,
    get_checkpoint_id,
)
from torch import optim
from torch.optim.lr_scheduler import LambdaLR


def build_model(config: Config) -> None:
    r"""Sets up actor critic and agent for PPO.

    Args:
        ppo_cfg: config node with relevant params

    Returns:
        None
    """
    model_config = config.MODEL
    rgb_config = model_config.RGB_ENCODER
    backbone_config = config.model

    model_config.defrost()
    model_config.TORCH_GPU_ID = config.TORCH_GPU_ID
    model_config.freeze()

    use_augmentations = True
    
    visual_encoder = VisualEncoder(
        input_resize_size=rgb_config.input_resize_size,
        input_crop_size=rgb_config.input_crop_size,
        backbone_config=backbone_config,
        global_pool=rgb_config.global_pool,
        use_cls=rgb_config.use_cls,
        use_augmentations=use_augmentations,
        rgb_config=rgb_config,
    )

    # del visual_encoder.backbone.fc_hidden

    return visual_encoder.backbone, visual_encoder.visual_transform

@hydra.main(config_path="configs", config_name="config_imagenav")
def main(cfg: DictConfig) -> None:
    r"""Main function for habitat_vc
    Args:
        cfg: DictConfig object containing the configs for the experiment.
    """
    run_exp(cfg)


def execute_exp(config: Config) -> None:
    r"""This function runs the specified config with the specified runtype
    Args:
    config: Habitat.config
    """
    # set a random seed (from detectron2)
    seed = (
        os.getpid()
        + int(datetime.now().strftime("%S%f"))
        + int.from_bytes(os.urandom(2), "big")
    )
    config.defrost()
    config.TASK_CONFIG.SEED = seed
    config.freeze()
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    setup_experiment(config)

    trainer = Trainer(config)
    trainer.train()


def run_exp(cfg: DictConfig) -> None:
    r"""Runs experiment given mode and config

    Args:
        cfg: DictConfig object containing the configs for the experiment.

    Returns:
        None.
    """
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)
    execute_exp(config)


def setup_experiment(config: Config) -> None:
    if rank0_only():
        os.makedirs(config.CHECKPOINT_FOLDER, exist_ok=True)
        os.makedirs(config.LOG_DIR, exist_ok=True)

    config.defrost()
    config.TASK_CONFIG.DATASET.SCENES_DIR = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.SCENES_DIR
    )
    config.TASK_CONFIG.DATASET.DATA_PATH = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.DATA_PATH
    )
    config.freeze()

    os.environ["LD_LIBRARY_PATH"] = (
        "/usr/lib/x86_64-linux-gnu/nvidia-opengl:" + os.environ["LD_LIBRARY_PATH"]
    )
    os.environ["GLOG_minloglevel"] = "3"
    os.environ["MAGNUM_LOG"] = "quiet"


class EMAState:
    def __init__(self, default=1.0, decay=0.999) -> None:
        self.decay = decay
        self.state = torch.tensor(default).float()

    def to(self, device: torch.device) -> None:
        self.state = self.state.to(device)

    def update(self, state) -> None:
        # check ddp
        if not isinstance(state, torch.Tensor):
            state = torch.tensor(state).float()
        if state.device != self.state.device:
            state = state.to(self.state.device)
        
        if distrib.is_initialized():
            # reduce the state across all processes
            distrib.all_reduce(state, op=distrib.ReduceOp.SUM)
            state /= distrib.get_world_size()
        # update the state
        
        if self.state is None:
            self.state = state
        else:
            self.state = self.decay * self.state + (1 - self.decay) * state
    
    def get(self) -> torch.Tensor:
        return self.state


class PretrainDataset:
    def __init__(self, config: Config) -> None:
        self.npz_dir = config.NPZ_DIR
        self.video_dir = config.VIDEO_DIR
        npz_files = os.listdir(self.npz_dir)
        self.episodes = [os.path.splitext(npz_file)[0] for npz_file in npz_files]

    def __iter__(self):
        for episode in self.episodes:
            npz_file = os.path.join(self.npz_dir, episode+".npz")
            
            data = np.load(npz_file, allow_pickle=True)
            rgb_images = data["rgb"]
            pre_mask = np.array([False]+[True]*(len(rgb_images)-1), dtype=bool)
            for i in range(len(rgb_images)):
                yield {
                    "rgb": rgb_images[i],
                    "pre_mask": pre_mask[i],
                    "episode": episode,
                }
    
    def shuffle(self):
        random.shuffle(self.episodes)

    def batch_iter(self, batch_size: int):
        # loop over the dataset
        
        while True:
            batch = []
            for i, data in enumerate(self):
                batch.append(data)
                if (i + 1) % batch_size == 0:
                    # merge the batch
                    batch_rgb = np.stack([d["rgb"] for d in batch], axis=0)
                    batch_pre_mask = np.stack([d["pre_mask"] for d in batch], axis=0)
                    yield {
                        "rgb": batch_rgb,
                        "pre_mask": batch_pre_mask
                    }
                    batch = []
            if len(batch) > 0:
                yield batch
        
        
class Model(torch.nn.Module):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model
    
    def forward(self, x, pre_mask, max_error):
        B = x.shape[0]
        x = self.model.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.model.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.model.mask_ratio is not None:
            x, _, _ = self.model.random_masking(x, mask_ratio=self.model.mask_ratio)

        # append cls token
        cls_token = self.model.cls_token + self.model.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        for i, layer in enumerate(self.model.blocks):
            if i == self.model.reference_last_frame_layer_idx:
                refs = x
            x = layer(x)

        loss = 0
        stats = dict()

        key = refs[:-1]
        value = x[:-1]
        querys = refs[1:]
        targets = x[1:].detach()
        pre_mask = pre_mask[1:]
        key = key[pre_mask]
        value = value[pre_mask]
        querys = querys[pre_mask]
        targets = targets[pre_mask]
        output, certainty = self.model.reference_net(querys, key, value, self.model.pos_embed)
        diff = F.mse_loss(output, targets, reduction="none").sum(dim=-1)

        sample_errors = diff.detach().flatten()  # 计算每个样本的误差
        certainty = certainty.flatten()  # 将certainty展平为一维张量

        if self.model.reference_net.certainty == 'mlp':
            # learnable certainty, requires gradient
            target_certainty = 1 - torch.clip(sample_errors / max_error, 0, 1)
            certainty_loss = F.binary_cross_entropy(
                certainty, target_certainty, reduction="mean"
            )
            loss += certainty_loss
            
            stats[f"certainty_loss"] = certainty_loss.item()
            regression_loss = diff.mean()
        else:
            regression_loss = diff.mean()
        loss += regression_loss
        stats[f"reg_loss"] = regression_loss.item()
        stats[f"sample"] = (sample_errors.detach(), certainty.detach())

        return loss, stats


class Trainer:
    def __init__(self, config: Config) -> None:

        self.config = config
        self.local_rank, tcp_store = init_distrib_slurm(self.config.IL.distrib_backend)
        add_signal_handlers()
        torch.cuda.set_device(self.local_rank)

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += self.world_rank * self.config.NUM_PROCESSES
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")
        
        model, self.pre_transform = build_model(self.config)
        self.model = Model(model)
        
        if hasattr(self.config, "RESUME") and self.config.RESUME:
            ckpt = load_resume_state(config.RESUME)
            # remove ddp module in key names
            if list(ckpt["model"].keys())[0].startswith("module."):
                new_state_dict = {}
                for k, v in ckpt["model"].items():
                    if k.startswith("module."):
                        new_state_dict[k[7:]] = v
                    else:
                        break
                if len(new_state_dict) != len(ckpt["model"]):
                    logger.error(
                        "Checkpoint is not compatible with the current model. "
                        "Please check the checkpoint and model architecture."
                    )
                    raise ValueError("Checkpoint is not compatible with the current model.")
                ckpt["model"] = new_state_dict
            if ckpt is not None:
                try:
                    self.model.load_state_dict(ckpt["model"], strict=False)
                except RuntimeError:
                    certainty_keys = list()
                    for k in ckpt["model"]:
                        if 'certainty_mlp' in k:
                            certainty_keys.append(k)
                    for k in certainty_keys:
                        del ckpt["model"][k]
                    self.model.load_state_dict(ckpt["model"], strict=False)
                logger.info(f"Loaded checkpoint {config.RESUME}")
        
        # DDP model
        self.model.to(self.device)
        if distrib.is_initialized():
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                find_unused_parameters=True,
            )
            
        if config is not None and self.local_rank == 0:
            logger.add_filehandler(config.LOG_FILE)
            logger.info(f"config: {config}")

    def train(self) -> None:
        r"""Main function for training the model
        Args:
            config: Habitat.config
        """
        self.model.train()
        max_error = EMAState(1.0, 0.99)
        max_error.to(self.device)

        batch_size = self.config.BATCH_SIZE
        dataset = PretrainDataset(self.config)
        optimizer = optim.AdamW(
            [
                {"params": self.model.parameters(), "lr": self.config.IL.BehaviorCloning.encoder_lr},
            ],
            lr=self.config.IL.BehaviorCloning.encoder_lr,
            eps=self.config.IL.BehaviorCloning.eps,
            weight_decay=self.config.IL.BehaviorCloning.wd,
        )
        # grad_scaler = amp.GradScaler()
        lr_scheduler = LambdaLR(
            optimizer=optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),  # type: ignore
        )
        
        update = 0

        for epoch in count(1, 1):
            dataset.shuffle()
            dataset_iter = dataset.batch_iter(batch_size)
            for batch in dataset_iter:
                update += 1
                optimizer.zero_grad()
                rgb = torch.from_numpy(batch["rgb"]).to(self.device)
                rgb = (
                    rgb.permute(0, 3, 1, 2).float() / 255
                )  # convert channels-last to channels-first
                rgb = self.pre_transform(rgb)
                pre_mask = torch.from_numpy(batch["pre_mask"]).to(self.device)
                
                # with amp.autocast():
                loss, stats = self.model(rgb, pre_mask, max_error.get())

                # grad_scaler.scale(loss).backward()
                # grad_scaler.step(optimizer)
                # grad_scaler.update()
                loss.backward()
                optimizer.step()
                lr_scheduler.step()

                parsed_stats = dict()
                for key, value in stats.items():
                    if 'loss' in key:
                        parsed_stats[key] = value
                    else:
                        sample_errors, certainty = value
                        if self.world_rank == 0:
                            # 计算Spearman相关系数
                            spearman_corr = spearmanr(sample_errors.cpu().numpy(), 1-certainty.cpu().numpy()).correlation
                            parsed_stats[f"spearman_corr"] = spearman_corr
                            parsed_stats[f"sample_errors"] = sample_errors.mean().item()
                            parsed_stats[f"certainty"] = (certainty.mean().item(), certainty.std().item())
                        
                        # 排序+筛选
                        sample_errors, sample_error_idx = torch.sort(sample_errors)
                        sample_errors = sample_errors.cpu().numpy()
                        sample_errors = sample_errors[int(0.1*len(sample_errors)):]
                        sample_error_idx = sample_error_idx[int(0.1*len(sample_error_idx)):]
                        # 计算分界点
                        knee = KneeLocator(np.arange(len(sample_errors)), sample_errors, curve='convex', direction='increasing', interp_method='polynomial')
                        threshold = sample_errors[knee.elbow]
                        max_error.update(threshold)
                        sample_errors = sample_errors[:knee.elbow]
                        certainty = certainty[sample_error_idx[:knee.elbow]]
                        if self.world_rank == 0:
                            parsed_stats[f"matched_certainty"] = (certainty.mean().item(), certainty.std().item())
                            parsed_stats[f"matched_sample_errors"] = np.mean(sample_errors)
                            parsed_stats[f"matched_threshold"] = max_error.get().item()
                            parsed_stats[f"matched_spearman_corr"] = spearmanr(sample_errors, 1-certainty.cpu().numpy()).correlation
                            
                
                if self.world_rank == 0:
                    if update % self.config.LOG_INTERVAL == 0:
                        # log the stats
                        msg = f"Epoch: {epoch}, Step: {update}, Loss: {loss.item()}, LR: {optimizer.param_groups[0]['lr']}"
                        for key, value in parsed_stats.items():
                            msg += f", {key}: {value}"
                        logger.info(msg)
                        logger.info(f"Max error: {max_error.get().item()}")
                        
                    
                    if update % self.config.CHECKPOINT_INTERVAL == 0 or update == self.config.NUM_UPDATES:
                        checkpoint = {
                            "epoch": epoch,
                            "model": self.model.module.state_dict(),
                        }
                        checkpoint_path = os.path.join(
                            self.config.CHECKPOINT_FOLDER, f"checkpoint_{epoch}_{update}.pth"
                        )
                        torch.save(checkpoint, checkpoint_path)
                        logger.info(f"Saved checkpoint to {checkpoint_path}")
                    
                    if update == self.config.NUM_UPDATES:
                        logger.info("Reached max updates, exiting...")
                        distrib.barrier()
                        exit(0)
                    


if __name__ == "__main__":
    main()
