import torch
import torch.nn.functional as F
import re
import numpy as np
import pandas as pd
import os
import glob
import sys
from collections import OrderedDict
from types import SimpleNamespace
from transformers import AutoModel, AutoTokenizer

if "config" not in sys.modules:
    cfg = SimpleNamespace()
    
    current_dir = os.path.dirname(os.path.abspath(__file__))
    cfg.EMOTION_WHEEL_ROOT = os.path.join(current_dir, "emotion_wheel")
    sys.modules["config"] = cfg

try:
    from merbench.affectgpt_local.wheel_metrics import _map_label, _normalize_words
except ImportError:
    
    sys.path.append(os.path.dirname(os.path.abspath(__file__)))
    from merbench.affectgpt_local.wheel_metrics import _map_label, _normalize_words

def extract_answer(text):
    
    match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return text.strip()

def calculate_f1(preds, gts):
    if not preds or not gts:
        return 0.0
    
    pred_set = set(preds)
    gt_set = set(gts)
    
    intersection = len(pred_set & gt_set)
    if intersection == 0:
        return 0.0
        
    precision = intersection / len(pred_set)
    recall = intersection / len(gt_set)
    
    if precision + recall == 0:
        return 0.0
        
    return 2 * (precision * recall) / (precision + recall)

def emotion_wheel_reward(completions, **kwargs):
    
    rewards = []
    ground_truths = kwargs.get('openset', [])
    
    if len(completions) > len(ground_truths) and len(ground_truths) > 0:
        repeat_factor = len(completions) // len(ground_truths)
        ground_truths = [gt for gt in ground_truths for _ in range(repeat_factor)]
    
    metrics = [f"case3_wheel{i}_level1" for i in range(1, 6)]
    
    for completion, gt_raw_list in zip(completions, ground_truths):
        content = completion[0]["content"]
        pred_text = extract_answer(content)
        
        pred_raw_list = [p.strip() for p in pred_text.split(',')]
        
        f1_scores = []
        
        for metric in metrics:
            
            gt_norm = _normalize_words(gt_raw_list)
            gt_mapped = []
            for label in gt_norm:
                mapped = _map_label(label, metric)
                if mapped:
                    gt_mapped.append(mapped)
            
            pred_norm = _normalize_words(pred_raw_list)
            pred_mapped = []
            for label in pred_norm:
                mapped = _map_label(label, metric)
                if mapped:
                    pred_mapped.append(mapped)
            
            f1 = calculate_f1(pred_mapped, gt_mapped)
            f1_scores.append(f1)
            
        rewards.append(float(np.mean(f1_scores)))
            
    return rewards

def format_reward(completions, **kwargs):
    
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

_DEFAULT_EMBEDDER_PATH = "/path/to/data"
_DEFAULT_RUBRIC_EMB_PATH = "/path/to/data"
_RUBRIC_DB = None
_EMBEDDER = None
_EMBEDDER_TOKENIZER = None
_TEXT_EMB_CACHE: OrderedDict[str, torch.Tensor] = OrderedDict()
_TEXT_EMB_CACHE_SIZE = 2048
_TASK_INSTRUCT = os.getenv("RUBRIC_INSTRUCT", "Given a clause query, calculate the similarity between this clause and the key clue.")
_DEBUG_MODE = os.getenv("RUBRIC_DEBUG", "0") == "1"

_SCORE_THRESHOLD = 0.5

def set_score_threshold(threshold: float):
    
    global _SCORE_THRESHOLD
    _SCORE_THRESHOLD = threshold
    print(f"[affect_reward] Score threshold set to {threshold}")

def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    seq_len = attention_mask.sum(dim=1) - 1
    batch_idx = torch.arange(last_hidden_states.size(0), device=last_hidden_states.device)
    return last_hidden_states[batch_idx, seq_len]

def _get_device():
    env = os.getenv("RUBRIC_EMB_DEVICE")
    if env:
        return env
    return "cuda" if torch.cuda.is_available() else "cpu"

