import os
import yaml
from PIL import Image
import argparse
import copy

from model.interpolator import ScoreInterpolator


def parse_main_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str, help="Path to the dataset directory.")
    parser.add_argument("--c", type=str, default="configs/config.yaml")
    parser.add_argument("--noise_level", type=float)
    parser.add_argument("--iters", type=int)
    parser.add_argument("--out_dir", type=str)
    parser.add_argument("--out_start_imgs", type=bool)
    parser.add_argument("--lr", type=float)
    parser.add_argument("--lr_scheduler", type=str)
    parser.add_argument("--grad_weight_type", type=str)
    parser.add_argument("--grad_sample_type", type=str)
    parser.add_argument("--num_splits", type=int, default=1)
    parser.add_argument("--split_index", type=int, default=0)
    return parser.parse_args()


def process_pair(id: str, paths: dict, config: dict, pipe):
    print(f"--- Processing pair: {id} ---")

    pair_config = copy.deepcopy(config)
    pair_config["test_name"] = id
    pair_config["pathA"] = paths["img0_path"]
    pair_config["pathB"] = paths["img1_path"]
    pair_config["promptA"] = paths["prompt"]
    pair_config["promptB"] = paths["prompt"]

    tv_ckpt_base = pair_config["tv_args"]["tv_ckpt_folder"]
    os.makedirs(tv_ckpt_base, exist_ok=True)
    pair_config["tv_args"]["tv_ckpt_folder"] = os.path.join(tv_ckpt_base, id)

    output_dir = pair_config["output_args"]["out_dir"]
    os.makedirs(output_dir, exist_ok=True)

    try:
        imgA = Image.open(pair_config["pathA"])
        imgB = Image.open(pair_config["pathB"])
    except FileNotFoundError as e:
        print(f"Error: Image not found for pair {id}. {e}")
        return

    bvp_solver = ScoreInterpolator(pipe, imgA=imgA, imgB=imgB, **pair_config)
    print("Start interpolation:", id)
    bvp_solver.solve()
    print(f"Finished interpolation: {id}\n")


def override_config_from_args(config: dict, args: argparse.Namespace) -> dict:
    if args.noise_level is not None:
        config["noise_level"] = args.noise_level
    if args.iters is not None:
        config["opt_args"]["iter_num"] = args.iters
    if args.out_dir is not None:
        config["output_args"]["out_dir"] = args.out_dir
    if args.out_start_imgs is not None:
        config["output_args"]["output_start_images"] = args.out_start_imgs
    if args.lr is not None:
        config["opt_args"]["lr"] = args.lr
    if args.lr_scheduler is not None:
        config["opt_args"]["lr_scheduler"] = args.lr_scheduler
    if args.grad_weight_type is not None:
        config["grad_args"]["grad_weight_type"] = args.grad_weight_type
    if args.grad_sample_type is not None:
        config["grad_args"]["grad_sample_type"] = args.grad_sample_type
    return config


def get_split_items(all_items: list, num_splits: int, split_index: int) -> list:
    num_items = len(all_items)
    if num_splits <= 0:
        return []

    items_per_split = (num_items + num_splits - 1) // num_splits
    start_index = split_index * items_per_split
    end_index = min(start_index + items_per_split, num_items)

    if split_index < 0 or start_index >= num_items:
        print(
            f"Warning: split_index {split_index} is out of range for {num_splits} splits. No items to process."
        )
        return []

    target_items = all_items[start_index:end_index]
    print(
        f"Processing split {split_index + 1}/{num_splits}: {len(target_items)} items (from index {start_index} to {end_index - 1})."
    )
    return target_items


def create_data_path_dict_from(dataset_config_name: str) -> dict:
    dataset_config_path = os.path.join("data", f"{dataset_config_name}.yaml")

    with open(dataset_config_path, "r", encoding="utf-8") as file:
        dataset_config = yaml.safe_load(file)

    data_path_dict = {}
    for item in dataset_config:
        item_id = item.get("job_name")
        img0_path = item.get("img0_path")
        img1_path = item.get("img1_path")
        prompt = item.get("prompt")
        data_path_dict[item_id] = {
            "img0_path": img0_path,
            "img1_path": img1_path,
            "prompt": prompt,
        }

    return data_path_dict
