#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the CC-BY-NC license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3
import os
import random
from datetime import datetime
from itertools import count
import time

import cv2
import hydra
import imageio
import numpy as np
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn.functional as F
from torch import distributed as distrib
from scipy.stats import spearmanr

from habitat import logger
from habitat.config import Config
from habitat.config.default import Config as CN
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_vc.config import get_config
from habitat_vc.visual_encoder import VisualEncoder
from habitat_baselines.rl.ddppo.ddp_utils import (
    EXIT,
    REQUEUE,
    rank0_only,
    add_signal_handlers,
    init_distrib_slurm,
    load_resume_state,
    requeue_job,
    save_resume_state,
)
from habitat_baselines.utils.common import (
    batch_obs,
    generate_video,
    linear_decay,
    get_checkpoint_id,
)
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
import triton
import triton.language as tl


def build_model(config: Config) -> None:
    r"""Sets up actor critic and agent for PPO.

    Args:
        ppo_cfg: config node with relevant params

    Returns:
        None
    """
    model_config = config.MODEL
    rgb_config = model_config.RGB_ENCODER
    backbone_config = config.model

    model_config.defrost()
    model_config.TORCH_GPU_ID = config.TORCH_GPU_ID
    model_config.freeze()

    use_augmentations = True
    
    visual_encoder = VisualEncoder(
        input_resize_size=rgb_config.input_resize_size,
        input_crop_size=rgb_config.input_crop_size,
        backbone_config=backbone_config,
        global_pool=rgb_config.global_pool,
        use_cls=rgb_config.use_cls,
        use_augmentations=use_augmentations,
        rgb_config=rgb_config,
    )

    del visual_encoder.backbone.fc_hidden

    return visual_encoder.backbone, visual_encoder.visual_transform

@hydra.main(config_path="configs", config_name="config_imagenav")
def main(cfg: DictConfig) -> None:
    r"""Main function for habitat_vc
    Args:
        cfg: DictConfig object containing the configs for the experiment.
    """
    run_exp(cfg)


def execute_exp(config: Config) -> None:
    r"""This function runs the specified config with the specified runtype
    Args:
    config: Habitat.config
    """
    # set a random seed (from detectron2)
    seed = (
        os.getpid()
        + int(datetime.now().strftime("%S%f"))
        + int.from_bytes(os.urandom(2), "big")
    )
    config.defrost()
    config.TASK_CONFIG.SEED = seed
    config.freeze()
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    setup_experiment(config)

    trainer = Trainer(config)
    trainer.evaluate(config.NUM_SAMPLES)


def run_exp(cfg: DictConfig) -> None:
    r"""Runs experiment given mode and config

    Args:
        cfg: DictConfig object containing the configs for the experiment.

    Returns:
        None.
    """
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)
    execute_exp(config)


def setup_experiment(config: Config) -> None:
    if rank0_only():
        os.makedirs(config.CHECKPOINT_FOLDER, exist_ok=True)
        os.makedirs(config.LOG_DIR, exist_ok=True)

    config.defrost()
    config.TASK_CONFIG.DATASET.SCENES_DIR = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.SCENES_DIR
    )
    config.TASK_CONFIG.DATASET.DATA_PATH = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.DATA_PATH
    )
    config.freeze()

    os.environ["LD_LIBRARY_PATH"] = (
        "/usr/lib/x86_64-linux-gnu/nvidia-opengl:" + os.environ["LD_LIBRARY_PATH"]
    )
    os.environ["GLOG_minloglevel"] = "3"
    os.environ["MAGNUM_LOG"] = "quiet"


class EMAState:
    def __init__(self, default=1.0, decay=0.999) -> None:
        self.decay = decay
        self.state = torch.tensor(default).float()

    def to(self, device: torch.device) -> None:
        self.state = self.state.to(device)

    def update(self, state: torch.Tensor) -> None:
        # check ddp
        
        if distrib.is_initialized():
            # reduce the state across all processes
            distrib.all_reduce(state, op=distrib.ReduceOp.SUM)
            state /= distrib.get_world_size()
        # update the state
        
        if self.state is None:
            self.state = state
        else:
            self.state = self.decay * self.state + (1 - self.decay) * state
    
    def get(self) -> torch.Tensor:
        return self.state


