"""Utility functions for DPOK training."""

import os
import random
import copy
import torch
import numpy as np

def init_state_dict(weight_dtype):
    """Initialize the state dictionary for collecting training data.

    Each rollout will add num_groups * num_inference_steps rows.
    Most tensors have shape [num_groups * num_inference_steps, num_samples, ...],
    but timestep and final_reward have shape [num_groups * num_inference_steps].
    """
    state_dict = {
        # Tensors with shape [num_groups * num_inference_steps, num_samples, ...]
        "state": torch.FloatTensor().to(weight_dtype).cpu(),
        "next_state": torch.FloatTensor().to(weight_dtype).cpu(),
        "unconditional_prompt_embeds": torch.FloatTensor().to(weight_dtype).cpu(),
        "guided_prompt_embeds": torch.FloatTensor().to(weight_dtype).cpu(),
        "log_prob": torch.FloatTensor().to(weight_dtype).cpu(),
        "pool_indices": torch.LongTensor().cpu(),

        # Tensors with shape [num_groups * num_inference_steps]
        "timestep": torch.LongTensor().cpu(),
        "final_reward": torch.FloatTensor().to(weight_dtype).cpu(),

        # For non-repeating samples (if enabled)
        "policy_indices": None,  # Sample indices for policy training
        "policy_index_position": 0,  # Current position in policy_indices
    }
    return state_dict

def _update_output_dir(args):
    """Modifies `args.output_dir` using configurations in `args`.

    Args:
        args: argparse.Namespace object.
    """
    if args.single_flag == 1:
        data_log = "single_prompt/" + args.single_prompt.replace(" ", "_") + "/"
    else:
        data_log = args.prompt_path.split("/")[-2] + "_"
        data_log += args.prompt_category + "/"
    learning_log = "p_lr" + str(args.learning_rate) + "_s" + str(args.p_step)
    learning_log += (
        "_b"
        + str(args.p_batch_size)
        + "_g"
        + str(args.gradient_accumulation_steps)
    )
    learning_log += "_l" + str(args.lora_rank)
    coeff_log = "_kl" + str(args.kl_weight) + "_re" + str(args.reward_weight)
    if args.kl_warmup > 0:
        coeff_log += "_klw" + str(args.kl_warmup)
    if args.sft_initialization == 0:
        start_log = "/pre_train/"
    else:
        start_log = "/sft/"
    if args.reward_flag == 0:
        args.output_dir += "/img_reward_{}/".format(args.reward_filter)
    else:
        args.output_dir += "/prev_reward_{}/".format(args.reward_filter)
    args.output_dir += start_log + data_log + "/" + learning_log + coeff_log

def _get_batch(data_iter_loader, data_iterator, prompt_list, args, accelerator):
    """Creates a batch of prompts for training.

    Args:
        data_iter_loader: Iterator over dataset.
        data_iterator: Function to create iterator.
        prompt_list: List of prompts.
        args: Arguments.
        accelerator: Accelerator object.

    Returns:
        List of prompts.
    """
    batch = next(data_iter_loader, None)
    if batch is None:
        batch = next(
            iter(
                accelerator.prepare(
                    data_iterator(prompt_list, batch_size=args.g_batch_size)
                )
            )
        )

    if args.single_flag == 1:
        for i in range(len(batch)):
            batch[i] = args.single_prompt

    batch_list = []
    for i in range(len(batch)):
        batch_list.extend([batch[i] for _ in range(args.num_samples)])
    batch = batch_list
    return batch

def _trim_buffer(buffer_size, state_dict):
    """Delete old samples from the buffer to maintain the specified size.

    Args:
        buffer_size: Maximum size of the buffer.
        state_dict: Dictionary containing collected experiences.
    """
    if state_dict["state"].shape[0] > buffer_size:
        # Trim tensors that have shape [num_groups * num_inference_steps, num_samples, ...]
        for key in ["state", "next_state", "unconditional_prompt_embeds",
                   "guided_prompt_embeds", "log_prob", "pool_indices"]:
            state_dict[key] = state_dict[key][-buffer_size:]

        # Trim tensors that have shape [num_groups * num_inference_steps]
        for key in ["timestep", "final_reward"]:
            state_dict[key] = state_dict[key][-buffer_size:]

        # Reset policy indices if they exist
        if state_dict["policy_indices"] is not None:
            state_dict["policy_indices"] = None
            state_dict["policy_index_position"] = 0

