import numpy as np
from qwen_vl_utils import process_vision_info
from scipy.stats import spearmanr
from video_utils import shuffle_except_first
from rl_utils import get_gvl_content
import re


def extract_task_completion_percentages(text: str, maxlen: int):
    # Regex to match "<number>%"
    pattern = r"(\d+(?:\.\d+)?)%"
    matches = re.findall(pattern, text)
    
    res =  [float(value) for value in matches]
    if len(res) > maxlen:
        res = res[-maxlen:]
    return res


def calculate_VOC(model, processor, dataset, n_frames=5, num_shuffles=100, n_frames_ref=-1):
    VOC_sum, cnt_voc, cnt_sorted = 0, 0, 0
    for _ in range(num_shuffles):
        sliced_frames, _ = dataset.get_uniform_frames(n_frames)

        if n_frames_ref > 0:
            frames_in_context, percent_in_context = dataset.get_uniform_frames(n_frames_ref)
            ref_in_context = list(zip(frames_in_context, percent_in_context))
        else:
            ref_in_context = []

        # default processer
        shuffled_frames, shuffled_indices = shuffle_except_first(sliced_frames)
        messages, image_inputs = get_gvl_content(shuffled_frames, ref_in_context)

        # Preparation for inference
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        # image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(model.device)

        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        ## parse output here
        unordered_values = [-1.0] + extract_task_completion_percentages(output_text[0], len(sliced_frames) - 1)
        if len(unordered_values) == len(sliced_frames):

            # check that it's sorted => model predicts stricly increasing sequence for stats
            if np.all(np.array(unordered_values)[:-1] <= np.array(unordered_values)[1:]):
                cnt_sorted += 1
            
            values_order = np.argsort(unordered_values)
            order = np.array([0] + shuffled_indices)[values_order]
            true_order = np.arange(len(order))

            ## compute VOC here
            VOC, _ = spearmanr(order[1:] - 1, true_order[1:] - 1)
        else:
            VOC = -1
        
        if VOC is not None:
            VOC_sum += VOC
            cnt_voc += 1

    return {
        'VOC': VOC_sum / cnt_voc,
        'cnt_sorted': cnt_sorted
    }