import os
DISTRL_DEBUG_PROFILING = os.environ.get('DISTRL_DEBUG_PROFILING', None)
os.environ['DISTRL_RL'] = "1"

import json

from typing import Callable
import numpy as np
from datetime import timedelta

import torch
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

from datasets import load_dataset

from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed, InitProcessGroupKwargs


# DPOK imports
from distrl.autofid.parse_args import parse_args
from distrl.autofid.utils import (
    _update_output_dir, _trim_buffer, _save_model, init_state_dict,
    safe_distributed_all_reduce, create_symlinks
)
from distrl.autofid.models.policy import TrainPolicyFuncData, _train_policy_func, _train_policy_func_eachimg, _train_policy_func_flat
from distrl.autofid.models.rollout import _collect_rollout, prepare_policy_samples, _collect_rollout_flat
from distrl.autofid.image_pool import ImagePool

# Model initialization imports
from distrl.autofid.reject_sampling.edm2 import init_edm2_model
from distrl.autofid.reject_sampling.sit import init_sit_model
from distrl.profiling import ProfilerTimer

from distrl.autofid.reject_sampling.rollout import _get_global_best_samples, _get_local_best_samples
from distrl.autofid.models.filter import _get_local_best_and_worst_samples


def get_policy_train_func(args):
    """Select the appropriate policy training function based on args.

    Args:
        args: Arguments object

    Returns:
        The appropriate policy training function
    """
    if args.flat_rollout:
        return _train_policy_func_flat
    elif args.p_loss_for_each_img:
        return _train_policy_func_eachimg
    else:
        return _train_policy_func

def get_rollout_func(args):
    """Select the appropriate rollout function based on args.

    Args:
        args: Arguments object

    Returns:
        The appropriate rollout function
    """
    if args.flat_rollout:
        return _collect_rollout_flat
    else:
        return _collect_rollout