def _load_embedder():
    
    global _EMBEDDER, _EMBEDDER_TOKENIZER
    if _EMBEDDER is not None:
        return

    model_path = _DEFAULT_EMBEDDER_PATH
    device = _get_device()
    _EMBEDDER_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
    
    _EMBEDDER = AutoModel.from_pretrained(
        model_path,
        torch_dtype=torch.float32,
        device_map=None,           
        low_cpu_mem_usage=False    
    )
    _EMBEDDER = _EMBEDDER.to(device)
    _EMBEDDER.eval()
    
    for name, param in _EMBEDDER.named_parameters():
        
        param.requires_grad = False  
    
    _EMBEDDER_TOKENIZER.padding_side = "left"

def _encode_texts(texts):
    
    if len(texts) == 0:
        return None
    _load_embedder()
    device = _get_device()

    to_encode, order = [], []
    for t in texts:
        if t in _TEXT_EMB_CACHE:
            
            _TEXT_EMB_CACHE.move_to_end(t)
        else:
            to_encode.append(t)
        order.append(t)

    new_embs = []
    if to_encode:
        with torch.inference_mode():
            inputs = _EMBEDDER_TOKENIZER(
                [f"Instruct: {_TASK_INSTRUCT}\nQuery: {t}" for t in to_encode],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            ).to(device)
            
            import sys
            if 'deepspeed' in sys.modules:
                import deepspeed
                
                with deepspeed.zero.GatheredParameters(list(_EMBEDDER.parameters()), modifier_rank=None):
                    outputs = _EMBEDDER(**inputs)
            else:
                outputs = _EMBEDDER(**inputs)
            
            emb = _last_token_pool(outputs.last_hidden_state, inputs["attention_mask"])
            emb = F.normalize(emb, p=2, dim=1)
            
            emb = emb.float()
        
        for t, e in zip(to_encode, emb):
            if len(_TEXT_EMB_CACHE) >= _TEXT_EMB_CACHE_SIZE:
                _TEXT_EMB_CACHE.popitem(last=False)
            _TEXT_EMB_CACHE[t] = e.detach().cpu()

    merged = [_TEXT_EMB_CACHE[t] for t in order]
    return torch.stack(merged, dim=0)

def _load_rubric_db():
    
    global _RUBRIC_DB
    if _RUBRIC_DB is not None:
        return _RUBRIC_DB
    rubric_path = os.getenv("RUBRIC_EMB_PATH", _DEFAULT_RUBRIC_EMB_PATH)
    if rubric_path and os.path.exists(rubric_path):
        try:
            _RUBRIC_DB = torch.load(rubric_path, map_location="cpu")
        except Exception:
            _RUBRIC_DB = {}
    else:
        _RUBRIC_DB = {}
    return _RUBRIC_DB

def _extract_think_text(text: str) -> str:
    
    if not isinstance(text, str):
        return ""
    m = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL)
    return m.group(1).strip() if m else text.strip()

def _split_sentences(text: str):
    
    if not text:
        return []
    
    parts = re.split(r"[!?\.,]+\s*|\n+", text)
    return [p.strip() for p in parts if p.strip()]

def _build_rubric_from_clues(clues: dict):
    
    visual = clues.get("visual_clues") or []
    audio = clues.get("audio_clues") or []
    logic = clues.get("reasoning_emotions") or []

    rubric = {
        "visual": _encode_texts(visual),
        "audio": _encode_texts(audio),
        "logic": _encode_texts(logic),
    }
    return rubric

def _get_rubric(path: str, extracted_clues: dict):
    
    db = _load_rubric_db()
    if path and path in db:
        return db[path]
    if extracted_clues:
        rubric = _build_rubric_from_clues(extracted_clues)
        
        if path:
            db[path] = rubric
        return rubric
    return {"visual": None, "audio": None, "logic": None}

