#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import torch
import argparse
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


class Qwen25VLHUDAnalyzer:
    def __init__(self, sft_checkpoint_path: str, device: str = 'cuda', batch_size: int = 8,
                 override_hud_jsonl: str = None):
        self.device = device
        self.batch_size = batch_size
        self.dataset = None
        self.qid_to_idx = None
        self._init_model(sft_checkpoint_path)
        self.vqa_prompt_template = "{question}"
        self.override_hud_mapping = {}
        if override_hud_jsonl:
            self.override_hud_mapping = self._load_override_hud_jsonl(override_hud_jsonl)

    def _init_model(self, sft_checkpoint_path: str):
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            sft_checkpoint_path,
            torch_dtype="auto",
            device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(sft_checkpoint_path)
        self.model.eval()

    def _load_override_hud_jsonl(self, jsonl_path: str) -> Dict[int, List[float]]:
        override_mapping = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    record = json.loads(line.strip())
                    question_id = int(record['question_id'])
                    top2_probs = record.get('top2_probs', record.get('top4_probs', [])[:2])
                    if not top2_probs: continue
                    if len(top2_probs) < 2: top2_probs.extend([0.0] * (2 - len(top2_probs)))
                    override_mapping[question_id] = top2_probs[:2]
                except (json.JSONDecodeError, KeyError, TypeError):
                    continue
        return override_mapping

    def _get_hud_scores_and_answers(self, question_id: int, hud_mapping: Dict[int, Dict]) -> Optional[
        Tuple[List[float], List[str]]]:
        question_id = int(question_id)
        if question_id in hud_mapping:
            hud_record = hud_mapping[question_id]
            original_scores = hud_record.get('hud_scores', [])
            original_answers = hud_record.get('hud_answers', [])
            if len(original_scores) >= 2 and len(original_answers) >= 2:
                return original_scores[:2], original_answers[:2]
        return None

    def get_top_two_logits(self, image_path: str, question: str, data_path: str = "") -> Optional[List[float]]:
        try:
            if data_path and not image_path.startswith('/'):
                full_image_path = os.path.join(data_path, image_path)
            else:
                full_image_path = image_path
            if not os.path.exists(full_image_path):
                return None
            img = Image.open(full_image_path).convert('RGB')
            formatted_question = self.vqa_prompt_template.format(question=question)
            messages = [{"role": "user",
                         "content": [{"type": "image", "image": img}, {"type": "text", "text": formatted_question}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs, padding=True,
                                    return_tensors="pt").to(self.model.device)
            with torch.no_grad():
                outputs = self.model(**inputs)
            logits = outputs.logits
            last_token_logits = logits[0, -1, :]
            top_two = torch.topk(last_token_logits, 2)
            return top_two.values.tolist()
        except Exception:
            return None

    def calculate_answer_distribution(self, sample: dict, data_path: str = "") -> Optional[List[float]]:
        try:
            question = None
            for conv in sample.get('conversations', []):
                if conv['from'] == 'human':
                    question = conv['value'].replace('<image>', '').strip()
                    break
            if not question:
                return None
            top_two_logits = self.get_top_two_logits(sample.get('image', ''), question, data_path)
            if top_two_logits is None:
                return None
            logits_tensor = torch.tensor(top_two_logits)
            probs = F.softmax(logits_tensor, dim=-1)
            return probs.tolist()
        except Exception:
            return None

    def load_classification_files(self, hu_low_path: str, hu_mid_path: str) -> Tuple[Dict[str, List[int]], set]:
        classifications, training_set = {}, set()
        file_mapping = {'hu_low': hu_low_path, 'hu_mid': hu_mid_path}
        for name, path in file_mapping.items():
            with open(path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            if isinstance(data, dict) and 'annotations' in data:
                question_ids = [int(ann['question_id']) for ann in data['annotations']]
            else:
                question_ids = [int(qid) for qid in data] if isinstance(data, list) else []
            classifications[name] = question_ids
            training_set.update(question_ids)
        return classifications, training_set

    def load_hud_jsonl(self, hud_jsonl_path: str) -> Dict[int, Dict]:
        hud_mapping = {}
        with open(hud_jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                record = json.loads(line)
                qid = int(record['question_id'])
                hud_mapping[qid] = record
        return hud_mapping

    def load_vqa_dataset(self, data_path: str) -> Tuple[List[Dict], Dict[int, int]]:
        with open(data_path, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
        qid_to_idx = {int(item.get('question_id', i)): i for i, item in enumerate(dataset)}
        return dataset, qid_to_idx

    def process_classification_questions(self, dataset: List[Dict], qid_to_idx: Dict[int, int],
                                         classifications: Dict[str, List[int]], hud_mapping: Dict[int, Dict],
                                         data_path: str = "") -> Dict:
        results = {'kl_thresholds': {}}
        for class_name, qids in classifications.items():
            valid_samples = []
            for qid in qids:
                if qid not in qid_to_idx or qid not in hud_mapping:
                    continue
                hud_data = self._get_hud_scores_and_answers(qid, hud_mapping)
                if hud_data is None:
                    continue
                hud_scores, _ = hud_data
                valid_samples.append((qid, hud_scores))
            if not valid_samples:
                continue
            kl_divergences = []
            for qid, hud_scores in tqdm(valid_samples, desc=f"Computing KL for {class_name}"):
                sample = dataset[qid_to_idx[qid]]
                model_dist = self.calculate_answer_distribution(sample, data_path)
                if model_dist is None:
                    continue
                hud_sum = sum(hud_scores)
                if hud_sum == 0:
                    continue
                hud_dist = [s / hud_sum for s in hud_scores]
                epsilon = 1e-8
                kl = sum(hud_dist[i] * np.log((hud_dist[i] + epsilon) / (model_dist[i] + epsilon)) for i in range(2))
                if np.isfinite(kl):
                    kl_divergences.append(kl)
            if kl_divergences:
                results['kl_thresholds'][class_name] = np.mean(kl_divergences)
        return results

    def find_remaining_questions(self, training_set: set, qid_to_idx: Dict[int, int]) -> List[int]:
        all_qids = set(qid_to_idx.keys())
        return list(all_qids - training_set)

    def process_remaining_questions_targeted(self, dataset: List[Dict], qid_to_idx: Dict[int, int],
                                             remaining_qids: List[int], hud_mapping: Dict[int, Dict],
                                             data_path: str, kl_interval: Tuple[float, float],
                                             target_limit: int, output_kl_results_path: str = None) -> Dict:
        results = {'target_questions': []}
        valid_samples = []
        for qid in remaining_qids:
            if qid not in qid_to_idx or qid not in hud_mapping:
                continue
            hud_data = self._get_hud_scores_and_answers(qid, hud_mapping)
            if hud_data is None:
                continue
            hud_scores, hud_answers = hud_data
            valid_samples.append((qid, hud_scores, hud_answers))

        with open(output_kl_results_path, 'w', encoding='utf-8') if output_kl_results_path else open(os.devnull,
                                                                                                     'w') as kl_file:
            for qid, hud_scores, hud_answers in tqdm(valid_samples, desc="Filtering questions"):
                if len(results['target_questions']) >= target_limit:
                    break
                sample = dataset[qid_to_idx[qid]]
                model_dist = self.calculate_answer_distribution(sample, data_path)
                if model_dist is None:
                    continue
                hud_sum = sum(hud_scores)
                if hud_sum == 0:
                    continue
                hud_dist = [s / hud_sum for s in hud_scores]
                epsilon = 1e-8
                kl = sum(hud_dist[i] * np.log((hud_dist[i] + epsilon) / (model_dist[i] + epsilon)) for i in range(2))
                if not np.isfinite(kl):
                    continue
                if output_kl_results_path:
                    gt_answer = next((c['value'] for c in sample.get('conversations', []) if c['from'] == 'gpt'), "N/A")
                    kl_file.write(json.dumps(
                        {"question_id": qid, "hud_answers": hud_answers, "hud_dist": hud_dist, "model_dist": model_dist,
                         "kl_divergence": kl, "ground_truth": gt_answer}) + '\n')
                if kl_interval[0] <= kl <= kl_interval[1]:
                    results['target_questions'].append(
                        {'question_id': qid, 'image_path': sample.get('image'), 'top_answer': hud_answers[0],
                         'kl_divergence': kl})
        return results

    def run_full_analysis(self, data_path: str, hu_low_path: str, hu_mid_path: str,
                          hud_jsonl_path: str, training_json_path: str, output_file: str,
                          target_limit: int, output_kl_results_path: str = None):
        classifications, training_set = self.load_classification_files(hu_low_path, hu_mid_path)
        hud_mapping = self.load_hud_jsonl(hud_jsonl_path)
        dataset, qid_to_idx = self.load_vqa_dataset(data_path)
        self.dataset, self.qid_to_idx = dataset, qid_to_idx
        data_dir = os.path.dirname(data_path) if data_path.endswith('.json') else data_path
        classification_results = self.process_classification_questions(dataset, qid_to_idx, classifications,
                                                                       hud_mapping, data_dir)
        kl1 = classification_results.get('kl_thresholds', {}).get('hu_low', 0.0)
        kl2 = classification_results.get('kl_thresholds', {}).get('hu_mid', 0.0)
        kl_interval = (min(kl1, kl2), max(kl1, kl2)) if kl1 != 0.0 or kl2 != 0.0 else (0.0, 1.0)
        remaining_qids = self.find_remaining_questions(training_set, qid_to_idx)
        remaining_results = self.process_remaining_questions_targeted(dataset, qid_to_idx, remaining_qids, hud_mapping,
                                                                      data_dir, kl_interval, target_limit,
                                                                      output_kl_results_path)

        if output_file:
            with open(training_json_path, 'r', encoding='utf-8') as f:
                existing_training_data = json.load(f)
            new_training_samples = []
            for item in remaining_results['target_questions']:
                qid = item['question_id']
                if qid in qid_to_idx:
                    original_sample = dataset[qid_to_idx[qid]]
                    new_training_samples.append({"image": original_sample.get('image', ''),
                                                 "conversations": original_sample.get('conversations', [])})

            combined_training_data = existing_training_data + new_training_samples
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(combined_training_data, f, indent=2)
            print(f"Saved combined training data to: {output_file}")
            print(
                f"Total samples: {len(combined_training_data)} (original: {len(existing_training_data)} + new: {len(new_training_samples)})")

            base_name = output_file.replace('.json', '')
            stats_file = f"{base_name}_analysis_summary.json"
            final_results = {
                'kl_thresholds': {'hu_low': kl1, 'hu_mid': kl2}, 'kl_interval': kl_interval,
                'original_training_count': len(existing_training_data),
                'new_samples_count': len(new_training_samples), 'total_count': len(combined_training_data)
            }
            with open(stats_file, 'w', encoding='utf-8') as f:
                json.dump(final_results, f, indent=2)
            print(f"Saved analysis summary to: {stats_file}")

        print("\n" + "=" * 80)
        print("Analysis Complete!")
        print(f"KL Thresholds: hu_low={kl1:.6f}, hu_mid={kl2:.6f}")
        print(f"Collected {len(remaining_results['target_questions'])} questions")
        print("=" * 80)


def main():
    parser = argparse.ArgumentParser(description='Analysis')
    parser.add_argument('--sft_checkpoint', type=str, required=True)
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--hu_low_file', type=str, required=True)
    parser.add_argument('--hu_mid_file', type=str, required=True)
    parser.add_argument('--hud_jsonl', type=str, required=True)
    parser.add_argument('--training_json_path', type=str, required=True)
    parser.add_argument('--override_hud_jsonl', type=str, default=None)
    parser.add_argument('--output_file', type=str, default='analysis.json')
    parser.add_argument('--target_limit', type=int, default=4437)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--output_kl_results', type=str, default=None)
    args = parser.parse_args()

    analyzer = Qwen25VLHUDAnalyzer(
        sft_checkpoint_path=args.sft_checkpoint, device=args.device,
        batch_size=args.batch_size, override_hud_jsonl=args.override_hud_jsonl
    )
    analyzer.run_full_analysis(
        data_path=args.data_path, hu_low_path=args.hu_low_file,
        hu_mid_path=args.hu_mid_file, hud_jsonl_path=args.hud_jsonl,
        training_json_path=args.training_json_path, output_file=args.output_file,
        target_limit=args.target_limit, output_kl_results_path=args.output_kl_results
    )


if __name__ == "__main__":
    main()