#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import logging
import torch
import torch.nn.functional as F
from typing import Dict, Optional, Any
from llava.train.llava_trainer import LLaVATrainer
from llava.model import *

class QuestionIDResolver:
    def __init__(self, full_vqa_path: str = None):
        self.full_vqa_path = full_vqa_path
        self.id_question_mapping = {}
        self.enabled = False
        if full_vqa_path and os.path.exists(full_vqa_path):
            self._build_mapping()
            self.enabled = True

    def _build_mapping(self):
        try:
            with open(self.full_vqa_path, 'r', encoding='utf-8') as f:
                full_data = json.load(f)
        except Exception as e:
            print(f"Error loading full VQA data: {e}")
            return
        for item in full_data:
            sample_id = item.get('id', '')
            question = self._extract_question(item.get('conversations', []))
            if sample_id and question:
                self.id_question_mapping[sample_id] = {
                    'question': question.lower().strip(),
                    'image': item.get('image', '')
                }

    def _extract_question(self, conversations):
        for conv in conversations:
            if conv.get('from') == 'human':
                return conv.get('value', '').replace('<image>', '').strip()
        return ""

    def resolve_question_id(self, sample_id: str) -> int:
        if not self.enabled:
            return -1
        try:
            return int(sample_id)
        except (ValueError, TypeError):
            return -1

class HUDProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def tokenize_hud_answers(self, hud_answers, hud_scores):
        tokenized_answers = []
        valid_scores = []
        for answer, score in zip(hud_answers, hud_scores):
            if isinstance(answer, str) and len(answer.strip()) > 0:
                tokens = self.tokenizer.encode(answer.strip(), add_special_tokens=False)
                if len(tokens) > 0:
                    tokenized_answers.append(tokens)
                    valid_scores.append(score)
        return tokenized_answers, valid_scores

    def compute_answer_sequence_probability(self, logits, answer_tokens, seq_len=None):
        if not answer_tokens:
            return torch.tensor(0.0, device=logits.device)
        probs = F.softmax(logits, dim=-1)
        log_prob = 0.0
        effective_len = seq_len if seq_len is not None else len(logits)
        start_pos = max(0, effective_len - len(answer_tokens))
        for i, token_id in enumerate(answer_tokens):
            pos = start_pos + i
            if pos < len(logits):
                token_prob = probs[pos, token_id]
                log_prob += torch.log(token_prob + 1e-8)
        return log_prob

