import re

from PIL import Image
import tqdm
import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import accelerate

from src import utils

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)
    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    if not isinstance(image_file, Image.Image):
        image = Image.open(image_file).convert('RGB')
    else:
        image = image_file.copy()
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


def eval_rank(
    model, 
    tokenizer, 
    data_config, 
    per_device_eval_batch_size = 100,
    tqdm_disable = True,
) -> dict:

    results = {}

    for template_key, template in enumerate(data_config["templates"]):
        template_key = str(template_key)

        results[template_key] = {
            "template": template,
            "rank": None,
            "meta": [],
        }

        option_key = re.search(r'\{(.*?)\}', template[1]).group(1)
        options_list = [mapping[option_key] for mapping in data_config["mapping"]]
        # WARNING: deduplication
        options_list = list(set(options_list))
        formatted_options = [
            template[1].format(**{option_key: option})
            for option in options_list
        ]
        first_token_of_options = [
            tokenizer.encode(formatted_option, add_special_tokens=False)[0]
            for formatted_option in formatted_options
        ]

        if len(set(first_token_of_options)) != len(first_token_of_options):
            # Identify conflicting options
            token_to_options = {}
            for option, token in zip(options_list, first_token_of_options):
                if token not in token_to_options:
                    token_to_options[token] = []
                token_to_options[token].append(option)

            # Collect conflicts (options sharing the same first token)
            conflicts = {
                token: options
                for token, options in token_to_options.items()
                if len(options) > 1
            }

            # Construct a detailed error message
            conflict_details = "\n".join(
                f"  - Token {token} is shared by options: {', '.join(options)}"
                for token, options in conflicts.items()
            )
            error_message = (
                "First tokens of the options are not unique. Fast ranking requires uniqueness.\n"
                "Conflicts found:\n" + conflict_details
            )

            raise ValueError(error_message)

        # Ensure each process gets at least one sample
        num_processes = accelerate.PartialState().num_processes
        process_index = accelerate.PartialState().process_index

        # Ensure the length of data_config["mapping"] is divisible by num_processes
        if len(data_config["mapping"]) < num_processes:
            # Repeat the mapping elements enough times to reach or exceed num_processes
            repeat_factor = (num_processes // len(data_config["mapping"])) + 1
            data_config["mapping"].extend(data_config["mapping"] * repeat_factor)

        # Now pad the list to make it divisible by num_processes
        if len(data_config["mapping"]) % num_processes != 0:
            num_to_add = num_processes - (len(data_config["mapping"]) % num_processes)
            data_config["mapping"].extend(data_config["mapping"][:num_to_add])

        assert len(data_config["mapping"]) % num_processes == 0, "Data is not evenly divisible by number of processes."

        num_data_per_process = len(data_config["mapping"]) // num_processes
        start_idx = process_index * num_data_per_process
        end_idx = (process_index + 1) * num_data_per_process
        mappings_this_device = data_config["mapping"][start_idx:end_idx]

        for i in tqdm.tqdm(
            range(0, len(mappings_this_device), per_device_eval_batch_size),
            disable=tqdm_disable or not accelerate.PartialState().is_main_process
        ):
            batch_mappings = mappings_this_device[i:i + per_device_eval_batch_size]
            prompts = [
                template[0].format_map(utils.SafeDict(mapping))
                for mapping in batch_mappings
            ]

            if "{image_prefix}" in prompts[0]:
                # set the max number of tiles in `max_num`
                pixel_values = [load_image(mapping['image'], max_num=12).to(torch.bfloat16).cuda() for mapping in batch_mappings]
                num_patches_list = [pixel_value.size(0) for pixel_value in pixel_values]
                pixel_values = torch.cat(pixel_values, dim=0)
            
                generation_config = dict(max_new_tokens=1024, do_sample=False)
                outputs = model.batch_chat(tokenizer, pixel_values,  num_patches_list=num_patches_list,
                                    questions = prompts, generation_config=generation_config, logits_only = True)
            else:
                num_patches_list = list(range(len(prompts)))
                generation_config = dict(max_new_tokens=1024, do_sample=False)
                outputs = model.batch_chat(tokenizer, None,  num_patches_list=num_patches_list,
                                    questions = prompts, generation_config=generation_config, logits_only = True)

            pooled_logits = outputs['logits']
            input_ids = outputs['input_ids']
            lens = input_ids.shape
            pooled_logits = outputs["logits"][torch.arange(lens[0], device=model.device), -1]
            pooled_logits = pooled_logits[:, first_token_of_options]
        
            ranked_args_list = torch.argsort(
                    pooled_logits, dim=1, descending=True).cpu().tolist()
    
    
            for mapping, ranked_args in zip(batch_mappings, ranked_args_list):
                ranked_options = [options_list[arg] for arg in ranked_args]
                rank = ranked_args.index(options_list.index(mapping[option_key]))

                result = {
                        "name": mapping["name"],
                        "path": mapping.get("path", mapping["name"]),
                        "gt": mapping[option_key],
                        "rank": rank,
                        "ranked_options": ranked_options
                    }
                results[template_key]["meta"].append(result)

    # After all templates processed, gather and merge results from all devices
    gathered_results = accelerate.utils.gather_object([results])
    merged_results = {}

    for template_key in gathered_results[0].keys():
        merged_meta = []
        template = gathered_results[0][template_key]["template"]
        # ✅ Deduplicate based on `path`, keep the first occurrence
        seen_names = set()
        for partial_result in gathered_results:
            for entry in partial_result[template_key]["meta"]:
                if entry["path"] not in seen_names:
                    merged_meta.append(entry)
                    seen_names.add(entry["path"])

        mean_rank = sum([entry["rank"] for entry in merged_meta]) / len(merged_meta)
        merged_results[template_key] = {
            "template": template,
            "rank": mean_rank,
            "meta": merged_meta,
        }

    return merged_results


if __name__ == "__main__":
    from dataclasses import dataclass
    import transformers
    import tyro
    import pprint

    from src.utils import GpuTimer

    @dataclass
    class ScriptArguments:
        model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/OpenGVLab/InternVL3-8B"
        data_config_path: str = "data/food_v3/config_image.yaml"
        data_overwrite_args: str = "" # e.g. --data_overwrite_args "data.train[0].images_dirs[0]=/new/path/to/images,..."
        per_device_eval_batch_size: int = 8
        save_path: str = None

    script_args = tyro.cli(ScriptArguments)

    ################
    # Model, Tokenizer
    ################
    print("Eval_rank for internvl testing...")
    print(f"evaluating {script_args.model_name_or_path}...")
    model = transformers.AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map={"": accelerate.PartialState().local_process_index}
    ).eval()
    tokenizer = transformers.AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True, use_fast=False)
    # if script_args.weight_path:
    #     from safetensors.torch import load_file
    #     state_dict = load_file(script_args.weight_path)
    #     model.load_state_dict(state_dict)

    generation_config = dict(max_new_tokens=1024, do_sample=False)
    data_config = utils.parse_data_config(script_args.data_config_path, script_args.data_overwrite_args)
    train_configs, eval_configs = utils.parse_train_and_eval_config(data_config)

    with GpuTimer():
        results = {}
        for train_idx, train_config in enumerate(train_configs):
            if train_config is None:
                continue  # skip if train_config is None
            partial_results = eval_rank(
                model,
                tokenizer,
                train_config,
                script_args.per_device_eval_batch_size,
                tqdm_disable=False
            )
            for template_key, template_result in partial_results.items():
                new_key = f"train-{train_idx}.{template_key}"
                results[new_key] = template_result
        
        # Then handle eval_configs
        for eval_idx, eval_config in enumerate(eval_configs):
            if eval_config is None:
                continue  # skip if eval_config is None
            partial_results = eval_rank(
                model,
                tokenizer,
                eval_config,
                script_args.per_device_eval_batch_size,
                tqdm_disable=False
            )
            for template_key, template_result in partial_results.items():
                new_key = f"eval-{eval_idx}.{template_key}"
                results[new_key] = template_result

    pprint.pprint(results, depth=2, width=500)
    breakpoint()

    # pprint.pprint(results, depth=2, width=500)
    # import json
    # import os
    # results_path = os.path.join(f"tmp/eval_rank/", script_args.data_config_path.removesuffix('.yaml'), os.path.basename(script_args.model_name_or_path.rstrip('/')), "log.json" )
    # print(f"Result has been saved to {results_path}")
    # os.makedirs(os.path.dirname(results_path), exist_ok=True)
    # with open(results_path, "w", encoding="utf-8") as f:
    #     json.dump(results, f, ensure_ascii=False, indent=4)