def _save_model(args, count, is_ddp, accelerator, policy_model, value_function=None, model_output_dir=None):
    """Saves policy model and optionally value function.

    Args:
        args: Arguments.
        count: Current training step.
        is_ddp: Whether the model is using DistributedDataParallel.
        accelerator: Accelerator object.
        policy_model: Policy model to save (AttnProcsLayers).
        value_function: Optional value function to save.
    """
    import torch
    import os
    import copy

    if model_output_dir is None:
        save_path = os.path.join(args.output_dir, f"save_{count}")
    else:
        save_path = os.path.join(model_output_dir, f"save_{count}")
    print(f"Saving model to {save_path}")
    os.makedirs(save_path, exist_ok=True)

    # Save policy model (AttnProcsLayers)
    if is_ddp:
        model_to_save = accelerator.unwrap_model(policy_model)
    else:
        model_to_save = policy_model

    # Save the policy model state dict
    policy_path = os.path.join(save_path, "policy_model.pt")
    torch.save(model_to_save.state_dict(), policy_path)

    # Save value function if provided
    if value_function is not None:
        value_path = os.path.join(save_path, "value_function.pt")
        if isinstance(value_function, torch.nn.parallel.DistributedDataParallel):
            value_model = accelerator.unwrap_model(value_function)
            torch.save(value_model.state_dict(), value_path)
        else:
            torch.save(value_function.state_dict(), value_path)

    # Save metadata
    metadata = {
        "train_iteration": count,
        "model_type": "dpok_policy",
        "num_inference_steps": getattr(args, 'num_inference_steps', 50),
        "lora_rank": getattr(args, 'lora_rank', 4)
    }
    metadata_path = os.path.join(save_path, "metadata.pt")
    torch.save(metadata, metadata_path)

def get_random_indices(num_indices, sample_size):
    """Returns a random sample of indices from a larger list of indices.

    Args:
        num_indices (int): The total number of indices to choose from.
        sample_size (int): The number of indices to choose.

    Returns:
        A numpy array of `sample_size` randomly chosen indices.
    """
    return np.random.choice(num_indices, size=sample_size, replace=False)

def create_data_iterator(data, batch_size):
    """Creates an iterator over the data.

    Args:
        data: List of data items.
        batch_size: Batch size.

    Returns:
        Iterator over batches.
    """
    # Shuffle the data randomly
    random.shuffle(data)

    for i in range(0, len(data), batch_size):
        batch = data[i : i + batch_size]
        yield batch

def get_test_prompts(flag="default"):
    """Gets test prompts for evaluation.

    Args:
        flag: Type of test prompts to return.

    Returns:
        List of test prompts.
    """
    if flag == "drawbench":
        test_batch = [
            "A pink colored giraffe.",
            (
                "An emoji of a baby panda wearing a red hat, green gloves, red"
                " shirt, and green pants."
            ),
            "A blue bird and a brown bear.",
            "A yellow book and a red vase.",
            "Three dogs on the street.",
            "Two cats and one dog sitting on the grass.",
            "A wine glass on top of a dog.",
            "A cube made of denim. A cube with the texture of denim.",
        ]
    elif flag == "partiprompt":
        test_batch = [
            "a panda bear with aviator glasses on its head",
            "Times Square during the day",
            "the skyline of New York City",
            "square red apples on a tree with circular green leaves",
            "a map of Italy",
            "a sketch of a horse",
            "the word 'START' on a blue t-shirt",
            "a dolphin in an astronaut suit on saturn",
        ]
    elif flag == "coco":
        test_batch = [
            "A Christmas tree with lights and teddy bear",
            "A group of planes near a large wall of windows.",
            "three men riding horses through a grassy field",
            "A man and a woman posing in front of a motorcycle.",
            "A man sitting on a motorcycle smoking a cigarette.",
            "A pear, orange, and two bananas in a wooden bowl.",
            "Some people posting in front of a camera for a picture.",
            "Some very big furry brown bears in a big grass field.",
        ]
    elif flag == "paintskill":
        test_batch = [
            "a photo of blue bear",
            "a photo of blue fire hydrant",
            "a photo of bike and skateboard; skateboard is left to bike",
            "a photo of bed and human; human is right to bed",
            "a photo of suitcase and bench; bench is left to suitcase",
            "a photo of bed and stop sign; stop sign is above bed",
            (
                "a photo of dining table and traffic light; traffic light is below"
                " dining table"
            ),
            "a photo of bear and bus; bus is above bear",
        ]
    else:
        test_batch = [
            "A dog and a cat.",
            "A cat and a dog.",
            "Two dogs in the park.",
            "Three dogs in the park.",
            "Four dogs in the park.",
            "A blue colored rabbit.",
            "A red colored rabbit.",
            "A green colored rabbit.",
        ]

    return test_batch