class LLaVAThreePartTrainer(LLaVATrainer):
    def __init__(self, *args, **kwargs):
        self.data_args = kwargs.pop('data_args', None)
        self.beta = getattr(self.data_args, 'beta', 0.0) if self.data_args else 0.0
        self.lambda_param = getattr(self.data_args, 'lambda_param', 0.0) if self.data_args else 0.0
        self.base_model_path = getattr(self.data_args, 'base_model_path', None) if self.data_args else None
        self.hud_data_path = getattr(self.data_args, 'hud_data_path', None) if self.data_args else None
        self.full_vqa_path = getattr(self.data_args, 'full_vqa_path', None) if self.data_args else None
        super().__init__(*args, **kwargs)
        self._init_base_model()
        self._load_hud_data()
        self._init_question_resolver()
        self._init_hud_processor()
        self.loss_logs = []

    def _init_base_model(self):
        self.base_model = None
        if not self.base_model_path or not os.path.exists(self.base_model_path):
            return
        try:
            self.base_model = LlavaMistralForCausalLM.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.bfloat16 if self.args.bf16 else torch.float16,
                device_map="auto"
            )
            self.base_model.eval()
            for param in self.base_model.parameters():
                param.requires_grad = False
        except Exception as e:
            logging.error(f"Failed to load base model: {e}")
            self.base_model = None

    def _load_hud_data(self):
        self.hud_mapping = {}
        if not self.hud_data_path or not os.path.exists(self.hud_data_path):
            return
        try:
            with open(self.hud_data_path, 'r', encoding='utf-8') as f:
                for line in f:
                    record = json.loads(line.strip())
                    question_id = int(record['question_id'])
                    self.hud_mapping[question_id] = {
                        'hud_answers': record.get('hud_answers', []),
                        'hud_scores': record.get('hud_scores', [])
                    }
        except Exception as e:
            logging.error(f"Error loading HUD data: {e}")
            self.hud_mapping = {}

    def _init_question_resolver(self):
        self.question_id_resolver = QuestionIDResolver(self.full_vqa_path)

    def _init_hud_processor(self):
        self.hud_processor = HUDProcessor(self.tokenizer)

    def compute_part_b_loss(self, current_logits, base_logits, labels):
        mask = (labels != -100).float()
        if mask.sum() == 0:
            return torch.tensor(0.0, device=current_logits.device)
        current_probs = F.softmax(current_logits, dim=-1)
        base_log_probs = F.log_softmax(base_logits, dim=-1)
        kl_div = F.kl_div(base_log_probs, current_probs, reduction='none', log_target=False).sum(-1)
        masked_kl = kl_div * mask
        return masked_kl.sum() / mask.sum()

    def compute_part_c_loss(self, current_logits, base_logits, question_ids, inputs: Dict[str, Any]):
        if self.lambda_param == 0 or not self.hud_mapping:
            return torch.tensor(0.0, device=current_logits.device)
        part_c_losses = []
        attention_mask = inputs.get('attention_mask')
        seq_lengths = attention_mask.sum(dim=1) if attention_mask is not None else [len(logits) for logits in current_logits]
        for i, qid in enumerate(question_ids):
            try:
                qid_int = int(qid.item()) if torch.is_tensor(qid) else int(qid)
                if qid_int not in self.hud_mapping or qid_int == -1:
                    continue
                hud_data = self.hud_mapping[qid_int]
                if not hud_data.get('hud_answers') or not hud_data.get('hud_scores'):
                    continue
                tokenized_answers, valid_scores = self.hud_processor.tokenize_hud_answers(
                    hud_data['hud_answers'], hud_data['hud_scores']
                )
                if not tokenized_answers:
                    continue
                valid_sum = sum(valid_scores)
                if valid_sum <= 0:
                    continue
                hud_dist = [s / valid_sum for s in valid_scores]
                sample_current_logits = current_logits[i]
                sample_base_logits = base_logits[i]
                actual_seq_len = seq_lengths[i].item()
                kl_current_total = 0.0
                kl_base_total = 0.0
                for answer_tokens, hud_prob in zip(tokenized_answers, hud_dist):
                    if hud_prob <= 0:
                        continue
                    current_log_prob = self.hud_processor.compute_answer_sequence_probability(
                        sample_current_logits, answer_tokens, actual_seq_len
                    )
                    base_log_prob = self.hud_processor.compute_answer_sequence_probability(
                        sample_base_logits, answer_tokens, actual_seq_len
                    )
                    current_prob = torch.exp(current_log_prob).clamp(min=1e-8)
                    base_prob = torch.exp(base_log_prob).clamp(min=1e-8)
                    device = current_logits.device
                    kl_current_total += hud_prob * (torch.log(torch.tensor(hud_prob, device=device)) - torch.log(current_prob))
                    kl_base_total += hud_prob * (torch.log(torch.tensor(hud_prob, device=device)) - torch.log(base_prob))
                relative_improvement = kl_current_total - kl_base_total
                if torch.isfinite(relative_improvement):
                    part_c_losses.append(relative_improvement)
            except Exception as e:
                logging.warning(f"Error computing Part C loss for qid {qid}: {e}")
                continue
        return torch.stack(part_c_losses).mean() if part_c_losses else torch.tensor(0.0, device=current_logits.device)

    def compute_loss(self, model, inputs, return_outputs=False):
        sample_ids = inputs.pop('sample_ids', None)
        labels = inputs.get('labels')
        outputs = model(**inputs)
        part_a_loss = outputs.loss if outputs.loss is not None else torch.tensor(0.0, device=self.args.device)
        part_b_loss = torch.tensor(0.0, device=self.args.device)
        part_c_loss = torch.tensor(0.0, device=self.args.device)
        question_ids = [self.question_id_resolver.resolve_question_id(sid) for sid in sample_ids] if sample_ids else None
        if self.base_model is not None and (self.beta > 0 or (self.lambda_param > 0 and question_ids)):
            try:
                with torch.no_grad():
                    base_outputs = self.base_model(**inputs)
                if hasattr(outputs, 'logits') and hasattr(base_outputs, 'logits'):
                    current_logits = outputs.logits
                    base_logits = base_outputs.logits
                    if current_logits.shape != base_logits.shape:
                        min_len = min(current_logits.size(1), base_logits.size(1))
                        current_logits = current_logits[:, :min_len, :]
                        base_logits = base_logits[:, :min_len, :]
                        if labels is not None:
                            labels = labels[:, :min_len]
                    if self.beta > 0 and labels is not None:
                        part_b_loss = self.compute_part_b_loss(current_logits, base_logits, labels)
                    if self.lambda_param > 0 and question_ids is not None:
                        part_c_loss = self.compute_part_c_loss(current_logits, base_logits, question_ids, inputs)
            except Exception as e:
                logging.error(f"Error during base model inference or Part B/C loss calculation: {e}", exc_info=True)
        total_loss = part_a_loss + self.beta * part_b_loss + self.lambda_param * part_c_loss
        return (total_loss, outputs) if return_outputs else total_loss