import os, sys
import lightning
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
sys.path.append(os.path.join(os.getcwd(), "third_party", "LDMI"))
sys.path.append(os.path.join(os.getcwd(), "third_party", "taming-transformers"))
from model.hocelebahq import HoCelebAHQ
from dataloader.celebahq_datamodule import *

from run_nerf_helpers import *

from hypernet_core import NeRF_tar, Hypernet
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from pdb import set_trace as bb
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from send2trash import send2trash
import hydra
from omegaconf import DictConfig
from hydra.core.hydra_config import HydraConfig
import logging

from third_party.LDMI.main import ImageLogger 
logging.basicConfig(filename="run.log", level=logging.INFO)

env_list = os.environ["PATH"].split(":")
env_list.append("/usr/local/cuda/bin/")
os.environ["PATH"] = ":".join(env_list)
            
@hydra.main(config_path="configs/hydra", config_name="celebahq_ivae_debug")
def main(cfg):
    lightning.seed_everything(cfg.seed)
    torch.set_float32_matmul_precision("medium")
    logger = TensorBoardLogger("logs", name=cfg.expname, version=cfg.version)
    if not hasattr(cfg, "gpu"):
        cfg.gpu = -1
    print(f"============== gpu config is", cfg.gpu)
    

    trainer = lightning.Trainer(
        max_epochs=100000,
        accelerator="auto",
        devices=cfg.gpu,
        num_nodes=cfg.num_nodes,
        logger=logger,
        strategy="ddp_find_unused_parameters_true",
        val_check_interval=cfg.val_check_interval,
        precision="bf16-mixed" if cfg.bf16 else "32",
        log_every_n_steps=5,
        num_sanity_val_steps=0,
        callbacks=[
            ModelCheckpoint(
                monitor="val/total_loss",
                dirpath=os.path.join(logger.log_dir, "checkpoints"),
                filename="hoshapenet-{step:06d}-loss{val/total_loss:.6f}",
                save_top_k=-1,
                mode="min",
                every_n_train_steps=10000,
                save_last=False,
                auto_insert_metric_name=False,
            ),
            ModelCheckpoint(
                dirpath=os.path.join(logger.log_dir, "checkpoints"),
                filename="hoshapenet-{step:06d}",
                every_n_train_steps=2000,
                save_top_k=0,
                save_last=True,
                enable_version_counter=False,
            ),
            ImageLogger(
                batch_frequency=cfg.log_freq if hasattr(cfg, "log_freq") else 50000,
                max_images=12,
                cols=6,
                clamp=True,
                increase_log_steps=False,
                log_first_step=False,
            )
        ],
    )
    data_module = CelebAHQDataModule(cfg)
    if cfg.stage == "test":
        data_module.setup("test")
    else:
        data_module.setup()

    hocelebahq = HoCelebAHQ(cfg, data_module)

    if cfg.ckpt_path == 'None':
        ckpt_path = os.path.join(logger.log_dir, "checkpoints", "last.ckpt")
    else:
        ckpt_path = cfg.ckpt_path


    if cfg.stage == "fit":
        trainer.fit(
            hocelebahq,
            datamodule=data_module,
            ckpt_path=ckpt_path if os.path.isfile(ckpt_path) else None,
        )
    else:
        trainer.test(hocelebahq, datamodule=data_module, ckpt_path=ckpt_path)

if __name__ == "__main__":
    main()
