#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from src.vlm.train.sft.trainer import CustomSeq2SeqTrainer
import torch
import torch.nn.functional as F
import logging
import os
import json
from typing import Dict, Any, Optional, Union, List
from transformers import Qwen2_5_VLForConditionalGeneration

class QuestionIDResolver:
    def __init__(self, full_vqa_path: str = None):
        self.full_vqa_path = full_vqa_path
        self.id_question_mapping = {}
        self.enabled = False
        if full_vqa_path and os.path.exists(full_vqa_path):
            self._build_mapping()
            self.enabled = True

    def _build_mapping(self):
        try:
            with open(self.full_vqa_path, 'r', encoding='utf-8') as f:
                full_data = json.load(f)
        except Exception as e:
            print(f"Error loading full VQA data: {e}")
            return
        for item in full_data:
            sample_id = item.get('id', '')
            question = self._extract_question(item.get('conversations', []))
            if sample_id and question:
                self.id_question_mapping[sample_id] = {
                    'question': question.lower().strip(),
                    'image': item.get('image', '')
                }

    def _extract_question(self, conversations):
        for conv in conversations:
            if conv.get('from') == 'human':
                return conv.get('value', '').replace('<image>', '').strip()
        return ""

    def resolve_question_id(self, sample_id: str) -> int:
        if not self.enabled:
            return -1
        try:
            return int(sample_id)
        except (ValueError, TypeError):
            return -1

class HUDProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def tokenize_hud_answers(self, hud_answers, hud_scores):
        tokenized_answers = []
        valid_scores = []
        for answer, score in zip(hud_answers, hud_scores):
            if isinstance(answer, str) and len(answer.strip()) > 0:
                tokens = self.tokenizer.encode(answer.strip(), add_special_tokens=False)
                if len(tokens) > 0:
                    tokenized_answers.append(tokens)
                    valid_scores.append(score)
        return tokenized_answers, valid_scores

    def compute_answer_sequence_probability(self, logits, answer_tokens, labels):
        if len(answer_tokens) == 0:
            return torch.tensor(0.0, device=logits.device)
        valid_positions = (labels != -100).nonzero(as_tuple=True)[0]
        if len(valid_positions) == 0:
            return torch.tensor(0.0, device=logits.device)
        answer_start = valid_positions[0].item()
        probs = F.softmax(logits, dim=-1)
        log_prob = 0.0
        for i, token_id in enumerate(answer_tokens):
            pos = answer_start + i
            if pos < len(logits) and pos < len(labels):
                if labels[pos] != -100:
                    token_prob = probs[pos, token_id]
                    log_prob += torch.log(token_prob + 1e-8)
                else:
                    break
            else:
                break
        return log_prob

