import os
import json
import torch
import torch.nn.functional as F
from typing import Dict, Any, Optional, Union, List
from engine_for_finetuning import VQAHandler
import utils

class QuestionIDResolver:
    def __init__(self, full_vqa_path: str = None):
        self.full_vqa_path = full_vqa_path
        self.id_question_mapping = {}
        self.enabled = False
        if full_vqa_path and os.path.exists(full_vqa_path):
            self._build_mapping()
            self.enabled = True

    def _build_mapping(self):
        try:
            with open(self.full_vqa_path, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    qid = item.get('qid')
                    if qid:
                        self.id_question_mapping[qid] = {
                            'image_path': item.get('image_path', ''),
                            'text_segment': item.get('text_segment', [])
                        }
        except Exception:
            return

    def resolve_question_id(self, sample_id):
        if not self.enabled:
            return -1
        try:
            return int(sample_id) if sample_id else -1
        except (ValueError, TypeError):
            return -1

class HUDProcessor:
    def __init__(self, tokenizer, ans2label):
        self.tokenizer = tokenizer
        self.ans2label = ans2label

    def tokenize_hud_answers(self, hud_answers, hud_scores):
        tokenized_answers = []
        valid_scores = []
        valid_labels = []
        for answer, score in zip(hud_answers, hud_scores):
            if isinstance(answer, str) and len(answer.strip()) > 0:
                from glossary import normalize_word
                normalized_answer = normalize_word(answer.strip())
                if normalized_answer in self.ans2label:
                    label_id = self.ans2label[normalized_answer]
                    tokenized_answers.append(answer.strip())
                    valid_scores.append(score)
                    valid_labels.append(label_id)
        return tokenized_answers, valid_scores, valid_labels

    def compute_answer_probability(self, logits, answer_label):
        if answer_label >= logits.size(-1):
            return torch.tensor(0.0, device=logits.device)
        probs = F.softmax(logits, dim=-1)
        return probs[answer_label]

class BEiT3ThreePartLossTrainer:
    def __init__(self,
                 model,
                 base_model_path: Optional[str] = None,
                 hud_data_path: Optional[str] = None,
                 full_vqa_path: Optional[str] = None,
                 ans2label_path: Optional[str] = None,
                 beta: float = 1.0,
                 lambda_param: float = 1.0,
                 device: str = 'cuda'):
        self.model = model
        self.beta = beta
        self.lambda_param = lambda_param
        self.device = device
        self.base_model = None
        self.hud_mapping = {}
        self.ans2label = {}
        self._load_ans2label("data/")
        self._init_base_model("data/")
        self._load_hud_data("data/")
        self._init_question_resolver("data/")
        self._init_hud_processor()

    def _load_ans2label(self, ans2label_path):
        if not ans2label_path or not os.path.exists(ans2label_path):
            return
        try:
            with open(ans2label_path, 'r', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line.strip())
                    self.ans2label[data['answer']] = data['label']
        except Exception:
            pass

    def _init_base_model(self, base_model_path):
        if not base_model_path or not os.path.exists(base_model_path):
            return
        try:
            from modeling_finetune import BEiT3ForVisualQuestionAnswering
            from modeling_utils import _get_base_config
            args = _get_base_config(
                vocab_size=len(self.ans2label) if self.ans2label else 3129
            )
            self.base_model = BEiT3ForVisualQuestionAnswering(
                args,
                num_classes=len(self.ans2label) if self.ans2label else 3129
            )
            checkpoint = torch.load(base_model_path, map_location='cpu')
            if 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint
            utils.load_state_dict(self.base_model, state_dict)
            self.base_model.to(self.device)
            self.base_model.eval()
            for param in self.base_model.parameters():
                param.requires_grad = False
        except Exception:
            self.base_model = None

    def _load_hud_data(self, hud_data_path):
        if not hud_data_path or not os.path.exists(hud_data_path):
            return
        try:
            with open(hud_data_path, 'r', encoding='utf-8') as f:
                for line in f:
                    record = json.loads(line.strip())
                    question_id = int(record['question_id'])
                    self.hud_mapping[question_id] = {
                        'hud_answers': record.get('hud_answers', []),
                        'hud_scores': record.get('hud_scores', [])
                    }
        except Exception:
            self.hud_mapping = {}

    def _init_question_resolver(self, full_vqa_path):
        self.question_id_resolver = QuestionIDResolver(full_vqa_path)

    def _init_hud_processor(self):
        self.hud_processor = HUDProcessor(None, self.ans2label)

    def compute_part_a_loss(self, model_output, labels):
        if not hasattr(model_output, 'loss') or model_output.loss is None:
            logits = model_output if torch.is_tensor(model_output) else model_output.logits
            criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
            return criterion(logits.float(), labels.float()) * labels.shape[1]
        return model_output.loss

    def compute_part_b_loss(self, current_logits, base_logits):
        current_probs = F.softmax(current_logits, dim=-1)
        base_log_probs = F.log_softmax(base_logits, dim=-1)
        kl_div = F.kl_div(base_log_probs, current_probs, reduction='batchmean')
        return kl_div

    def compute_part_c_loss(self, current_logits, base_logits, question_ids):
        if self.lambda_param == 0 or not self.hud_mapping:
            return torch.tensor(0.0, device=current_logits.device)
        part_c_losses = []
        for i, qid in enumerate(question_ids):
            try:
                qid_int = int(qid.item()) if torch.is_tensor(qid) else int(qid)
                if qid_int not in self.hud_mapping or qid_int == -1:
                    continue
                hud_data = self.hud_mapping[qid_int]
                hud_answers = hud_data['hud_answers']
                hud_scores = hud_data['hud_scores']
                if not hud_answers or not hud_scores:
                    continue
                _, valid_scores, valid_labels = self.hud_processor.tokenize_hud_answers(
                    hud_answers, hud_scores
                )
                if not valid_labels:
                    continue
                valid_sum = sum(valid_scores)
                if valid_sum <= 0:
                    continue
                valid_probs = [s / valid_sum for s in valid_scores]
                sample_current_logits = current_logits[i]
                sample_base_logits = base_logits[i]
                kl_current_total = 0.0
                kl_base_total = 0.0
                for label_id, hud_prob in zip(valid_labels, valid_probs):
                    if hud_prob <= 0:
                        continue
                    current_prob = self.hud_processor.compute_answer_probability(
                        sample_current_logits, label_id
                    ).clamp(min=1e-8)
                    base_prob = self.hud_processor.compute_answer_probability(
                        sample_base_logits, label_id
                    ).clamp(min=1e-8)
                    kl_current_total += hud_prob * torch.log(
                        torch.tensor(hud_prob, device=current_logits.device) / current_prob
                    )
                    kl_base_total += hud_prob * torch.log(
                        torch.tensor(hud_prob, device=current_logits.device) / base_prob
                    )
                relative_improvement = kl_current_total - kl_base_total
                if torch.isfinite(relative_improvement):
                    part_c_losses.append(relative_improvement)
            except Exception:
                continue
        if part_c_losses:
            return torch.stack(part_c_losses).mean()
        return torch.tensor(0.0, device=current_logits.device)

    def compute_loss(self, batch_data):
        image = batch_data['image']
        language_tokens = batch_data['language_tokens']
        padding_mask = batch_data['padding_mask']
        labels = batch_data['labels']
        question_ids = batch_data.get('qid', batch_data.get('question_ids'))
        current_outputs = self.model(
            image=image,
            question=language_tokens,
            padding_mask=padding_mask
        )
        part_a_loss = self.compute_part_a_loss(current_outputs, labels)
        part_b_loss = torch.tensor(0.0, device=image.device)
        part_c_loss = torch.tensor(0.0, device=image.device)
        if self.base_model and (self.beta > 0 or (self.lambda_param > 0 and question_ids is not None)):
            try:
                with torch.no_grad():
                    base_outputs = self.base_model(
                        image=image,
                        question=language_tokens,
                        padding_mask=padding_mask
                    )
                current_logits = current_outputs if torch.is_tensor(
                    current_outputs) else current_outputs.logits if hasattr(current_outputs,
                                                                            'logits') else current_outputs
                base_logits = base_outputs if torch.is_tensor(base_outputs) else base_outputs.logits if hasattr(
                    base_outputs, 'logits') else base_outputs
                if current_logits.shape != base_logits.shape:
                    min_dim = min(current_logits.size(-1), base_logits.size(-1))
                    current_logits = current_logits[..., :min_dim]
                    base_logits = base_logits[..., :min_dim]
                if self.beta > 0:
                    part_b_loss = self.compute_part_b_loss(current_logits, base_logits)
                if self.lambda_param > 0 and question_ids is not None:
                    part_c_loss = self.compute_part_c_loss(current_logits, base_logits, question_ids)
            except Exception:
                pass
        total_loss = part_a_loss + self.beta * part_b_loss + self.lambda_param * part_c_loss
        return {
            'total_loss': total_loss,
            'part_a_loss': part_a_loss,
            'part_b_loss': part_b_loss,
            'part_c_loss': part_c_loss,
            'logits': current_logits if 'current_logits' in locals() else current_outputs
        }

class BEiT3VQAHandlerWithThreePartLoss(VQAHandler):
    def __init__(self,
                 three_part_trainer: BEiT3ThreePartLossTrainer):
        super().__init__()
        self.three_part_trainer = three_part_trainer

    def train_batch(self, model, **kwargs):
        batch_data = {
            'image': kwargs['image'],
            'language_tokens': kwargs['language_tokens'],
            'padding_mask': kwargs['padding_mask'],
            'labels': kwargs['labels'],
            'qid': kwargs.get('qid')
        }
        loss_dict = self.three_part_trainer.compute_loss(batch_data)
        return {
            'loss': loss_dict['total_loss'],
            'part_a_loss': loss_dict['part_a_loss'],
            'part_b_loss': loss_dict['part_b_loss'],
            'part_c_loss': loss_dict['part_c_loss']
        }