# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from transformers import PreTrainedTokenizer

from ... import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score, tvg_compute_score, tvg_compute_score_confident
import re
import json
import numpy as np

def compute_iou_interval(a, b):
    """Compute IoU for two time intervals a and b"""
    inter = max(0, min(a[1], b[1]) - max(a[0], b[0]))
    union = max(a[1], b[1]) - min(a[0], b[0])
    return inter / union if union > 0 else 0.0

def compute_iou_consistency_matrix(candidates):
    """Compute pairwise IOU matrix for candidates"""
    N = len(candidates)
    iou_matrix = np.zeros((N, N))
    for i in range(N):
        for j in range(N):
            if i != j:
                iou_matrix[i, j] = compute_iou_interval(candidates[i], candidates[j])
    return iou_matrix

def select_best_candidate(candidates, iou_matrix, drop_percentile=30):
    """Filter low-consistency outliers and select best candidate"""
    avg_ious = iou_matrix.mean(axis=1)
    threshold = np.percentile(avg_ious, drop_percentile)
    
    filtered = [(idx, avg_iou) for idx, avg_iou in enumerate(avg_ious) if avg_iou >= threshold]
    
    if not filtered:
        # fallback: use full list
        filtered = list(enumerate(avg_ious))
    
    # Choose the one with highest avg IOU
    best_idx = max(filtered, key=lambda x: x[1])[0]
    return best_idx, avg_ious[best_idx]

def compute_pseudo_labels_with_variance_confidence(all_candidates, num_generation, confidence_alpha=10.0):
    """
    使用筛选后的时间段均值作为伪标签，并用 start/end 方差计算置信度。
    输入:
        all_candidates: List of [start, end], 长度 = batch_size * num_generation
        num_generation: 每个 sample 生成的候选数
        confidence_alpha: 控制方差对置信度的影响程度（越大惩罚越重）
    输出:
        reordered_labels: 与 all_candidates 等长的伪标签 list（每个 sample 的 label 重复 num_generation 次）
        reordered_confidences: 同样长度的置信度 list，归一化到 [0, 1]
    """
    total = len(all_candidates)
    assert total % num_generation == 0, "Input length must be divisible by num_generation"
    batch_size = total // num_generation

    pseudo_labels_full = []
    confidences_full = []

    for i in range(batch_size):
        # 提取 i-th sample 的所有候选
        group = [all_candidates[i + j * batch_size] for j in range(num_generation)]

        # 去除无效候选
        filtered_group = [g for g in group if g != [1.0, 1.0]]

        if len(filtered_group) < 1:
            label = [1.0, 1.0]
            confidence = 0.0
        elif len(filtered_group) == 1:
            label = filtered_group[0]
            confidence = 1.0
        else:
            starts = np.array([s for s, e in filtered_group])
            ends = np.array([e for s, e in filtered_group])
            avg_start = float(np.mean(starts))
            avg_end = float(np.mean(ends))
            label = [avg_start, avg_end]

            # 计算置信度（越集中越好）
            std_start = np.std(starts)
            std_end = np.std(ends)
            variance_score = std_start + std_end
            confidence = float(np.exp(-confidence_alpha * variance_score))  # in (0, 1]

        # 保存每个 sample 的 label 和 confidence
        pseudo_labels_full.extend([label] * num_generation)
        confidences_full.extend([confidence] * num_generation)

    # 重新排回 interleaved 原顺序
    reordered_labels = [None] * total
    reordered_confidences = [None] * total
    for i in range(batch_size):
        for j in range(num_generation):
            idx = i + j * batch_size
            reordered_labels[idx] = pseudo_labels_full[i * num_generation + j]
            reordered_confidences[idx] = confidences_full[i * num_generation + j]

    return reordered_labels, reordered_confidences

class CustomRewardManager:
    def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str):
        self.tokenizer = tokenizer
        self.num_examine = num_examine
        if compute_score == "math":
            self.compute_score = math_compute_score
        elif compute_score == "r1v":
            self.compute_score = r1v_compute_score
        else:
            raise NotImplementedError()

    def __call__(self, data: DataProto) -> torch.Tensor:
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        already_print = 0
        if data.non_tensor_batch["text_type"][0] == "adaptation":
            ##################################################################################
            all_candidate = []
            for i in range(len(data)):
                data_item = data[i]
                prompt_ids = data_item.batch["prompts"]
                prompt_length = prompt_ids.shape[-1]

                response_ids = data_item.batch["responses"]
                valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
                valid_response_ids = response_ids[:valid_response_length]

                response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)

                try:
                    content_answer_match = re.search(r"<answer>(.*?)</answer>", response_str, re.DOTALL)
                    if not content_answer_match:
                        raise ValueError("No <answer>...</answer> tag found.")

                    content_answer = content_answer_match.group(1).strip()
                    answer_data = json.loads(content_answer)
                    video_length = data_item.non_tensor_batch["video_length"]
                    answer_timestamp = [
                        float(answer_data["start_time"]) / video_length, 
                        float(answer_data["end_time"]) / video_length
                    ]
                except Exception as e:
                    # 打印异常信息可选：print(f"Warning: Failed to parse response {i}: {e}")
                    answer_timestamp = [1.0, 1.0]
                all_candidate.append(answer_timestamp)
            # output_path = "all_candidate.txt"
            # with open(output_path, 'a') as f:
            #     f.write("all_candidate:" + str(all_candidate) + "\n")

            num_generation = 8
            pseudo_gt_list, confidence_list = compute_pseudo_labels_with_variance_confidence(all_candidate, num_generation, 10.0)
            ##################################################################################
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem

            prompt_ids = data_item.batch["prompts"]
            prompt_length = prompt_ids.shape[-1]

            valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)

            if data.non_tensor_batch["text_type"][0] == "adaptation":
                ##################################################################################
                ground_truth_real = data_item.non_tensor_batch["ground_truth"]
                ground_truth = pseudo_gt_list[i]
                confidence_sample = confidence_list[i]
                output_path = "all_candidate.txt"
                with open(output_path, 'a') as f:
                    f.write("pseudo_ground_truth_sample:" + str(ground_truth) + "ground_truth_sample:" + str(ground_truth_real) + "confidence_sample:" + str(confidence_sample) + "\n")
                ##################################################################################
            else:
                ground_truth = data_item.non_tensor_batch["ground_truth"]
            problem_type = data_item.non_tensor_batch.get("problem_type", "")
            if problem_type == 'tvg':
                video_length = data_item.non_tensor_batch["video_length"]
                if data.non_tensor_batch["text_type"][0] == "adaptation":
                    # output_path = "all_candidate.txt"
                    # with open(output_path, 'a') as f:
                    #     f.write("ssss" + data.non_tensor_batch["text_type"][0] + "\n")
                    ##################################################################################
                    score = tvg_compute_score_confident(response_str, ground_truth, video_length, confidence_sample)
                    ##################################################################################
                else:
                    # with open(output_path, 'a') as f:
                    #     f.write("ssssdddd" + data.non_tensor_batch["text_type"][0] + "\n")
                    score = tvg_compute_score(response_str, ground_truth, video_length)
            else:
                score = self.compute_score(response_str, ground_truth)
            reward_tensor[i, valid_response_length - 1] = score

            if already_print < self.num_examine:
                already_print += 1
                print("[prompt]", prompt_str)
                print("[response]", response_str)
                print("[ground_truth]", ground_truth)
                print("[video_length]", video_length)
                print("[score]", score)

        return reward_tensor