def main():
    """Main training function."""
    args = parse_args()

    # Update output directory based on arguments
    _update_output_dir(args)
    logging_dir = os.path.join(args.output_dir, args.logging_dir)

    # Initialize accelerator
    accelerator_project_config = ProjectConfiguration(
        logging_dir=logging_dir, total_limit=args.checkpoints_total_limit
    )
    accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=1))
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        project_config=accelerator_project_config,
        kwargs_handlers=[accelerator_kwargs]
    )

    # Initialize profiler after accelerator is available
    profiler = ProfilerTimer(
        enabled=bool(DISTRL_DEBUG_PROFILING),
        is_main_process=accelerator.is_local_main_process
    )

    # Collect environment variables with prefix DPOK_ to add to config
    env_var_prefix = "DISTRL_"
    debug_env_vars = {}
    for key, value in os.environ.items():
        if key.startswith(env_var_prefix):
            debug_env_vars[key] = value

    # Set random seed if provided
    if args.seed is not None:
        set_seed(args.seed, device_specific=True)

    # Create output directory
    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    # Setup mixed precision data type
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # * Initialize the model based on model_type
    if args.model_type.lower() == "edm2":
        pipe, unet, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count = init_edm2_model(
            args, accelerator, weight_dtype
        )
    elif args.model_type.lower() == "sit":
        pipe, unet, unet_copy, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count = init_sit_model(
            args, accelerator, weight_dtype, unet_copy=True
        )
    else:
        raise NotImplementedError

    # Scale learning rate if specified
    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate
            * args.gradient_accumulation_steps
            * args.train_batch_size
            * accelerator.num_processes
        )

    # Load dataset or prompt list
    if args.dataset_name is not None:
        # Download and load dataset from hub
        load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
        )
    else:
        # Load prompts from JSON file
        with open(args.prompt_path) as json_file:
            prompt_dict = json.load(json_file)

        # Filter prompts by category if specified
        if args.prompt_category != "all":
            prompt_category = [e for e in args.prompt_category.split(",")]

        prompt_list = []
        for prompt in prompt_dict:
            category = prompt_dict[prompt]["category"]
            if args.prompt_category != "all":
                if category in prompt_category:
                    prompt_list.append(prompt)
            else:
                prompt_list.append(prompt)

    # Calculate total batch size and prepare for training
    total_batch_size = (
        args.train_batch_size
        * accelerator.num_processes
        * args.gradient_accumulation_steps
    )

    # Resume from checkpoint if specified
    global_step = 0
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting"
                " a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(args.resume_from_checkpoint)
            global_step = int(path.split("-")[1])

    # Move pipe to CUDA explicitly
    try:
        pipe.to("cuda")
    except Exception as e:
        print(f"Error moving pipe to cuda: {e}")

    # Set up progress bar
    progress_bar = tqdm(
        range(global_step, args.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description("Steps")

    # Initialize state dictionary for storing experiences
    state_dict = init_state_dict(weight_dtype)

    # Training loop preparation
    count = 0
    buffer_size = args.buffer_size
    policy_steps = args.gradient_accumulation_steps * args.p_step
    is_ddp = isinstance(unet, DistributedDataParallel)
    accelerator.print("model is parallel:", is_ddp)

    # Create and fill image pool
    profiler.start("Image Pool Creation & Filling")
    pool_output_dir = args.image_pool_dir if args.image_pool_dir is not None else os.path.join(args.output_dir, "image_pool")
    image_pool = ImagePool.from_args(args, accelerator, pool_output_dir)

    # Fill the pool with balanced images
    with accelerator.autocast():
        image_pool.fill_the_pool(
            pipe=pipe,
            prompts=prompt_list,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            seed=args.seed,
        )
    profiler.end("Image Pool Creation & Filling")

    # Load GT statistics
    image_pool.load_gt_fid_stats(args.gt_fid_stats)

    # Load GT FID statistics and compute FID scores
    if accelerator.is_main_process:
        profiler.start("FID Statistics Computation")
        try:
            # Compute pool statistics
            mu, sigma = image_pool.compute_fid_statistics()

            # Save pool statistics for later use
            fid_stats = {"mu": mu, "sigma": sigma}
            np.savez(os.path.join(args.output_dir, "fid_statistics.npz"), **fid_stats)

            # Compute FID score if GT stats are available
            if image_pool.gt_mu is not None:
                _ = image_pool.compute_fid()

        except Exception:
            pass
        profiler.end("FID Statistics Computation")

    model_output_dir = args.model_output_dir if args.model_output_dir is not None else os.path.join(args.output_dir)
    os.makedirs(model_output_dir, exist_ok=True)

    # Select the appropriate policy training function
    train_policy_func: Callable = get_policy_train_func(args)
    collect_rollout_func: Callable = get_rollout_func(args)

    # Main Training Loop
    profiler.start(f"Complete Training Loop ({args.max_train_steps // args.p_step - start_count} steps)")
    for count in range(start_count, args.max_train_steps // args.p_step):
        if accelerator.is_main_process:
            print(f"[START] Step {count} / {args.max_train_steps // args.p_step}")
        # Set model to evaluation mode for rollout collection
        if is_ddp:
            if hasattr(unet.module, 'eval'):
                unet.module.eval()
        else:
            try:
                unet.eval()
            except (AttributeError, TypeError):
                print("Warning: Could not set unet to eval mode. Continuing without it.")

        # 0. refill the pool
        if count % args.refill_interval == 0 and count > start_count:
            profiler.start(f"Pool Refill (step {count})")
            accelerator.print(f"[REFILL] Refilling the pool at step {count}")
            image_pool._reset_pool()
            accelerator.wait_for_everyone()
            with accelerator.autocast():
                image_pool.fill_the_pool(
                    pipe=pipe,
                    prompts=prompt_list,
                    num_inference_steps=args.num_inference_steps,
                    guidance_scale=args.guidance_scale,
                    seed=args.seed + count,
                    force_fill=True,
                )
            if image_pool.gt_mu is not None and accelerator.is_main_process:
                fid = image_pool.compute_fid()
                accelerator.print(f"FID score after refill: {fid:.2f}")
            profiler.end(f"Pool Refill (step {count})")

        state_dict = init_state_dict(weight_dtype)

        profiler.start(f"Rollout Collection (step {count})")
        with accelerator.autocast():
            rollout_results = collect_rollout_func(
                args=args,
                pipe=pipe,
                policy_model=unet,
                is_ddp=is_ddp,
                image_pool=image_pool,
                state_dict=state_dict,
                accelerator=accelerator,
                count=count,
            )
        profiler.end(f"Rollout Collection (step {count})")

        # Extract FID values from rollout results
        local_fid_list = rollout_results["local_fids"] if args.global_flag == 0 else rollout_results["fid"]
        if args.global_flag > 0:
            local_fid_list = [-x for x in local_fid_list]
        sorted_fids = sorted(local_fid_list)
        if args.global_flag == 2:
            selected_fids_to_print = sorted_fids[:args.num_best_samples] + sorted_fids[-args.num_best_samples:]
        elif args.global_flag == -1:
            selected_fids_to_print = sorted_fids
        else:
            selected_fids_to_print = sorted_fids[:args.num_best_samples]

        if args.global_flag in [0, 2]:
            # 2. Get locally best samples based on rewards (each process selects local best)
            profiler.start(f"Local Best Sample Selection (step {count})")
            if args.global_flag == 0:
                selected_samples = _get_local_best_samples(args, state_dict, accelerator, count, is_ddp)
            elif args.global_flag == 2:
                selected_samples = _get_local_best_and_worst_samples(args, state_dict, accelerator, count, is_ddp)
            profiler.end(f"Local Best Sample Selection (step {count})")

            selected_local_fids = selected_fids_to_print

            profiler.start(f"Sync Selected FIDs to All Processes (step {count})")
            # Sync selected FIDs to all processes for logging
            if len(selected_local_fids) > 0:
                selected_fid_tensor = torch.tensor(selected_local_fids, device=accelerator.device)

                if is_ddp and accelerator.num_processes > 1:
                    # Gather selected FID values from all processes
                    gathered_selected_fid = accelerator.gather(selected_fid_tensor)
                    if isinstance(gathered_selected_fid, torch.Tensor):
                        all_selected_fid_values = gathered_selected_fid.cpu().tolist()
                    else:
                        all_selected_fid_values = selected_local_fids
                else:
                    all_selected_fid_values = selected_local_fids
            else:
                all_selected_fid_values = []
            profiler.end(f"Sync Selected FIDs to All Processes (step {count})")
            # log all selected fids from all processes
            selected_fids_to_print = all_selected_fid_values
        elif args.global_flag == 1:
            # 2. Get globally best samples based on rewards
            profiler.start(f"Global Best Sample Selection (step {count})")
            selected_samples = _get_global_best_samples(args, state_dict, accelerator, count, is_ddp)
            profiler.end(f"Global Best Sample Selection (step {count})")

        # Shared info printing
        if accelerator.is_main_process:
            accelerator.print(
                f"[{count}/{args.max_train_steps // args.p_step}] Selected FIDs: {selected_fids_to_print}"
            )
            if args.global_flag == 1:
                accelerator.print(f"global_fid_list: {sorted_fids}")

        _trim_buffer(buffer_size, state_dict)

        # 3. Train policy function (Actor)
        tpfdata = TrainPolicyFuncData()

        # Prepare non-repeating samples if specified
        if args.use_non_repeating_samples:
            prepare_policy_samples(state_dict, args)

        for p_step in range(args.p_step):
            profiler.start(f"Policy Training (step {count}, p_step {p_step})")
            optimizer_policy.zero_grad()
            for accum_step in range(int(args.gradient_accumulation_steps)):
                if accum_step < int(args.gradient_accumulation_steps) - 1:
                    with accelerator.no_sync(unet):
                        with accelerator.autocast():
                            train_policy_func(
                                args,
                                state_dict,
                                pipe,
                                unet_copy,
                                is_ddp,
                                count,
                                policy_steps,
                                accelerator,
                                tpfdata,
                            )
                else:
                    with accelerator.autocast():
                        train_policy_func(
                            args,
                            state_dict,
                            pipe,
                            unet_copy,
                            is_ddp,
                            count,
                            policy_steps,
                            accelerator,
                            tpfdata,
                        )

            # Gradient clipping
            norm = None
            if accelerator.sync_gradients:
                norm = accelerator.clip_grad_norm_(unet.parameters(), args.clip_norm)
            if norm is not None:
                tpfdata.tot_grad_norm += norm.item() / args.p_step

            # Update model parameters
            optimizer_policy.step()
            lr_scheduler_policy.step()
            profiler.end(f"Policy Training (step {count}, p_step {p_step})")

            # Compute mean reward with all_reduce to sync across processes
            reward_mean = torch.mean(state_dict["final_reward"])
            reward_mean = safe_distributed_all_reduce(reward_mean, accelerator.device)

            # Simple prints similar to previous logging
            if accelerator.is_main_process:
                print(f"count: [{count} / {args.max_train_steps // args.p_step}]")
                print(f"train_reward: {reward_mean.item()}")
                print(f"grad_norm: {tpfdata.tot_grad_norm}, ratio: {tpfdata.tot_ratio}")
                print(f"kl: {tpfdata.tot_kl}, p_loss: {tpfdata.tot_p_loss}")
                print(f"adv_mean: {tpfdata.tot_adv_mean}, adv_std: {tpfdata.tot_adv_std}")

                # Clear lists after logging to avoid memory issues
                tpfdata.unclipped_ratio_values.clear()
                tpfdata.ratio_values.clear()
                tpfdata.adv_values.clear()

            # Clear GPU memory
            torch.cuda.empty_cache()

        # Save checkpoint periodically
        if accelerator.sync_gradients:
            global_step += 1
            if global_step % args.checkpointing_steps == 0:
                profiler.start(f"Checkpoint Save (step {count})")
                save_path = os.path.join(model_output_dir, f"checkpoint-{global_step}")
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    accelerator.save_state(output_dir=save_path)
                if accelerator.is_main_process and args.model_output_dir is not None:
                    create_symlinks(args.output_dir, model_output_dir)
                profiler.end(f"Checkpoint Save (step {count})")
            accelerator.print(">>>>>> global_step", global_step)

        # Save model periodically
        if count % args.save_interval == 0:
            profiler.start(f"Model Save (step {count})")
            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                _save_model(args, count, is_ddp, accelerator, unet, model_output_dir=model_output_dir)
                if args.model_output_dir is not None:
                    create_symlinks(args.output_dir, model_output_dir)
            profiler.end(f"Model Save (step {count})")

    profiler.end(f"Complete Training Loop ({args.max_train_steps // args.p_step - start_count} steps)")

    # Finish training, save final model
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        # Save final model
        _save_model(args, count, is_ddp, accelerator, unet, model_output_dir=model_output_dir)
        if args.model_output_dir is not None:
            create_symlinks(args.output_dir, model_output_dir)

    # End training with accelerator
    accelerator.end_training()


if __name__ == "__main__":
    main()
