import os, sys
import lightning
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
sys.path.append(os.path.join(os.getcwd(), "third_party", "LDMI"))
sys.path.append(os.path.join(os.getcwd(), "third_party", "taming-transformers"))
from model.hoimagenet import HoImageNet
from dataloader.imagenet_datamodule import ImageNetDataModule

from run_nerf_helpers import *

from hypernet_core import NeRF_tar, Hypernet
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from pdb import set_trace as bb
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from send2trash import send2trash
import hydra
from omegaconf import DictConfig
from hydra.core.hydra_config import HydraConfig
import logging

from third_party.LDMI.main import ImageLogger
logging.basicConfig(filename="run.log", level=logging.INFO)

env_list = os.environ["PATH"].split(":")
env_list.append("/usr/local/cuda/bin/")
os.environ["PATH"] = ":".join(env_list)
            
@hydra.main(config_path="configs/hydra", config_name="imagenet_ivqvae_debug")
def main(cfg):
    # cfg.stage = "test"
    lightning.seed_everything(cfg.seed)
    torch.set_float32_matmul_precision("medium")
    logger = TensorBoardLogger("logs", name=cfg.expname, version=cfg.version)
    if not hasattr(cfg, "gpu"):
        cfg.gpu = -1
    print(f"============== gpu config is", cfg.gpu)

    import yaml
    os.makedirs(logger.log_dir, exist_ok=True)  
    hydra_cfg = HydraConfig.get()
    
    config_info = {
        "config_path": hydra_cfg.runtime.choices.get('config_path', 'configs/hydra'),
        "config_name": hydra_cfg.runtime.choices.get('config_name', 'imagenet_ivqvae'),
        "all_choices": dict(hydra_cfg.runtime.choices),
    }
    config_info_path = os.path.join(logger.log_dir, "config_info.yaml")
    with open(config_info_path, 'w') as f:
        yaml.dump(config_info, f, default_flow_style=False)
    print(f"Config info saved to: {config_info_path}")

    trainer = lightning.Trainer(
        max_epochs=100000,
        accelerator="auto",
        devices=cfg.gpu,
        num_nodes=cfg.num_nodes,
        logger=logger,
        strategy="ddp_find_unused_parameters_true",
        val_check_interval=cfg.val_check_interval,
        precision="bf16-mixed" if cfg.bf16 else "32",
        log_every_n_steps=5,
        num_sanity_val_steps=2,
        callbacks=[
            ModelCheckpoint(
                monitor="val/total_loss",
                dirpath=os.path.join(logger.log_dir, "checkpoints"),
                filename="hoimagenet-{step:06d}-loss{val/total_loss:.6f}",
                save_top_k=-1,
                mode="min",
                every_n_train_steps=10000,
                save_last=False,
                auto_insert_metric_name=False,
            ),
            ModelCheckpoint(
                dirpath=os.path.join(logger.log_dir, "checkpoints"),
                filename="hoshapenet-{step:06d}",
                every_n_train_steps=5000,
                save_top_k=0,
                save_last=True,
                enable_version_counter=False,
            ),
            ImageLogger(
                batch_frequency=cfg.log_freq if hasattr(cfg, "log_freq") else 1000,
                max_images=9,
                increase_log_steps=False,
                log_first_step=False,
                # log_on_batch_idx=True
            )
        ],
    )
    data_module = ImageNetDataModule(cfg)
    if cfg.stage == "test":
        data_module.setup("test")
    else:
        data_module.setup()

    hoimagenet = HoImageNet(cfg, data_module)

    if cfg.ckpt_path == 'None':
        ckpt_path = os.path.join(logger.log_dir, "checkpoints", "last.ckpt")
    else:
        ckpt_path = cfg.ckpt_path

    if cfg.stage == "fit":
        if cfg.ckpt_path != "None" and os.path.isfile(cfg.ckpt_path):
            print(f"Loading encoder from ckpt: {cfg.ckpt_path}")
            ckpt = torch.load(cfg.ckpt_path, map_location="cpu")
            state_dict = ckpt.get("state_dict", ckpt)

            encoder_state_dict = {
                k.replace("first_stage_model.encoder.", ""): v
                for k, v in state_dict.items()
                if k.startswith("first_stage_model.encoder.")
            }

            if hasattr(hoimagenet, "voxel_encoder"):
                missing, unexpected = hoimagenet.voxel_encoder.load_state_dict(
                    encoder_state_dict, strict=False
                )
                print(f"[voxel_encoder] Loaded. Missing keys: {missing}")
                print(f"[voxel_encoder] Unexpected keys: {unexpected}")
                print("[voxel_encoder] Loaded keys:")
                for k in encoder_state_dict.keys():
                    print(f"  - {k}")
            else:
                print("Warning: hoimagenet has no attribute 'voxel_encoder'")
            remaining_state_dict = {
                k.replace("first_stage_model.", ""): v
                for k, v in state_dict.items()
                if k.startswith("first_stage_model.")
            }

        print("[General] Attempting to load additional matching keys into hoimagenet:")
        try:
            missing, unexpected = hoimagenet.load_state_dict(remaining_state_dict, strict=False)
            successfully_loaded_keys = set(remaining_state_dict.keys()) - set(missing) - set(unexpected)
            loaded_keys = [k for k in successfully_loaded_keys if not k.startswith("encoder.")]
            
            for k in loaded_keys:
                print(f"  - {k}")
            
            print(f"[General] Loaded {len(loaded_keys)} non-encoder keys. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
        except Exception as e:
            print(f"[General] Failed to load additional state_dict into hoimagenet: {e}")

        ckpt_path = None

        trainer.fit(hoimagenet, datamodule=data_module, ckpt_path=ckpt_path)



if __name__ == "__main__":
    main()
