"

from __future__ import annotations

import re
import uuid
from collections import Counter
from typing import List, Tuple, Sequence, Union

import pandas as pd

# ---------------------------------------------------------------------
#                           Helper Utilities
# ---------------------------------------------------------------------
def get_evi_numbers(xml_text: str):
    """
    Extract (main_sentence, support_numbers) tuples from <main> … </main> blocks.

    Returns
    -------
    List[Tuple[int, List[int]]]
        e.g. [(2, []), (5, [1, 3, 4])]
    """
    pattern = re.compile(
        r'<main>\s*'                         # opening tag
        r'sentence_number\s*=\s*(\d+)\s*'    # main sentence number  (group 1)
        r'(?:'                               # --- optional support alternatives ---
        r'<support\s*>(.*?)</support\s*>'    #   normal tag (group 2)
        r'|'                                 #   ─or─
        r'<support\s*/>'                     #   self‑closing
        r')?'                                # support section itself optional
        r'\s*</main>',                       # closing </main>
        re.DOTALL | re.IGNORECASE
    )

    results = []
    for main_sentence, support in pattern.findall(xml_text):
        
        support_numbers = []
        if support:
            for n in support.split(','):
                if n.strip():
                    try:
                        support_numbers.append(int(n.strip()))
                    except Exception as e:
                        print(e)
        try:
            results.append((int(main_sentence), support_numbers))
        except Exception as e:
            print(e) 

    return results

from collections import Counter
from typing import List, Tuple, Sequence, Union

Number = Union[int, str]                     # use str if your IDs aren’t numeric
Pred   = Tuple[Number, Sequence[Number]]     # e.g. (main , [supports …])

def get_correctness_reward(
    predictions: List[Pred],
    ground_truth: List[Pred],
    dedup: bool = True            # False ⇒ duplicates are scored individually
) -> float:
    """
    Compute the total reward for one (predictions, ground_truth) pair.

    Rules
    -----
    1)  pred  main  ∈ gt main      →  +2
    2)  pred  main  ∈ gt support   →  +1
    3)  pred  main  ∉ gt main∪sup  →  -2
    4)  pred  sup   ∈ gt main      →  +1
    5)  pred  sup   ∈ gt support   →  +1
    6)  pred  sup   ∉ gt main∪sup  →  -1
    7)  gt    main  missed         →  -2      (FN)
    8)  gt    support missed       →  -1      (FN)
    """
    # ---------------------------- helpers ----------------------------
    def _flatten(items, idx):
        return [t[idx] if idx == 0 else x
                for t in items
                for x in (t[idx] if idx else [t[idx]])]

    main_p  = _flatten(predictions, 0)
    main_g  = _flatten(ground_truth, 0)
    sup_p   = _flatten(predictions, 1)
    sup_g   = _flatten(ground_truth, 1)

    if dedup:
        main_p, main_g, sup_p, sup_g = map(set, (main_p, main_g, sup_p, sup_g))
        # convert back to list so Counter logic later works transparently
        main_p, main_g, sup_p, sup_g = map(list, (main_p, main_g, sup_p, sup_g))

    # choose container for lookup
    main_g_set, sup_g_set = set(main_g), set(sup_g)

    # ------------------- 1–3 : predicted mains -----------------------
    reward  = 0.0
    for m in main_p:
        if m in main_g_set:
            reward += 2
        elif m in sup_g_set:
            reward += 1
        else:
            reward -= 2

    # ------------------- 4–6 : predicted supports --------------------
    for s in sup_p:
        if s in main_g_set or s in sup_g_set:
            reward += 2
        else:
            reward -= 2

    # ------------------- 7–8 : false negatives -----------------------
    missed_main = [m for m in main_g if m not in main_p and m not in sup_p]
    missed_sup  = [s for s in sup_g if s not in main_p and s not in sup_p]
    reward -= 2 * len(missed_main) + 1 * len(missed_sup)

    return reward

import uuid
def correctness_reward(completions, answers,**kwargs) -> list[float]:
    correctness_reward= []
    for ans, completion in zip(answers, completions):
        
        evi_sen_labels = eval(ans)
        evi_sen_preds = get_evi_numbers(completion)
       
        corr_reward = get_correctness_reward(predictions=evi_sen_preds, ground_truth=evi_sen_labels)
        correctness_reward.append(corr_reward)
        
    
    _id = str(uuid.uuid4())
    
    df = pd.DataFrame({"completions": completions, "answers": answers, "rewards": correctness_reward})
    df.to_csv(f"rewards_24_04/{_id}.csv",index=False)
    
    return correctness_reward


import re

def format_reward_func(completions, **kwargs) -> list[float]:
    """Strict reward: Matches exact multiple <main> blocks with expected inner tags and structure."""
    strict_pattern = re.compile(
r"""
^                                   # ─── entire string must fit ───
(?:                                 # repeat one or more identical blocks
    <main>\r?\n                     # 1️⃣  <main>      on its own line
    sentence_number=(\d+)\r?\n      # 2️⃣  integer only, no quotes / spaces
    <support>                       # 3️⃣  start support tag
        (?:\d+(?:,\d+)*)?           #       0‑N ints separated by commas
    </support>\r?\n                 #       close support tag (same line)
    </main>\r?\n?                   # 4️⃣  </main> on its own line
)+                                  # at least one block is required
$                                   # ─── nothing else before/after ───
""",
    re.VERBOSE,)



    soft_pattern = re.compile(
r"""
^                                   # ── begin entire string ──
(?:                                 #    repeat for every block
    (?:                             #    1️⃣ any text that is *not* the start of <main>
        (?!<main>).                 #        negative‑look‑ahead prevents eating real blocks
    )*
    <main>\s*                       #    2️⃣ well‑formed <main> block
    sentence_number\s*=\s*\d+\s*
    (?:
          <support\s*/>                             # empty   <support/>
        | <support>\s*
              (?:\d+\s*(?:,\s*\d+\s*)*)?            # list of ints (optional)
          </support>
    )
    \s*</main>
)+                                  #    must have *at least one* block
(?:                                 #    3️⃣ trailing text with no more <main> tags
    (?!<main>).                     
)*                                  #        (same idea as 1️⃣)
$                                   # ── end entire string ──
""",
    re.VERBOSE | re.DOTALL,
    )
    
    all_rewards = []
    for comp in completions:
        is_strict  = bool(strict_pattern.fullmatch(comp))
        is_soft    = bool(soft_pattern.search(comp))
        
        if "<no_main_sentences/>" in comp:
            reward = 1.0
        elif is_strict:
            reward = 1.0
        elif is_soft:
            reward = 0.5
        else:
            reward = -0.5
        
        all_rewards.append(reward)
    
    return all_rewards


import re
def repetation_reward(completions,  **kwargs) -> list[float]:
    
    def get_repitation_num(lis):
        df = pd.Series(lis).value_counts()
        df = df[df > 1]
        return sum(df.to_list()) - df.shape[0]

    rep_rewards = []
    for completion in completions:
    
        evi_sen_preds = get_evi_numbers(completion)
        
        main_sen = [i[0] for i in evi_sen_preds]
        support_sen = [i[1] for i in evi_sen_preds]
        support_sen.append(main_sen)
        reward = 0
        for i in support_sen:
            reward += get_repitation_num(i)*-2
        
        rep_rewards.append(reward)
    
    return rep_rewards