def _column_max_scores(pred_embs: torch.Tensor, gt_embs: torch.Tensor, return_sims=False):
    
    if pred_embs is None or gt_embs is None or gt_embs.numel() == 0:
        return None, None
    
    pred = F.normalize(pred_embs.to(device=gt_embs.device, dtype=gt_embs.dtype), p=2, dim=1)
    gt = F.normalize(gt_embs, p=2, dim=1)
    sims = pred @ gt.T  
    max_per_gt = sims.max(dim=0).values
    
    threshold = _SCORE_THRESHOLD
    scale = 1.0 / (1.0 - threshold) if threshold < 1.0 else 2.0
    scores = torch.relu(max_per_gt - threshold) * scale
    if return_sims:
        return scores, sims
    return scores, None

def _compute_sent_embs(completion_text: str):
    think_text = _extract_think_text(completion_text)
    sentences = _split_sentences(think_text)
    if not sentences:
        sentences = [think_text] if think_text else []
    return _encode_texts(sentences)

def rubric_perc_reward(completions, **kwargs):
    
    rewards = []
    paths = kwargs.get("path") or [None] * len(completions)
    clues_list = kwargs.get("extracted_clues") or [None] * len(completions)
    if not isinstance(paths, list):
        paths = [paths] * len(completions)
    if not isinstance(clues_list, list):
        clues_list = [clues_list] * len(completions)

    for idx, (completion, path, clues) in enumerate(zip(completions, paths, clues_list)):
        text = completion[0]["content"]
        pred_embs = _compute_sent_embs(text)
        rubric = _get_rubric(path, clues or {})

        if _DEBUG_MODE and idx == 0:
            think_text = _extract_think_text(text)
            sentences = _split_sentences(think_text)
            if not sentences:
                sentences = [think_text] if think_text else []
            
            print("\n" + "="*80)
            print("[RUBRIC_DEBUG] rubric_perc_reward - Sample 0")
            print("="*80)
            print(f"Generated sentences ({len(sentences)}):")
            for i, sent in enumerate(sentences):
                print(f"  [{i}] {sent[:100]}..." if len(sent) > 100 else f"  [{i}] {sent}")
            
            visual_clues = (clues or {}).get("visual_clues", []) if clues else []
            print(f"\nVisual clues ({len(visual_clues)}):")
            for i, clue in enumerate(visual_clues):
                print(f"  [{i}] {clue}")
            
            audio_clues = (clues or {}).get("audio_clues", []) if clues else []
            print(f"\nAudio clues ({len(audio_clues)}):")
            for i, clue in enumerate(audio_clues):
                print(f"  [{i}] {clue}")

        s_v, sims_v = _column_max_scores(pred_embs, rubric.get("visual"), return_sims=_DEBUG_MODE and idx == 0)
        s_a, sims_a = _column_max_scores(pred_embs, rubric.get("audio"), return_sims=_DEBUG_MODE and idx == 0)

        if _DEBUG_MODE and idx == 0:
            if sims_v is not None:
                print(f"\nVisual similarity matrix shape: {sims_v.shape}")
                print(f"Visual similarity matrix (first 5x5):\n{sims_v[:5, :5].cpu().numpy()}")
                print(f"Visual column max: {sims_v.max(dim=0).values.cpu().numpy()}")
                print(f"Visual soft-threshold scores: {s_v.cpu().numpy()}")
            if sims_a is not None:
                print(f"\nAudio similarity matrix shape: {sims_a.shape}")
                print(f"Audio similarity matrix (first 5x5):\n{sims_a[:5, :5].cpu().numpy()}")
                print(f"Audio column max: {sims_a.max(dim=0).values.cpu().numpy()}")
                print(f"Audio soft-threshold scores: {s_a.cpu().numpy()}")

        Iv = 1 if s_v is not None else 0
        Ia = 1 if s_a is not None else 0
        denom = Iv + Ia
        if denom == 0:
            rewards.append(0.0)
            if _DEBUG_MODE and idx == 0:
                print(f"\nFinal R_perc: 0.0 (no valid clues)")
                print("="*80 + "\n")
            continue

        Sv = s_v.mean().item() if s_v is not None else 0.0
        Sa = s_a.mean().item() if s_a is not None else 0.0
        r = (Iv * Sv + Ia * Sa) / denom
        rewards.append(float(r))
        
        if _DEBUG_MODE and idx == 0:
            print(f"\nIv={Iv}, Ia={Ia}, Sv={Sv:.4f}, Sa={Sa:.4f}")
            print(f"Final R_perc: {r:.4f}")
            print("="*80 + "\n")

    return rewards