@triton.jit
def bool_gather_kernel(
    input_ptr,  # 输入矩阵指针
    mask_ptr,   # mask 指针
    indices_ptr,  # 前缀和索引指针
    output_ptr,  # 输出指针
    N, K,  # 输入维度
    input_row_stride, output_row_stride,  # 行步长
    BLOCK_SIZE_K: tl.constexpr,  # 每行处理的元素块大小
):
    # 当前线程处理的 token 索引
    row_idx = tl.program_id(0)
    
    # 检查 mask 是否为 True
    mask_val = tl.load(mask_ptr + row_idx)
    if mask_val:
        # 计算 output 的目标行
        output_row = tl.load(indices_ptr + row_idx) - 1  # 0-based
        
        # 逐块拷贝 input[row_idx] -> output[output_row]
        for k in range(0, K, BLOCK_SIZE_K):
            k_offsets = k + tl.arange(0, BLOCK_SIZE_K)
            mask_k = k_offsets < K  # 防止越界
            
            # 从 input 加载
            input_row_start = row_idx * input_row_stride
            input_vals = tl.load(
                input_ptr + input_row_start + k_offsets,
                mask=mask_k,
            )
            
            # 写入 output
            output_row_start = output_row * output_row_stride
            tl.store(
                output_ptr + output_row_start + k_offsets,
                input_vals,
                mask=mask_k,
            )

def bool_gather(input: torch.Tensor, mask: torch.Tensor):
    assert input.dim() == 2, "input must be 2D"
    assert mask.dim() == 1, "mask must be 1D"
    assert input.size(0) == mask.size(0), "input and mask must have same length"
    
    if input.device.type == "cpu":
        # CPU 版本
        return input[mask]

    # 计算 indices (前缀和)
    indices = torch.cumsum(mask, dim=0)  # 0-based
    
    # 计算输出大小
    # N_selected = int(mask.sum().item())
    N_selected = indices[-1].item()
    output = torch.empty((N_selected, input.size(1)), device=input.device, dtype=input.dtype)
    
    # 如果没有任何数据，直接返回空张量
    # if N_selected == 0:
    #     return output
    
    # 启动 Triton 内核
    grid = (input.size(0),)  # 每个 token 一个线程
    bool_gather_kernel[grid](
        input_ptr=input,
        mask_ptr=mask,
        indices_ptr=indices,
        output_ptr=output,
        N=input.size(0),
        K=input.size(1),
        input_row_stride=input.stride(0),
        output_row_stride=output.stride(0),
        BLOCK_SIZE_K=128,  # 可调整，通常 64-256
    )
    
    return output


class PretrainDataset:
    def __init__(self, config: Config) -> None:
        self.npz_dir = config.NPZ_DIR
        self.video_dir = config.VIDEO_DIR
        npz_files = os.listdir(self.npz_dir)
        npz_files = sorted(npz_files)
        self.episodes = [os.path.splitext(npz_file)[0] for npz_file in npz_files]

    def __iter__(self):
        for episode in self.episodes:
            npz_file = os.path.join(self.npz_dir, episode+".npz")
            
            data = np.load(npz_file, allow_pickle=True)
            rgb_images = data["rgb"]
            pre_mask = np.array([False]+[True]*(len(rgb_images)-1), dtype=bool)
            for i in range(len(rgb_images)):
                yield {
                    "rgb": rgb_images[i],
                    "pre_mask": pre_mask[i],
                    "episode": episode,
                }
    
    def shuffle(self):
        random.shuffle(self.episodes)