class ThreePartLossTrainer(CustomSeq2SeqTrainer):
    def __init__(self,
                 beta: float = 1.0,
                 lambda_param: float = 1.0,
                 base_model_path: Optional[str] = None,
                 hud_data_path: Optional[str] = None,
                 full_vqa_path: Optional[str] = None,
                 **kwargs):
        super().__init__(**kwargs)
        self.beta = beta
        self.lambda_param = lambda_param
        self.base_model_path = base_model_path
        self.hud_data_path = hud_data_path
        self.full_vqa_path = full_vqa_path
        self.base_model = None
        self.hud_mapping = {}
        self._init_base_model()
        self._load_hud_data()
        self._init_question_resolver()
        self._init_hud_processor()

    def _init_base_model(self):
        if not self.base_model_path or not os.path.exists(self.base_model_path):
            return
        try:
            self.base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.bfloat16 if self.args.bf16 else torch.float32,
                device_map=None,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            self.base_model.eval()
            for param in self.base_model.parameters():
                param.requires_grad = False
        except Exception as e:
            logging.error(f"Failed to load base model: {e}")
            self.base_model = None

    def _load_hud_data(self):
        if not self.hud_data_path or not os.path.exists(self.hud_data_path):
            return
        try:
            with open(self.hud_data_path, 'r', encoding='utf-8') as f:
                for line in f:
                    record = json.loads(line.strip())
                    question_id = int(record['question_id'])
                    self.hud_mapping[question_id] = {
                        'hud_answers': record.get('hud_answers', []),
                        'hud_scores': record.get('hud_scores', [])
                    }
        except Exception as e:
            logging.error(f"Error loading HUD data: {e}")
            self.hud_mapping = {}

    def _init_question_resolver(self):
        self.question_id_resolver = QuestionIDResolver(self.full_vqa_path)

    def _init_hud_processor(self):
        self.hud_processor = HUDProcessor(self.tokenizer)

    def compute_part_b_loss(self, current_logits, base_logits, labels):
        mask = (labels != -100).float()
        if mask.sum() == 0:
            return torch.tensor(0.0, device=current_logits.device)
        current_probs = F.softmax(current_logits, dim=-1)
        base_log_probs = F.log_softmax(base_logits, dim=-1)
        kl_div = F.kl_div(base_log_probs, current_probs, reduction='none').sum(-1)
        masked_kl = kl_div * mask
        return masked_kl.sum() / mask.sum()

    def compute_part_c_loss(self, current_logits, base_logits, labels, question_ids):
        if self.lambda_param == 0 or not self.hud_mapping:
            return torch.tensor(0.0, device=current_logits.device)
        part_c_losses = []
        for i, qid in enumerate(question_ids):
            try:
                qid_int = int(qid.item()) if torch.is_tensor(qid) else int(qid)
                if qid_int not in self.hud_mapping or qid_int == -1:
                    continue
                hud_data = self.hud_mapping[qid_int]
                hud_answers = hud_data['hud_answers']
                hud_scores = hud_data['hud_scores']
                if not hud_answers or not hud_scores:
                    continue
                hud_sum = sum(hud_scores)
                if hud_sum <= 0:
                    continue
                tokenized_answers, valid_scores = self.hud_processor.tokenize_hud_answers(
                    hud_answers, hud_scores
                )
                if not tokenized_answers:
                    continue
                valid_sum = sum(valid_scores)
                if valid_sum <= 0:
                    continue
                valid_probs = [s / valid_sum for s in valid_scores]
                sample_current_logits = current_logits[i]
                sample_base_logits = base_logits[i]
                sample_labels = labels[i]
                kl_current_total = 0.0
                kl_base_total = 0.0
                for answer_tokens, hud_prob in zip(tokenized_answers, valid_probs):
                    if hud_prob <= 0:
                        continue
                    current_log_prob = self.hud_processor.compute_answer_sequence_probability(
                        sample_current_logits, answer_tokens, sample_labels
                    )
                    base_log_prob = self.hud_processor.compute_answer_sequence_probability(
                        sample_base_logits, answer_tokens, sample_labels
                    )
                    current_prob = torch.exp(current_log_prob).clamp(min=1e-8)
                    base_prob = torch.exp(base_log_prob).clamp(min=1e-8)
                    kl_current_total += hud_prob * torch.log(torch.tensor(hud_prob, device=current_logits.device) / current_prob)
                    kl_base_total += hud_prob * torch.log(torch.tensor(hud_prob, device=current_logits.device) / base_prob)
                relative_improvement = kl_current_total - kl_base_total
                if torch.isfinite(relative_improvement):
                    part_c_losses.append(relative_improvement)
            except Exception as e:
                logging.warning(f"Error computing Part C loss for qid {qid}: {e}")
                continue
        if part_c_losses:
            return torch.stack(part_c_losses).mean()
        return torch.tensor(0.0, device=current_logits.device)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        question_ids = inputs.pop('question_ids', None)
        sample_ids = inputs.pop('sample_ids', None)
        if question_ids is None and sample_ids is not None:
            question_ids = [self.question_id_resolver.resolve_question_id(sid) for sid in sample_ids]
        outputs = model(**inputs)
        part_a_loss = outputs.loss if outputs.loss is not None else torch.tensor(0.0, device=self.args.device)
        part_b_loss = torch.tensor(0.0, device=self.args.device)
        part_c_loss = torch.tensor(0.0, device=self.args.device)
        if self.base_model and (self.beta > 0 or (self.lambda_param > 0 and question_ids is not None)):
            try:
                with torch.no_grad():
                    base_outputs = self.base_model(**inputs)
                if hasattr(base_outputs, 'logits') and hasattr(outputs, 'logits'):
                    current_logits = outputs.logits
                    base_logits = base_outputs.logits
                    labels = inputs.get('labels', None)
                    if current_logits.shape != base_logits.shape:
                        min_length = min(current_logits.size(1), base_logits.size(1))
                        current_logits = current_logits[:, :min_length, :]
                        base_logits = base_logits[:, :min_length, :]
                        if labels is not None and labels.size(1) > min_length:
                            labels = labels[:, :min_length]
                    if self.beta > 0 and labels is not None:
                        part_b_loss = self.compute_part_b_loss(current_logits, base_logits, labels)
                    if self.lambda_param > 0 and question_ids is not None:
                        part_c_loss = self.compute_part_c_loss(current_logits, base_logits, labels, question_ids)
            except Exception as e:
                logging.warning(f"Error computing base model outputs: {e}")
        total_loss = part_a_loss + self.beta * part_b_loss + self.lambda_param * part_c_loss
        outputs.loss = total_loss
        return (total_loss, outputs) if return_outputs else total_loss