def rubric_coh_reward(completions, **kwargs):
    
    rewards = []
    paths = kwargs.get("path") or [None] * len(completions)
    clues_list = kwargs.get("extracted_clues") or [None] * len(completions)
    if not isinstance(paths, list):
        paths = [paths] * len(completions)
    if not isinstance(clues_list, list):
        clues_list = [clues_list] * len(completions)

    for idx, (completion, path, clues) in enumerate(zip(completions, paths, clues_list)):
        text = completion[0]["content"]
        pred_embs = _compute_sent_embs(text)
        rubric = _get_rubric(path, clues or {})

        if _DEBUG_MODE and idx == 0:
            think_text = _extract_think_text(text)
            sentences = _split_sentences(think_text)
            if not sentences:
                sentences = [think_text] if think_text else []
            
            print("\n" + "="*80)
            print("[RUBRIC_DEBUG] rubric_coh_reward - Sample 0")
            print("="*80)
            print(f"Generated sentences ({len(sentences)}):")
            for i, sent in enumerate(sentences):
                print(f"  [{i}] {sent[:100]}..." if len(sent) > 100 else f"  [{i}] {sent}")
            
            logic_clues = (clues or {}).get("reasoning_emotions", []) if clues else []
            print(f"\nReasoning emotion clues ({len(logic_clues)}):")
            for i, clue in enumerate(logic_clues):
                print(f"  [{i}] {clue}")

        s_l, sims_l = _column_max_scores(pred_embs, rubric.get("logic"), return_sims=_DEBUG_MODE and idx == 0)
        
        if _DEBUG_MODE and idx == 0:
            if sims_l is not None:
                print(f"\nLogic similarity matrix shape: {sims_l.shape}")
                print(f"Logic similarity matrix (first 5x5):\n{sims_l[:5, :5].cpu().numpy()}")
                print(f"Logic column max: {sims_l.max(dim=0).values.cpu().numpy()}")
                print(f"Logic soft-threshold scores: {s_l.cpu().numpy()}")
        
        if s_l is None:
            rewards.append(0.0)
            if _DEBUG_MODE and idx == 0:
                print(f"\nFinal R_coh: 0.0 (no valid clues)")
                print("="*80 + "\n")
        else:
            r = float(s_l.mean().item())
            rewards.append(r)
            if _DEBUG_MODE and idx == 0:
                print(f"\nFinal R_coh: {r:.4f}")
                print("="*80 + "\n")

    return rewards

def compute_similarity_matrices_for_papo(completions, **kwargs):
    
    results = []
    paths = kwargs.get("path") or [None] * len(completions)
    clues_list = kwargs.get("extracted_clues") or [None] * len(completions)
    
    if not isinstance(paths, list):
        paths = [paths] * len(completions)
    if not isinstance(clues_list, list):
        clues_list = [clues_list] * len(completions)
    
    for completion, path, clues in zip(completions, paths, clues_list):
        text = completion[0]["content"]
        
        think_text = _extract_think_text(text)
        sentences = _split_sentences(think_text)
        if not sentences:
            sentences = [think_text] if think_text else []
        
        pred_embs = _encode_texts(sentences) if sentences else None
        rubric = _get_rubric(path, clues or {})
        
        result = {
            'sentences': sentences,
            'sim_matrix_v': None,
            'sim_matrix_a': None,
            'score_v': None,
            'score_a': None,
        }
        
        if pred_embs is not None:
            
            visual_embs = rubric.get("visual")
            if visual_embs is not None and visual_embs.numel() > 0:
                pred = F.normalize(pred_embs.to(device=visual_embs.device, dtype=visual_embs.dtype), p=2, dim=1)
                gt = F.normalize(visual_embs, p=2, dim=1)
                sims_v = pred @ gt.T  
                result['sim_matrix_v'] = sims_v
                result['score_v'] = sims_v.max(dim=-1).values
            
            audio_embs = rubric.get("audio")
            if audio_embs is not None and audio_embs.numel() > 0:
                pred = F.normalize(pred_embs.to(device=audio_embs.device, dtype=audio_embs.dtype), p=2, dim=1)
                gt = F.normalize(audio_embs, p=2, dim=1)
                sims_a = pred @ gt.T  
                result['sim_matrix_a'] = sims_a
                result['score_a'] = sims_a.max(dim=-1).values
        
        results.append(result)
    
    return results