class Trainer:
    def __init__(self, config: Config) -> None:
        self.config = config
        self.local_rank = 0
        if config is not None and self.local_rank == 0:
            logger.add_filehandler(config.LOG_FILE)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
            self.device = torch.device("cpu")
        else:
            self.device = torch.device("cpu")
        
        self.model, self.pre_transform = build_model(self.config)
        self.model.to(self.device)
        self.model.eval()

        if hasattr(self.config, "RESUME") and self.config.RESUME:
            ckpt = load_resume_state(config.RESUME)
            # remove ddp module in key names
            new_state_dict = {}
            for k, v in ckpt["model"].items():
                if k.startswith("model."):
                    k = k[len("model."):]
                else:
                    continue
                new_state_dict[k] = v
            ckpt["model"] = new_state_dict
            if ckpt is not None:
                self.model.load_state_dict(ckpt["model"], strict=True)
                logger.info(f"Loaded checkpoint {config.RESUME}")
    
    def evaluate(self, num_samples: int = 100) -> None:
        memory = None
        dataset = PretrainDataset(self.config)
        # dataset.shuffle()
        self.model.eval()
        diffs = list()
        mem_time_all = 0
        inference_time_all = 0
        time_count = 0
        warmup = 2
        cnt = 0
        self.model.reference_last_frame_thr = 0.475
        with torch.no_grad():
            for data in dataset:
                cnt += 1
                imageio.imwrite(f"{cnt}.jpg", data["rgb"])
                rgb = torch.from_numpy(data["rgb"]).to(self.device).unsqueeze(0)
                rgb = (
                    rgb.permute(0, 3, 1, 2).float() / 255
                )
                rgb = self.pre_transform(rgb)
                rgb = rgb.repeat(1, 1, 1, 1)
                
                premask = data["pre_mask"]
                if not premask:
                    memory = None
                
                # warm up
                std, _, _, _ = self.eval_model(rgb)
                
                # standard inference without memory
                start_time = time.perf_counter()
                std, _, _, _ = self.eval_model(rgb)
                inference_time = time.perf_counter() - start_time
                
                
                outputs = list()
                start_time = time.perf_counter()
                for i in range(len(rgb)):
                    out, mem, mem_mask, certainty = self.eval_model(rgb[i:i+1], memory)
                certainty_img = certainty.flatten()[1:].reshape(40, 40).cpu().numpy()
                certainty_img = (certainty_img * 255).astype(np.uint8)
                imageio.imwrite(f"{cnt}_certainty.png", certainty_img)
                outputs.append(out)
                mem_time = time.perf_counter() - start_time
                memory = mem
                # import ipdb; ipdb.set_trace()
                
                if warmup > 0:
                    warmup -= 1
                    continue
                
                mem_time_all += mem_time
                inference_time_all += inference_time
                time_count += 1
                
                diff = F.mse_loss(torch.cat(outputs, dim=0), std, reduction="none")
                diffs.append({
                    'mean': diff.mean().item(),
                    'std': diff.std().item(),
                    'max': diff.max().item(),
                    'min': diff.min().item(),
                    'keep_ratio': mem_mask.sum().item() / mem_mask.numel(),
                })
                
                if time_count % (1 if self.device.type == "cpu" else 100) == 0:
                    logger.info(f"Memory time: {mem_time_all / time_count}, Inference time: {inference_time_all / time_count}")
                    diff_mean = np.mean([d['mean'] for d in diffs])
                    diff_std = np.mean([d['std'] for d in diffs])
                    diff_max = np.mean([d['max'] for d in diffs])
                    diff_min = np.mean([d['min'] for d in diffs])
                    keep_ratio = np.mean([d['keep_ratio'] for d in diffs])
                    logger.info(f"Diffs: max: {diff_max}, min: {diff_min}, mean: {diff_mean}, std: {diff_std}")
                    logger.info(f"Keep ratio: {keep_ratio}")

                    # diffs = list()
                    mem_time_all = 0
                    inference_time_all = 0
                    time_count = 0
                
                if len(diffs) >= num_samples:
                    break
        diff_mean = np.mean([d['mean'] for d in diffs])
        diff_std = np.mean([d['std'] for d in diffs])
        diff_max = np.mean([d['max'] for d in diffs])
        diff_min = np.mean([d['min'] for d in diffs])
        logger.info(f"Memory time: {mem_time_all / time_count}, Inference time: {inference_time_all / time_count}")
        logger.info(f"Diffs: max: {diff_max}, min: {diff_min}, mean: {diff_mean}, std: {diff_std}")
    
    # @torch.compile
    def eval_model(self, x, memory=None):
        B = x.shape[0]
        x = self.model.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.model.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.model.mask_ratio is not None:
            x, _, _ = self.model.random_masking(x, mask_ratio=self.model.mask_ratio)

        # append cls token
        cls_token = self.model.cls_token + self.model.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        if memory is not None:
            new_memory = []
            assert B == 1, "Batch size must be 1 for memory"
            for i, layer in enumerate(self.model.blocks):
                if i == self.model.reference_last_frame_layer_idx:
                    new_memory.append(x)
                    keys, values = memory
                    querys = x
                    output, certainty = self.model.reference_net(querys, keys, values, self.model.pos_embed)
                    mem_mask = certainty.flatten() < self.model.reference_last_frame_thr
                    mem_mask = torch.cat((torch.ones_like(mem_mask[:1]), mem_mask[1:]), dim=0)
                    # output = x
                    # x = x[mem_mask].reshape(B, -1, x.shape[-1])
                    
                    x_out = output.squeeze(0)
                    x = bool_gather(x.squeeze(0), mem_mask).unsqueeze(0)
                    # x_out = inc(mem_mask.float())
                    # x_out = inc(x.squeeze(0))
                x = layer(x)
            
            x_out[mem_mask] = x.squeeze(0)
            x = x_out.unsqueeze(0)
            
            new_memory.append(x)
            return x, new_memory, mem_mask, certainty
        else:
            new_memory = []
            for i, block in enumerate(self.model.blocks):
                if i == self.model.reference_last_frame_layer_idx:
                    new_memory.append(x)
                x = block(x)
            new_memory.append(x)
            tt = x[0,0,0].item()
            
            return x, new_memory, torch.ones_like(x[:, :, 0], dtype=torch.bool), torch.zeros_like(x[:, :, 0], dtype=torch.bool)


if __name__ == "__main__":
    main()
