import json
import logging
import os
from pathlib import Path

import hydra
import torch
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from engine.train_loop import train_loop
from omegaconf import DictConfig, OmegaConf
from utils.checkpoint import load_checkpoint
from utils.data_setup import setup_data
from utils.log_utils import create_logger
from utils.model_setup import setup_model_and_optimizer

logger = logging.getLogger(__name__)


@hydra.main(
    config_path="./cfgs", config_name="diffusion_sit_rohban", version_base=None
    # config_path="./cfgs", config_name="diffusion_sit_full", version_base=None
)
def main(cfg: DictConfig) -> None:
    args = cfg
    logging_dir = Path(args.output_dir, args.logging_dir)
    accelerator_project_config = ProjectConfiguration(
        project_dir=args.output_dir, logging_dir=logging_dir
    )
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )
    min_recorded_avg_fid = float("inf")
    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
        save_dir = os.path.join(args.output_dir, args.exp_name)
        os.makedirs(save_dir, exist_ok=True)
        args_dict = OmegaConf.to_container(args, resolve=True)
        json_dir = os.path.join(save_dir, "args.json")
        with open(json_dir, "w") as f:
            json.dump(args_dict, f, indent=4)
        checkpoint_dir = f"{save_dir}/checkpoints"
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(save_dir)
        logger.info(f"Experiment directory created at {save_dir}")
    else:
        checkpoint_dir = os.path.join(args.output_dir, args.exp_name, "checkpoints")
        logger = logging.getLogger(__name__)
    device = accelerator.device
    if torch.backends.mps.is_available():
        accelerator.native_amp = False
    if args.seed is not None:
        set_seed(args.seed + accelerator.process_index)
    # Model, optimizer, VAE, encoders, etc.
    (
        model,
        ema,
        vae,
        encoders,
        optimizer,
        loss_fn,
        latents_scale,
        latents_bias,
        encoder_types,
        architectures,
    ) = setup_model_and_optimizer(args, device, accelerator)
    # Data
    datamodule, train_dataloader = setup_data(cfg)
    # Resume from checkpoint if needed
    global_step = 0
    if args.resume_step > 0:
        model, ema, optimizer, global_step = load_checkpoint(
            model, ema, optimizer, args, checkpoint_dir
        )
    # Prepare for distributed
    model, optimizer, train_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader
    )
    # Run training loop
    train_loop(
        args,
        accelerator,
        model,
        ema,
        vae,
        encoders,
        optimizer,
        loss_fn,
        latents_scale,
        latents_bias,
        encoder_types,
        architectures,
        datamodule,
        train_dataloader,
        checkpoint_dir,
        logger,
        device,
        min_recorded_avg_fid,
    )


if __name__ == "__main__":
    main()