def compute_modality_token_masks(
    sim_results: list,
    completion_ids: torch.Tensor,
    tokenizer,
    threshold: float = 0.5,
):
    
    batch_size, seq_len = completion_ids.shape
    device = completion_ids.device
    
    visual_masks = []
    audio_masks = []
    
    for b in range(batch_size):
        if b >= len(sim_results):
            
            visual_masks.append(torch.zeros(seq_len, device=device))
            audio_masks.append(torch.zeros(seq_len, device=device))
            continue
        
        result = sim_results[b]
        sentences = result.get('sentences', [])
        score_v = result.get('score_v')
        score_a = result.get('score_a')
        
        if not sentences or (score_v is None and score_a is None):
            
            visual_masks.append(torch.ones(seq_len, device=device))
            audio_masks.append(torch.ones(seq_len, device=device))
            continue
        
        if score_v is None:
            score_v = torch.full((len(sentences),), float('-inf'), device=device)
        else:
            score_v = score_v.to(device)
        
        if score_a is None:
            score_a = torch.full((len(sentences),), float('-inf'), device=device)
        else:
            score_a = score_a.to(device)
        
        is_visual = (score_v > threshold) & (score_v > score_a)
        is_audio = (score_a > threshold) & (score_a > score_v)
        
        sentence_token_counts = []
        for sent in sentences:
            
            tokens = tokenizer.encode(sent, add_special_tokens=False)
            sentence_token_counts.append(len(tokens))
        
        total_sentence_tokens = sum(sentence_token_counts)
        
        visual_mask = torch.zeros(seq_len, device=device)
        audio_mask = torch.zeros(seq_len, device=device)
        
        if total_sentence_tokens > 0:
            
            current_token = 0
            for i, token_count in enumerate(sentence_token_counts):
                if i >= len(is_visual):
                    break
                
                token_allocation = int(round(seq_len * token_count / total_sentence_tokens))
                token_start = current_token
                token_end = min(current_token + token_allocation, seq_len)
                
                if i == len(sentences) - 1:
                    token_end = seq_len  
                
                if token_start < token_end:
                    if is_visual[i]:
                        visual_mask[token_start:token_end] = 1.0
                    if is_audio[i]:
                        audio_mask[token_start:token_end] = 1.0
                
                current_token = token_end
        else:
            
            tokens_per_sentence = seq_len // max(len(sentences), 1)
            for i, (vis, aud) in enumerate(zip(is_visual, is_audio)):
                start = i * tokens_per_sentence
                end = min((i + 1) * tokens_per_sentence, seq_len)
                if vis:
                    visual_mask[start:end] = 1.0
                if aud:
                    audio_mask[start:end] = 1.0
        
        visual_masks.append(visual_mask)
        audio_masks.append(audio_mask)
    
    return {
        'visual_token_mask': torch.stack(visual_masks, dim=0),
        'audio_token_mask': torch.stack(audio_masks, dim=0),
    }

def rubric_perc_reward_with_matrices(completions, **kwargs):
    
    rewards = rubric_perc_reward(completions, **kwargs)
    sim_results = compute_similarity_matrices_for_papo(completions, **kwargs)
    return rewards, sim_results
