#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import json
import torch
import argparse
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
sys.path.insert(0, 'data/beit3')
try:
    from modeling_finetune import BEiT3ForVisualQuestionAnswering
    from modeling_utils import _get_base_config, _get_large_config
    from datasets import VQAv2Dataset, build_transform
    from transformers import XLMRobertaTokenizer
    import utils
except ImportError as e:
    sys.exit(1)
class BEiT3HUDAnalyzer:
    def __init__(self, sft_checkpoint_path: str, sentencepiece_path: str,
                 model_size: str = 'base', input_size: int = 480,
                 device: str = 'cuda', batch_size: int = 8,
                 override_hud_jsonl: str = None):
        self.device = device
        self.model_size = model_size
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = self._detect_sft_output_dim(sft_checkpoint_path)
        self._init_model(sft_checkpoint_path)
        self.tokenizer = XLMRobertaTokenizer(sentencepiece_path)
        self.transform = build_transform(is_train=False,
                                         args=argparse.Namespace(
                                             input_size=input_size,
                                             task='vqav2'
                                         ))
        self.override_hud_mapping = {}
        if override_hud_jsonl:
            self.override_hud_mapping = self._load_override_hud_jsonl(override_hud_jsonl)

    def _load_override_hud_jsonl(self, jsonl_path: str) -> Dict[int, List[float]]:
        override_mapping = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    record = json.loads(line.strip())
                    question_id = int(record['question_id'])
                    top2_probs = record.get('top2_probs', record.get('top4_probs', [])[:2])
                    if not top2_probs: continue
                    if len(top2_probs) < 2: top2_probs.extend([0.0] * (2 - len(top2_probs)))
                    override_mapping[question_id] = top2_probs[:2]
                except (json.JSONDecodeError, KeyError, TypeError):
                    continue
        return override_mapping

    def _detect_sft_output_dim(self, checkpoint_path: str) -> int:
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            state_dict = checkpoint.get('model', checkpoint.get('state_dict', checkpoint))
            head_keys = [k for k in state_dict.keys() if 'head.' in k and '.weight' in k]
            max_head_index = -1
            final_head_key = None
            for key in head_keys:
                weight = state_dict[key]
                if weight.dim() == 2:
                    try:
                        parts = key.split('.')
                        if len(parts) >= 2 and parts[1].isdigit():
                            head_index = int(parts[1])
                            if head_index > max_head_index:
                                max_head_index = head_index
                                final_head_key = key
                    except (IndexError, ValueError):
                        continue
            if final_head_key is not None:
                num_classes = state_dict[final_head_key].shape[0]
                return num_classes
            raise ValueError("Unable to detect output dimensions from checkpoint")
        except Exception as e:
            raise ValueError(f"Model output dimension detection failed: {e}")

    def _init_model(self, checkpoint_path: str):
        if self.model_size == 'base':
            config = _get_base_config(img_size=self.input_size)
        else:
            config = _get_large_config(img_size=self.input_size)
        config.normalize_output = False
        self.model = BEiT3ForVisualQuestionAnswering(config, num_classes=self.num_classes)
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        state_dict = checkpoint.get('model', checkpoint.get('state_dict', checkpoint))
        self.model.load_state_dict(state_dict, strict=False)
        self.model.to(self.device)
        self.model.eval()

    def _get_hud_scores_and_answers(self, question_id: int, hud_mapping: Dict[int, Dict]) -> Optional[
        Tuple[List[float], List[str]]]:
        question_id = int(question_id)
        if question_id in self.override_hud_mapping:
            return self.override_hud_mapping[question_id], ["N/A", "N/A"]
        elif question_id in hud_mapping:
            hud_record = hud_mapping[question_id]
            original_scores = hud_record.get('hud_scores', [])
            original_answers = hud_record.get('hud_answers', [])
            if len(original_scores) >= 2 and len(original_answers) >= 2:
                return original_scores[:2], original_answers[:2]
        return None

    def load_classification_files(self, hu_low_path: str, hu_mid_path: str) -> Tuple[Dict[str, List[int]], set]:
        classifications = {}
        training_set = set()
        with open(hu_low_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        if isinstance(data, dict) and 'annotations' in data:
            question_ids = [int(ann['question_id']) for ann in data['annotations']]
        else:
            question_ids = [int(qid) for qid in data] if isinstance(data, list) else []
        classifications['hu_low'] = question_ids
        training_set.update(question_ids)
        with open(hu_mid_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        if isinstance(data, dict) and 'annotations' in data:
            question_ids = [int(ann['question_id']) for ann in data['annotations']]
        else:
            question_ids = [int(qid) for qid in data] if isinstance(data, list) else []
        classifications['hu_mid'] = question_ids
        training_set.update(question_ids)
        return classifications, training_set

    def load_hud_jsonl(self, hud_jsonl_path: str) -> Dict[int, Dict]:
        hud_mapping = {}
        with open(hud_jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                record = json.loads(line.strip())
                qid = int(record['question_id'])
                hud_mapping[qid] = record
        return hud_mapping

    def load_vqa_dataset(self, data_path: str) -> VQAv2Dataset:
        return VQAv2Dataset(
            data_path=data_path,
            split='train',
            transform=self.transform,
            tokenizer=self.tokenizer,
            num_max_bpe_tokens=64,
            task='vqav2'
        )

    def load_answer_mapping(self, data_path: str) -> Tuple[Dict[str, int], Dict[int, str]]:
        answer2label_path = os.path.join(data_path, "answer2label.txt")
        if not os.path.exists(answer2label_path):
            raise FileNotFoundError(f"answer2label.txt file not found: {answer2label_path}")
        ans2label = {}
        label2ans = {}
        with open(answer2label_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    data = json.loads(line)
                    answer = data["answer"]
                    label = int(data["label"])
                    ans2label[answer] = label
                    label2ans[label] = answer
        expected_classes = max(label2ans.keys()) + 1 if label2ans else 0
        if expected_classes != self.num_classes:
            raise ValueError(
                f"Dimension mismatch: model {self.num_classes} vs answer mapping {expected_classes}")
        return ans2label, label2ans

    def create_qid_to_idx_mapping(self, dataset: VQAv2Dataset) -> Dict[int, int]:
        qid_to_idx = {}
        for i, item in enumerate(dataset.items):
            qid = item.get('qid')
            if qid is not None:
                qid_to_idx[qid] = i
        return qid_to_idx

    def calculate_batch_predictions(self, batch_samples: List[dict]) -> List[Optional[torch.Tensor]]:
        try:
            batch_results = []
            batch_images, batch_language_tokens, batch_padding_masks, valid_indices = [], [], [], []
            for i, vqa_sample in enumerate(batch_samples):
                try:
                    image = vqa_sample['image']
                    if not isinstance(image, torch.Tensor):
                        batch_results.append(None); continue
                    language_tokens = vqa_sample['language_tokens']
                    if isinstance(language_tokens, list): language_tokens = torch.tensor(language_tokens, dtype=torch.long)
                    padding_mask = vqa_sample['padding_mask']
                    if isinstance(padding_mask, list): padding_mask = torch.tensor(padding_mask, dtype=torch.long)
                    batch_images.append(image)
                    batch_language_tokens.append(language_tokens)
                    batch_padding_masks.append(padding_mask)
                    valid_indices.append(i)
                except Exception:
                    batch_results.append(None)
            if not valid_indices: return [None] * len(batch_samples)
            batch_images = torch.stack(batch_images).to(self.device)
            batch_language_tokens = torch.stack(batch_language_tokens).to(self.device)
            batch_padding_masks = torch.stack(batch_padding_masks).to(self.device)
            with torch.no_grad():
                batch_logits = self.model(image=batch_images, question=batch_language_tokens, padding_mask=batch_padding_masks.bool())
            if torch.isnan(batch_logits).any() or torch.isinf(batch_logits).any(): return [None] * len(batch_samples)
            if batch_logits.shape[-1] != self.num_classes: return [None] * len(batch_samples)
            batch_probabilities = F.softmax(batch_logits, dim=-1)
            result_idx = 0
            for i in range(len(batch_samples)):
                if i in valid_indices:
                    batch_results.append(batch_probabilities[result_idx])
                    result_idx += 1
                else:
                    batch_results.append(None)
            return batch_results
        except Exception:
            return [None] * len(batch_samples)

    def process_classification_questions(self, dataset: VQAv2Dataset, qid_to_idx: Dict[int, int],
                                         classifications: Dict[str, List[int]], hud_mapping: Dict[int, Dict],
                                         ans2label: Dict[str, int], label2ans: Dict[int, str]) -> Dict:
        results = {'kl_thresholds': {}}
        for class_name, qids in classifications.items():
            valid_samples = []
            for qid in qids:
                if qid not in qid_to_idx: continue
                hud_data = self._get_hud_scores_and_answers(qid, hud_mapping)
                if hud_data is None: continue
                hud_scores, _ = hud_data
                valid_samples.append((qid, hud_scores))
            if not valid_samples: continue
            kl_divergences = []
            for batch_start in tqdm(range(0, len(valid_samples), self.batch_size), desc=f"Computing KL for {class_name}"):
                batch_end = min(batch_start + self.batch_size, len(valid_samples))
                batch_qids_scores = valid_samples[batch_start:batch_end]
                batch_samples, batch_hud_scores = [], []
                for qid, hud_scores in batch_qids_scores:
                    try:
                        sample_idx = qid_to_idx[qid]
                        sample = dataset[sample_idx]
                        batch_samples.append(sample)
                        batch_hud_scores.append(hud_scores)
                    except Exception: continue
                if not batch_samples: continue
                batch_probs = self.calculate_batch_predictions(batch_samples)
                for probs, hud_scores in zip(batch_probs, batch_hud_scores):
                    if probs is None: continue
                    top2_probs, _ = torch.topk(probs, k=min(2, len(probs)))
                    model_dist = top2_probs.tolist()
                    while len(model_dist) < 2: model_dist.append(0.0)
                    model_dist = model_dist[:2]
                    hud_sum = sum(hud_scores)
                    if hud_sum == 0: continue
                    hud_dist = [s / hud_sum for s in hud_scores]
                    model_sum = sum(model_dist)
                    if model_sum == 0: continue
                    model_dist = [d / model_sum for d in model_dist]
                    epsilon = 1e-8
                    kl = sum(hud_dist[i] * np.log((hud_dist[i] + epsilon) / (model_dist[i] + epsilon)) for i in range(2))
                    if np.isfinite(kl): kl_divergences.append(kl)
            if kl_divergences:
                results['kl_thresholds'][class_name] = np.mean(kl_divergences)
        return results

    def find_remaining_questions(self, training_set: set, qid_to_idx: Dict[int, int]) -> List[int]:
        all_qids = set(qid_to_idx.keys())
        return list(all_qids - training_set)

    def process_remaining_questions_targeted(self, dataset: VQAv2Dataset, qid_to_idx: Dict[int, int],
                                             remaining_qids: List[int], hud_mapping: Dict[int, Dict],
                                             kl_interval: Tuple[float, float],
                                             ans2label: Dict[str, int], label2ans: Dict[int, str],
                                             target_limit: int, output_kl_results_path: str = None) -> Dict:
        results = {'target_questions': []}
        valid_samples = []
        for qid in remaining_qids:
            if qid not in qid_to_idx: continue
            hud_data = self._get_hud_scores_and_answers(qid, hud_mapping)
            if hud_data is None: continue
            hud_scores, hud_answers = hud_data
            valid_samples.append((qid, hud_scores, hud_answers))
        collected_count = 0
        with open(output_kl_results_path, 'w', encoding='utf-8') if output_kl_results_path else open(os.devnull, 'w') as kl_file:
            for batch_start in tqdm(range(0, len(valid_samples), self.batch_size), desc="Filtering questions"):
                if collected_count >= target_limit: break
                batch_end = min(batch_start + self.batch_size, len(valid_samples))
                batch_data = valid_samples[batch_start:batch_end]
                batch_samples, batch_qids, batch_hud_data = [], [], []
                for qid, hud_scores, hud_answers in batch_data:
                    try:
                        sample_idx = qid_to_idx[qid]
                        sample = dataset[sample_idx]
                        batch_samples.append(sample)
                        batch_qids.append(qid)
                        batch_hud_data.append((hud_scores, hud_answers))
                    except Exception: continue
                if not batch_samples: continue
                batch_probs = self.calculate_batch_predictions(batch_samples)
                for i, (qid, probs) in enumerate(zip(batch_qids, batch_probs)):
                    if collected_count >= target_limit: break
                    if probs is None: continue
                    hud_scores, hud_answers = batch_hud_data[i]
                    top2_probs, top2_indices = torch.topk(probs, k=min(2, len(probs)))
                    top2_answers = [label2ans[idx.item()] for idx in top2_indices]
                    model_dist = top2_probs.tolist()
                    while len(model_dist) < 2: model_dist.append(0.0)
                    model_dist = model_dist[:2]
                    hud_sum = sum(hud_scores)
                    if hud_sum == 0: continue
                    hud_dist = [s / hud_sum for s in hud_scores]
                    model_sum = sum(model_dist)
                    if model_sum == 0: continue
                    model_dist = [d / model_sum for d in model_dist]
                    epsilon = 1e-8
                    kl = sum(hud_dist[j] * np.log((hud_dist[j] + epsilon) / (model_dist[j] + epsilon)) for j in range(2))
                    if not np.isfinite(kl): continue
                    if output_kl_results_path:
                        kl_file.write(json.dumps({"question_id": qid, "hud_answers": hud_answers, "hud_dist": hud_dist, "model_dist": model_dist, "kl_divergence": kl, "top2_answers": top2_answers}) + '\n')
                    if kl_interval[0] <= kl <= kl_interval[1]:
                        results['target_questions'].append({
                            'question_id': qid, 'top2_indices': top2_indices.tolist(),
                            'top2_answers': top2_answers, 'top2_probs': top2_probs.tolist(),
                            'kl_divergence': kl, 'used_override_hud': qid in self.override_hud_mapping
                        })
                        collected_count += 1
        return results

    def run_full_analysis(self, data_path: str, hu_low_path: str, hu_mid_path: str,
                          hud_jsonl_path: str, training_json_path: str, output_file: str,
                          target_limit: int, output_kl_results_path: str = None) -> Dict:
        classifications, training_set = self.load_classification_files(hu_low_path, hu_mid_path)
        hud_mapping = self.load_hud_jsonl(hud_jsonl_path)
        dataset = self.load_vqa_dataset(data_path)
        ans2label, label2ans = self.load_answer_mapping(data_path)
        qid_to_idx = self.create_qid_to_idx_mapping(dataset)
        classification_results = self.process_classification_questions(
            dataset, qid_to_idx, classifications, hud_mapping, ans2label, label2ans)
        kl1 = classification_results.get('kl_thresholds', {}).get('hu_low', 0.0)
        kl2 = classification_results.get('kl_thresholds', {}).get('hu_mid', 0.0)
        kl_interval = (min(kl1, kl2), max(kl1, kl2)) if kl1 != 0.0 or kl2 != 0.0 else (0.0, 1.0)
        remaining_qids = self.find_remaining_questions(training_set, qid_to_idx)
        remaining_results = self.process_remaining_questions_targeted(
            dataset, qid_to_idx, remaining_qids, hud_mapping, kl_interval,
            ans2label, label2ans, target_limit, output_kl_results_path)
        if output_file:
            target_file = f"{output_file}_target_questions.jsonl"
            with open(target_file, 'w', encoding='utf-8') as f:
                for question_result in remaining_results['target_questions']:
                    f.write(json.dumps(question_result, ensure_ascii=False) + '\n')
            stats_file = f"{output_file}_analysis_summary.json"
            final_results = {
                'kl_thresholds': {'hu_low': kl1, 'hu_mid': kl2}, 'kl_interval': kl_interval,
                'statistics': {
                    'total_questions_in_dataset': len(qid_to_idx), 'training_set_size': len(training_set),
                    'remaining_questions_size': len(remaining_qids), 'target_questions_collected': len(remaining_results['target_questions']),
                    'override_hud_questions': len(self.override_hud_mapping)
                }
            }
            with open(stats_file, 'w', encoding='utf-8') as f:
                json.dump(final_results, f, indent=2, ensure_ascii=False)
        self._print_final_summary(kl1, kl2, len(remaining_results['target_questions']))
        return final_results
    def _print_final_summary(self, kl1: float, kl2: float, collected_count: int):
        pass
def main():
    parser = argparse.ArgumentParser(description='BEiT3 HUD Analysis')
    parser.add_argument('--sft_checkpoint', type=str, required=True)
    parser.add_argument('--sentencepiece_path', type=str, required=True)
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--hu_low_file', type=str, required=True)
    parser.add_argument('--hu_mid_file', type=str, required=True)
    parser.add_argument('--hud_jsonl', type=str, required=True)
    parser.add_argument('--training_json_path', type=str, required=True)
    parser.add_argument('--override_hud_jsonl', type=str, default=None)
    parser.add_argument('--model_size', type=str, default='base', choices=['base', 'large'])
    parser.add_argument('--input_size', type=int, default=480, choices=[384, 480, 768])
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--output_file', type=str, default='beit3_hud_analysis')
    parser.add_argument('--target_limit', type=int, default=4437)
    parser.add_argument('--output_kl_results', type=str, default=None)
    args = parser.parse_args()
    required_files = [
        (args.sft_checkpoint, 'SFT checkpoint'),
        (args.sentencepiece_path, 'Tokenizer'),
        (args.data_path, 'Data path'),
        (args.hu_low_file, 'hu_low classification file'),
        (args.hu_mid_file, 'hu_mid classification file'),
        (args.hud_jsonl, 'HUD JSONL file'),
        (args.training_json_path, 'Training JSON file')
    ]
    for path, name in required_files:
        if not os.path.exists(path):
            raise FileNotFoundError(f"File not found: {name} - {path}")
    if args.override_hud_jsonl and not os.path.exists(args.override_hud_jsonl):
        raise FileNotFoundError(f"Override HUD file not found: {args.override_hud_jsonl}")
    try:
        analyzer = BEiT3HUDAnalyzer(
            sft_checkpoint_path=args.sft_checkpoint,
            sentencepiece_path=args.sentencepiece_path,
            model_size=args.model_size,
            input_size=args.input_size,
            device=args.device,
            batch_size=args.batch_size,
            override_hud_jsonl=args.override_hud_jsonl
        )
        analyzer.run_full_analysis(
            data_path=args.data_path,
            hu_low_path=args.hu_low_file,
            hu_mid_path=args.hu_mid_file,
            hud_jsonl_path=args.hud_jsonl,
            training_json_path=args.training_json_path,
            output_file=args.output_file,
            target_limit=args.target_limit,
            output_kl_results_path=args.output_kl_results
        )
        return 0
    except (ValueError, Exception):
        import traceback
        traceback.print_exc()
        return 1
if __name__ == "__main__":
    exit(main())