#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import torch
import argparse
import numpy as np
import sys
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
from PIL import Image
from transformers import AutoTokenizer, BitsAndBytesConfig
import warnings

warnings.filterwarnings("ignore")
sys.path.insert(0, "/data")

from llava.model import LlavaMistralForCausalLM
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

class LLaVA16HUDAnalyzer:
    def __init__(self, model_path: str, device: str = 'cuda', load_in_8bit: bool = False,
                 override_hud_jsonl: str = None):
        self.device = device
        self.load_in_8bit = load_in_8bit
        self._init_model(model_path)
        self.vqa_prompt_template = "{question}"
        self.override_hud_mapping = {}
        if override_hud_jsonl:
            self.override_hud_mapping = self._load_override_hud_jsonl(override_hud_jsonl)
        print(f"LLaVA-v1.6 Analyzer initialized")

    def _init_model(self, model_path: str):
        self.model = LlavaMistralForCausalLM.from_pretrained(
            model_path, low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.bfloat16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
        vision_tower = self.model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model()
        self.vision_device = next(vision_tower.parameters()).device
        vision_tower.to(device=self.vision_device, dtype=torch.bfloat16)
        self.image_processor = vision_tower.image_processor
        self.conv_template = None
        for mode in ["mistral_instruct", "vicuna_v1", "llava_v1"]:
            if mode in conv_templates:
                self.conv_template = conv_templates[mode]
                break
        if self.conv_template is None:
            self.conv_template = list(conv_templates.values())[0]

    def _load_override_hud_jsonl(self, jsonl_path: str) -> Dict[int, List[float]]:
        override_mapping = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    record = json.loads(line.strip())
                    question_id = int(record['question_id'])
                    top2_probs = record.get('top2_probs', record.get('top4_probs', [])[:2])
                    if not top2_probs: continue
                    if len(top2_probs) < 2: top2_probs.extend([0.0] * (2 - len(top2_probs)))
                    override_mapping[question_id] = top2_probs[:2]
                except (json.JSONDecodeError, KeyError, TypeError):
                    continue
        return override_mapping

    def _get_hud_scores_and_answers(self, question_id: int, hud_mapping: Dict[int, Dict]) -> Optional[
        Tuple[List[float], List[str]]]:
        question_id = int(question_id)
        if question_id in hud_mapping:
            hud_record = hud_mapping[question_id]
            original_scores = hud_record.get('hud_scores', [])
            original_answers = hud_record.get('hud_answers', [])
            if len(original_scores) >= 2 and len(original_answers) >= 2:
                return original_scores[:2], original_answers[:2]
        return None

    def get_top_two_logits(self, image_path: str, question: str, data_path: str = "") -> Optional[List[float]]:
        try:
            if data_path and not image_path.startswith('/'):
                if image_path.startswith('COCO_train2014_'):
                    full_image_path = os.path.join(data_path, 'train2014', image_path)
                elif image_path.startswith('COCO_val2014_'):
                    full_image_path = os.path.join(data_path, 'val2014', image_path)
                else:
                    full_image_path = os.path.join(data_path, image_path)
            else:
                full_image_path = image_path
            if not os.path.exists(full_image_path):
                return None
            image = Image.open(full_image_path).convert('RGB')
            img_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(dtype=torch.bfloat16, device=self.vision_device)
            formatted_question = self.vqa_prompt_template.format(question=question)
            conv = self.conv_template.copy()
            roles = conv.roles
            first_input = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + formatted_question)
            conv.append_message(roles[0], first_input)
            conv.append_message(roles[1], None)
            raw_prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
            with torch.inference_mode():
                outputs = self.model(input_ids=input_ids, images=img_tensor, return_dict=True)
            logits = outputs.logits
            last_token_logits = logits[0, -1, :]
            top_two = torch.topk(last_token_logits, 2)
            return top_two.values.tolist()
        except Exception:
            return None

    def calculate_answer_distribution(self, sample: dict, data_path: str = "") -> Optional[List[float]]:
        try:
            question = None
            for conv in sample.get('conversations', []):
                if conv['from'] == 'human':
                    question = conv['value'].replace('<image>', '').strip()
                    if question.startswith('\n'):
                        question = question[1:].strip()
                    break
            if not question:
                return None
            top_two_logits = self.get_top_two_logits(sample.get('image', ''), question, data_path)
            if top_two_logits is None:
                return None
            logits_tensor = torch.tensor(top_two_logits)
            probs = F.softmax(logits_tensor, dim=-1)
            return probs.tolist()
        except Exception:
            return None

    def load_classification_files(self, hu_low_path: str, hu_mid_path: str) -> Tuple[Dict[str, List[int]], set]:
        classifications, training_set = {}, set()
        file_mapping = {'hu_low': hu_low_path, 'hu_mid': hu_mid_path}
        for name, path in file_mapping.items():
            with open(path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            if isinstance(data, dict) and 'annotations' in data:
                question_ids = [int(ann['question_id']) for ann in data['annotations']]
            else:
                question_ids = [int(qid) for qid in data] if isinstance(data, list) else []
            classifications[name] = question_ids
            training_set.update(question_ids)
        return classifications, training_set

    def load_hud_jsonl(self, hud_jsonl_path: str) -> Dict[int, Dict]:
        hud_mapping = {}
        with open(hud_jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                record = json.loads(line)
                qid = int(record['question_id'])
                hud_mapping[qid] = record
        return hud_mapping

    def load_llava_dataset(self, data_path: str) -> Tuple[List[Dict], Dict[int, int]]:
        with open(data_path, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
        qid_to_idx = {int(item.get('id', i)): i for i, item in enumerate(dataset)}
        return dataset, qid_to_idx

    def process_classification_questions(self, dataset: List[Dict], qid_to_idx: Dict[int, int],
                                         classifications: Dict[str, List[int]], hud_mapping: Dict[int, Dict],
                                         data_path: str = "") -> Dict:
        results = {'kl_thresholds': {}}
        for class_name, qids in classifications.items():
            valid_samples, kl_divergences = [], []
            for qid in qids:
                if qid in qid_to_idx and qid in hud_mapping:
                    hud_data = self._get_hud_scores_and_answers(qid, hud_mapping)
                    if hud_data:
                        valid_samples.append((qid, hud_data[0]))
            if not valid_samples:
                continue
            for qid, hud_scores in tqdm(valid_samples, desc=f"Computing KL for {class_name}"):
                sample = dataset[qid_to_idx[qid]]
                model_dist = self.calculate_answer_distribution(sample, data_path)
                if model_dist is None:
                    continue
                hud_sum = sum(hud_scores)
                if hud_sum == 0:
                    continue
                hud_dist = [s / hud_sum for s in hud_scores]
                epsilon = 1e-8
                kl = sum(hud_dist[i] * np.log((hud_dist[i] + epsilon) / (model_dist[i] + epsilon)) for i in range(2))
                if np.isfinite(kl):
                    kl_divergences.append(kl)
            if kl_divergences:
                results['kl_thresholds'][class_name] = np.mean(kl_divergences)
        return results

    def find_remaining_questions(self, training_set: set, qid_to_idx: Dict[int, int]) -> List[int]:
        all_qids = set(qid_to_idx.keys())
        return list(all_qids - training_set)

    def process_remaining_questions_targeted(self, dataset: List[Dict], qid_to_idx: Dict[int, int],
                                             remaining_qids: List[int], hud_mapping: Dict[int, Dict],
                                             data_path: str, kl_interval: Tuple[float, float],
                                             target_limit: int, output_kl_results_path: str = None) -> Dict:
        results = {'target_questions': []}
        valid_samples = []
        for qid in remaining_qids:
            if qid in qid_to_idx and qid in hud_mapping:
                hud_data = self._get_hud_scores_and_answers(qid, hud_mapping)
                if hud_data:
                    valid_samples.append((qid, hud_data[0], hud_data[1]))

        with open(output_kl_results_path, 'w', encoding='utf-8') if output_kl_results_path else open(os.devnull, 'w') as kl_file:
            for qid, hud_scores, hud_answers in tqdm(valid_samples, desc="Filtering questions"):
                if len(results['target_questions']) >= target_limit:
                    break
                sample = dataset[qid_to_idx[qid]]
                model_dist = self.calculate_answer_distribution(sample, data_path)
                if model_dist is None:
                    continue
                hud_sum = sum(hud_scores)
                if hud_sum == 0:
                    continue
                hud_dist = [s / hud_sum for s in hud_scores]
                epsilon = 1e-8
                kl = sum(hud_dist[i] * np.log((hud_dist[i] + epsilon) / (model_dist[i] + epsilon)) for i in range(2))
                if not np.isfinite(kl):
                    continue
                if output_kl_results_path:
                    gt_answer = next((c['value'] for c in sample.get('conversations', []) if c['from'] == 'gpt'), "N/A")
                    kl_file.write(json.dumps({"question_id": qid, "hud_answers": hud_answers, "hud_dist": hud_dist, "model_dist": model_dist, "kl_divergence": kl, "ground_truth": gt_answer}) + '\n')
                if kl_interval[0] <= kl <= kl_interval[1]:
                    results['target_questions'].append({'question_id': qid, 'image_path': sample.get('image'), 'top_answer': hud_answers[0], 'kl_divergence': kl})
        return results

    def run_full_analysis(self, data_path: str, hu_low_path: str, hu_mid_path: str,
                          hud_jsonl_path: str, training_json_path: str, output_file: str,
                          target_limit: int, output_kl_results_path: str = None):
        classifications, training_set = self.load_classification_files(hu_low_path, hu_mid_path)
        hud_mapping = self.load_hud_jsonl(hud_jsonl_path)
        dataset, qid_to_idx = self.load_llava_dataset(data_path)
        data_dir = os.path.dirname(data_path) if data_path.endswith('.json') else data_path
        classification_results = self.process_classification_questions(dataset, qid_to_idx, classifications, hud_mapping, data_dir)
        kl1 = classification_results.get('kl_thresholds', {}).get('hu_low', 0.0)
        kl2 = classification_results.get('kl_thresholds', {}).get('hu_mid', 0.0)
        kl_interval = (min(kl1, kl2), max(kl1, kl2)) if kl1 != 0.0 or kl2 != 0.0 else (0.0, 1.0)

        remaining_qids = self.find_remaining_questions(training_set, qid_to_idx)
        remaining_results = self.process_remaining_questions_targeted(
            dataset, qid_to_idx, remaining_qids, hud_mapping, data_dir, kl_interval, target_limit, output_kl_results_path
        )

        if output_file:
            base_name = output_file.replace('.json', '')
            stats_file = f"{base_name}_analysis_summary.json"
            final_results = {'kl_thresholds': {'hu_low': kl1, 'hu_mid': kl2}, 'kl_interval': kl_interval, 'target_count': len(remaining_results['target_questions'])}
            with open(stats_file, 'w', encoding='utf-8') as f:
                json.dump(final_results, f, indent=2)
            print(f"\nSaved summary to: {stats_file}")

            with open(training_json_path, 'r', encoding='utf-8') as f:
                existing_training_data = json.load(f)
            new_training_samples = []
            for item in remaining_results['target_questions']:
                qid = item['question_id']
                if qid in qid_to_idx:
                    original_sample = dataset[qid_to_idx[qid]]
                    new_training_samples.append({"id": str(qid), "image": original_sample.get('image', ''), "conversations": original_sample.get('conversations', [])})

            combined_training_data = existing_training_data + new_training_samples
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(combined_training_data, f, indent=2)
            print(f"Saved combined training data to: {output_file}")
            print(f"Total samples: {len(combined_training_data)} (original: {len(existing_training_data)} + new: {len(new_training_samples)})")

        print("\n" + "=" * 80)
        print("LLaVA-v1.6 Analysis Complete")
        print(f"KL Thresholds: hu_low={kl1:.6f}, hu_mid={kl2:.6f}")
        print(f"Collected {len(remaining_results['target_questions'])} questions")
        print("=" * 80)

def main():
    parser = argparse.ArgumentParser(description='LLaVA-v1.6 VQA HUD Analysis')
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--hu_low_file', type=str, required=True)
    parser.add_argument('--hu_mid_file', type=str, required=True)
    parser.add_argument('--hud_jsonl', type=str, required=True)
    parser.add_argument('--training_json_path', type=str, required=True)
    parser.add_argument('--output_file', type=str, default='llava_hud_analysis')
    parser.add_argument('--target_limit', type=int, default=4437)
    parser.add_argument('--output_kl_results', type=str, default=None)
    args = parser.parse_args()
    analyzer = LLaVA16HUDAnalyzer(model_path=args.model_path, device='cuda', load_in_8bit=False)
    analyzer.run_full_analysis(
        data_path=args.data_path,
        hu_low_path=args.hu_low_file,
        hu_mid_path=args.hu_mid_file,
        hud_jsonl_path=args.hud_jsonl,
        training_json_path=args.training_json_path,
        output_file=args.output_file,
        target_limit=args.target_limit,
        output_kl_results_path=args.output_kl_results
    )

if __name__ == "__main__":
    main()