import os, sys
import lightning
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

from model.honerf2vec import HoNeRF2Vec
from dataloader.honf2vec_datamodule import HoNeRF2VecDataModule
from run_nerf_helpers import *

from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from pdb import set_trace as bb
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from send2trash import send2trash
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
import logging
import argparse

logging.basicConfig(filename="run.log", level=logging.INFO)

env_list = os.environ["PATH"].split(":")
env_list.append("/usr/local/cuda/bin/")
os.environ["PATH"] = ":".join(env_list)

def main(cfg):
    lightning.seed_everything(cfg.seed)
    torch.set_float32_matmul_precision("medium")
    logger = TensorBoardLogger("logs", name=cfg.expname, version=cfg.version)
    if not hasattr(cfg, "gpu"):
        cfg.gpu = -1
    print(f"============== gpu config is", cfg.gpu)
    trainer = lightning.Trainer(
        max_epochs=100000,
        accelerator="auto",
        devices=cfg.gpu,
        num_nodes=cfg.num_nodes,
        logger=logger,
        strategy="ddp_find_unused_parameters_true",
        val_check_interval=cfg.val_check_interval,
        precision="bf16-mixed" if cfg.bf16 else "32",
        log_every_n_steps=5,
        num_sanity_val_steps=0,
        callbacks=[
            ModelCheckpoint(
                monitor="val/psnr",
                dirpath=os.path.join(logger.log_dir, "checkpoints"),
                filename="honerf-{step:06d}-val_psnr={val/psnr:.2f}",
                save_top_k=-1,
                mode="max",
                every_n_train_steps=5000,
                save_last=False,
                auto_insert_metric_name=False,
            ),
            ModelCheckpoint(
                dirpath=os.path.join(logger.log_dir, "checkpoints"),
                filename="honerf-{step:06d}",
                every_n_train_steps=1000,
                save_top_k=0,
                save_last=True,
                enable_version_counter=False,
            ),
        ],
    )
    data_module = HoNeRF2VecDataModule(cfg)
    if cfg.stage == "test":
        data_module.setup("test")
    else:
        data_module.setup()

    honerf2vec = HoNeRF2Vec(cfg, data_module)

    if cfg.ckpt_path == 'None':
        ckpt_path = os.path.join(logger.log_dir, "checkpoints", "last.ckpt")
    else:
        ckpt_path = cfg.ckpt_path


    if cfg.stage == "fit":
        trainer.fit(
            honerf2vec,
            datamodule=data_module,
            ckpt_path=ckpt_path if os.path.isfile(ckpt_path) else None,
        )
    else:
        trainer.test(honerf2vec, datamodule=data_module, ckpt_path=ckpt_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train HoNeRF2Vec', add_help=False)
    parser.add_argument('--config-path', '--config-dir', type=str, default=None,
                        help='Path to config directory or full path to config file (default: configs/hydra)')
    parser.add_argument('--config-name', '--cn', type=str, default=None,
                        help='Config name without .yaml extension (default: honerf2vec)')
    parser.add_argument('--config-file', type=str, default=None,
                        help='Full path to config file (alternative to --config-path and --config-name)')
    
    args, hydra_overrides = parser.parse_known_args()
    
    if args.config_file:
        config_file = args.config_file
        if not os.path.isabs(config_file):
            config_file = os.path.abspath(config_file)
        
        if not os.path.isfile(config_file):
            raise ValueError(f"Config file not found: {config_file}")
        
        config_path = os.path.dirname(config_file)
        config_name = os.path.splitext(os.path.basename(config_file))[0]
        
    elif args.config_path:
        config_path = args.config_path
        if not os.path.isabs(config_path):
            config_path = os.path.abspath(config_path)
        
        if os.path.isfile(config_path):
            config_file = config_path
            config_path = os.path.dirname(config_file)
            config_name = os.path.splitext(os.path.basename(config_file))[0]
        elif os.path.isdir(config_path):
            if args.config_name:
                config_name = args.config_name
            else:
                config_name = 'honerf2vec'
        else:
            raise ValueError(f"Config path not found: {config_path}")
    else:
        config_path = os.path.abspath('configs/hydra')
        config_name = args.config_name if args.config_name else 'honerf2vec'
    
    if not os.path.isdir(config_path):
        raise ValueError(f"Config directory not found: {config_path}")
    
    default_config_dir = os.path.abspath('configs/hydra')
    default_config_dir_normalized = os.path.normpath(default_config_dir)
    config_path_normalized = os.path.normpath(config_path)
    is_log_dir = not config_path_normalized.startswith(default_config_dir_normalized + os.sep) and config_path_normalized != default_config_dir_normalized
    
    if GlobalHydra.instance().is_initialized():
        GlobalHydra.instance().clear()
    
    try:
        if is_log_dir:
            print(f"Loading config from log directory: {config_path}/{config_name}.yaml")
            print(f"Using default config directory: {default_config_dir} for defaults resolution")
            
            log_config_file = os.path.join(config_path, f"{config_name}.yaml")
            if not os.path.isfile(log_config_file):
                raise FileNotFoundError(f"Config file not found: {log_config_file}")
            
            log_cfg = OmegaConf.load(log_config_file)
            
            if 'defaults' in log_cfg:
                default_file = os.path.join(default_config_dir, 'default.yaml')
                if os.path.isfile(default_file):
                    default_cfg = OmegaConf.load(default_file)
                    cfg = OmegaConf.merge(default_cfg, log_cfg)
                    if 'defaults' in cfg:
                        del cfg['defaults']
                else:
                    print(f"Warning: default.yaml not found in {default_config_dir}, using log config only")
                    cfg = log_cfg
                    if 'defaults' in cfg:
                        del cfg['defaults']
            else:
                cfg = log_cfg
            
            if hydra_overrides:
                override_cfg = OmegaConf.from_dotlist(hydra_overrides)
                cfg = OmegaConf.merge(cfg, override_cfg)
        else:
            with hydra.initialize_config_dir(config_dir=config_path, version_base=None):
                cfg = hydra.compose(config_name=config_name, overrides=hydra_overrides)
    except Exception as e:
        print(f"Error loading config from {config_path}/{config_name}.yaml")
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        raise
    
    print(f"Loaded config from: {config_path}/{config_name}.yaml")
    if hydra_overrides:
        print(f"Config overrides: {hydra_overrides}")
    
    main(cfg)