def safe_distributed_all_reduce(tensor, device):
    """Safely reduce a tensor across all processes.

    Args:
        tensor: Tensor to reduce.
        device: Device to perform reduction on.

    Returns:
        Reduced tensor value.
    """
    if tensor.device != device:
        tensor = tensor.to(device)

    try:
        import torch.distributed as dist
        if dist.is_available() and dist.is_initialized():
            dist.all_reduce(tensor)
            tensor = tensor / dist.get_world_size()
    except:
        # Fall back if distributed not initialized
        pass

    return tensor

def gather_and_log_rewards(accelerator, rewards, cfg_scale, output_dir):
    """Gather rewards from all processes and log them.

    Args:
        accelerator: Accelerator object.
        rewards: Rewards tensor.
        cfg_scale: CFG scale value.
        output_dir: Output directory.
    """
    # Create the output directory for rewards
    os.makedirs(output_dir, exist_ok=True)

    # Process rank and rewards
    process_index = accelerator.process_index

    # Create per-node directory
    node_dir = os.path.join(output_dir, "pernode")
    os.makedirs(node_dir, exist_ok=True)

    # Save rewards from this process
    output_filename = f"{cfg_scale.replace(',', '_')}"
    output_file_path = os.path.join(node_dir, output_filename)

    # Numpy array of rewards
    reward_array = rewards.cpu().numpy()

    # Save to file
    with open(output_file_path, 'wb') as f:
        np.save(f, reward_array)

    # If main process, gather and save combined results
    if accelerator.is_main_process:
        # Logic to combine results would go here
        pass

def load_policy_model(args, saved_model_path, policy_model, weight_dtype, device):
    """Load policy model from a saved checkpoint.

    Args:
        args: Arguments.
        saved_model_path: Path to saved model.
        policy_model: Policy model to load weights into (AttnProcsLayers).
        weight_dtype: Data type for model weights.
        device: Device to load model on.

    Returns:
        Tuple of (loaded policy model, start count).
    """
    import re
    import torch

    # Get the last saved step
    count_match = re.search(r'save_(\d+)', saved_model_path)
    start_count = 0

    if count_match:
        start_count = int(count_match.group(1))

    # Check if we have the new format (policy_model.pt) or old format (individual .safetensors files)
    policy_path = os.path.join(saved_model_path, "policy_model.pt")

    if os.path.exists(policy_path):
        # New format: load from policy_model.pt
        print(f"Loading policy model from {policy_path}")
        state_dict = torch.load(policy_path, map_location=device)
        policy_model.load_state_dict(state_dict)
    else:
        # Old format: load from individual .safetensors files
        print(f"Loading policy model from LoRA safetensors in {saved_model_path}")
        from diffusers.loaders import AttnProcsLayers

        lora_state_dict = {}
        for name in policy_model.attn_processors.keys():
            filename = os.path.join(saved_model_path, f"{name}.safetensors")
            if os.path.exists(filename):
                lora_state_dict[name] = filename

        if lora_state_dict:
            # Load the LoRA attention processors
            policy_model.load_attn_procs(lora_state_dict)
        else:
            print(f"Warning: No LoRA files found in {saved_model_path}")

    # Move to device and cast to dtype
    policy_model.to(device, dtype=weight_dtype)

    print(f"Successfully loaded policy model from {saved_model_path}, starting from count {start_count}")

    return policy_model, start_count

def create_symlinks(main_dir, shadow_dir):
    """
    Create a directory tree in main_dir that mirrors shadow_dir structure,
    with file-level symlinks pointing to the corresponding files in shadow_dir.

    Args:
        main_dir: Target directory where the directory tree and symlinks will be created
        shadow_dir: Source directory whose structure will be mirrored
    """
    import os

    # Ensure the target directory exists
    os.makedirs(main_dir, exist_ok=True)

    # Walk through all contents in shadow_dir
    for root, dirs, files in os.walk(shadow_dir):
        # Calculate relative path from shadow_dir
        rel_path = os.path.relpath(root, shadow_dir)

        # Create corresponding directory in main_dir
        if rel_path != '.':
            target_dir = os.path.join(main_dir, rel_path)
        else:
            target_dir = main_dir
        os.makedirs(target_dir, exist_ok=True)

        # Create symlinks for each file
        for file in files:
            source_file = os.path.join(root, file)
            target_file = os.path.join(target_dir, file)

            # skip if file or symlink already exists
            if os.path.exists(target_file) or os.path.islink(target_file):
                continue

            # Create symlink using absolute path to ensure correct linking
            os.symlink(os.path.abspath(source_file), target_file)
