from tqdm import tqdm
import torch

# --- PS Metric Helper Functions ---
def string2token(strings, tokenizer, device):
    tks = [tokenizer.encode(_, add_special_tokens=True, return_tensors='pt').to(device)[0] for _ in strings]
    tk_lens = [_.size(0) for _ in tks]
    return {'token': tks, 'length': tk_lens}

def token2string(tokens, tokenizer):
    strs = [tokenizer.decode(_, skip_special_tokens=True) for _ in tokens]
    return strs

def lcs(s1, s2):
    # Standard Longest Common Subsequence algorithm
    a = [[None for i in range(len(s2))] for j in range(len(s1))]
    def _lcs(s1, s2, s1Index, s2Index, arr):
        if s1Index == -1 or s2Index == -1:
            return 0
        if arr[s1Index][s2Index] is not None:
            return arr[s1Index][s2Index]
        if s1[s1Index] == s2[s2Index]:
            result = 1 + _lcs(s1, s2, s1Index - 1, s2Index - 1, arr)
        else:
            result = max(_lcs(s1, s2, s1Index - 1, s2Index, arr), _lcs(s1, s2, s1Index, s2Index - 1, arr))
        arr[s1Index][s2Index] = result
        return result 
    return _lcs(s1, s2, len(s1) - 1, len(s2) - 1, a)

def get_ps_score_tofu(model, tokenizer, loader, ps_type='exact'):
    """
    Calculates the Prediction Similarity (PS) score.
    """
    ps_list = []
    model.eval()
    device = model.device

    for batch in tqdm(loader, desc=f"Calculating PS ({ps_type})"):
        # Handle data keys based on ps_type
        if ps_type == 'perturb':
            ques = batch['paraphrased_question']
            anws = batch['answer']
        else:
            ques = batch['question']
            anws = batch['answer']
            
        fuls = [f"### Question: {que}\n ### Answer: {ans}" for que, ans in zip(ques, anws)]
        
        # Tokenize
        _ques_tks_and_lens = string2token(ques, tokenizer, device)
        _fuls_tks_and_lens = string2token(fuls, tokenizer, device)
        
        ques_tks, ques_tks_lens = _ques_tks_and_lens['token'], _ques_tks_and_lens['length']
        fuls_tks, fuls_tks_lens = _fuls_tks_and_lens['token'], _fuls_tks_and_lens['length']
        
        # Binary search / bar logic from original script
        left_bar, right_bar = [_ for _ in ques_tks_lens], [_ for _ in fuls_tks_lens]
        
        # Determine max attempts
        max_diff = max([b - a for a, b in zip(left_bar, right_bar)]) if left_bar and right_bar else 0
        
        old_mid_bar = []
        for _num_attempt_ in range(max_diff):
            mid_bar = [(a + b) // 2 for a, b in zip(left_bar, right_bar)]
            
            # Convergence check
            if _num_attempt_ != 0 and len(mid_bar) == len(old_mid_bar):
                if sum([int(l == r) for l, r in zip(mid_bar, old_mid_bar)]) == len(old_mid_bar): 
                    break
            
            # Prepare inputs for generation
            can_strings = token2string([tk[:cur] for cur, tk in zip(mid_bar, fuls_tks)], tokenizer)
            inputs = tokenizer.batch_encode_plus(can_strings, add_special_tokens=True, return_tensors='pt', padding=True).to(device)
            
            with torch.no_grad():
                preds_tks = model.generate(
                    inputs.input_ids, 
                    attention_mask=inputs.attention_mask, 
                    max_length=200, 
                    do_sample=False, 
                    use_cache=True, 
                    pad_token_id=tokenizer.eos_token_id
                )
            
            _preds_tks_and_lens = string2token(tokenizer.batch_decode(preds_tks, skip_special_tokens=True), tokenizer, device)
            pred_tks, pred_tks_lens = _preds_tks_and_lens['token'], _preds_tks_and_lens['length']
            
            # Logic to slice predictions and ground truth
            # Ensure indices don't go out of bounds
            pred_tks_ = []
            fuls_tks_ = []
            for idx in range(len(fuls_tks)):
                p_tk = pred_tks[idx]
                f_tk = fuls_tks[idx]
                m_bar = mid_bar[idx]
                
                # Careful slicing
                start_slice = min(m_bar, len(p_tk))
                pred_tks_.append(p_tk[start_slice:])
                fuls_tks_.append(f_tk[m_bar:])

            if ps_type == 'similar':
                match = [lcs(p, f) >= 0.5 * len(p) for p, f in zip(pred_tks_, fuls_tks_)]
            else:
                # Exact match logic
                match = []
                for p, f in zip(pred_tks_, fuls_tks_):
                    if len(p) == 0: # Handle empty prediction case
                        match.append(False)
                        continue
                    # Check exact prefix match
                    min_len = min(len(p), len(f))
                    is_eq = sum([int(a == b) for a, b in zip(p[:min_len], f[:min_len])]) == len(p)
                    match.append(is_eq)

            left_bar  = [left if m else mid  for m,  left, mid in zip(match,  left_bar, mid_bar)]
            right_bar = [mid if m else right for m, right, mid in zip(match, right_bar, mid_bar)]
            old_mid_bar = mid_bar
            
        ps_list += [1 - (m - l) / (r - l + 1e-9) for l, m, r in zip(ques_tks_lens, right_bar, fuls_tks_lens)]
        
    return sum(ps_list) / len(ps_list) if ps_list else 0.0




def get_ps_score_arxiv(model, tokenizer, loader, ps_type='exact'):
    """
    Calculates the Prediction Similarity (PS) score.
    """
    ps_list = []
    model.eval()
    device = model.device

    for batch in tqdm(loader, desc=f"Calculating PS ({ps_type})"):
        # Handle data keys based on ps_type
        if ps_type == 'perturb':
            raise NotImplementedError("Perturbation-based PS not implemented for arXiv dataset.")
        else:
            fuls, attention_mask, index = batch
            input = fuls[:, :100]
            gt = fuls[:, 100:200]
            
        # fuls = [f"{que} {ans}" for que, ans in zip(ques, anws)]
        
        # Tokenize
        # _ques_tks_and_lens = string2token(ques, tokenizer, device)
        # _fuls_tks_and_lens = string2token(fuls, tokenizer, device)
        
        ques_tks, ques_tks_lens = input, [_.size(0) for _ in input]
        fuls_tks, fuls_tks_lens = fuls, [_.size(0) for _ in fuls]
        
        # Binary search / bar logic from original script
        left_bar, right_bar = [_ for _ in ques_tks_lens], [_ for _ in fuls_tks_lens]
        
        # Determine max attempts
        max_diff = max([b - a for a, b in zip(left_bar, right_bar)]) if left_bar and right_bar else 0
        
        old_mid_bar = []
        for _num_attempt_ in range(max_diff):
            mid_bar = [(a + b) // 2 for a, b in zip(left_bar, right_bar)]
            
            # Convergence check
            if _num_attempt_ != 0 and len(mid_bar) == len(old_mid_bar):
                if sum([int(l == r) for l, r in zip(mid_bar, old_mid_bar)]) == len(old_mid_bar): 
                    break
            
            # Prepare inputs for generation
            can_strings = token2string([tk[:cur] for cur, tk in zip(mid_bar, fuls_tks)], tokenizer)
            inputs = tokenizer.batch_encode_plus(can_strings, add_special_tokens=True, return_tensors='pt', padding=True).to(device)
            
            with torch.no_grad():
                preds_tks = model.generate(
                    inputs.input_ids, 
                    attention_mask=inputs.attention_mask, 
                    max_length=512, 
                    do_sample=False, 
                    use_cache=True, 
                    pad_token_id=tokenizer.eos_token_id
                )
            
            _preds_tks_and_lens = string2token(tokenizer.batch_decode(preds_tks, skip_special_tokens=True), tokenizer, device)
            pred_tks, pred_tks_lens = _preds_tks_and_lens['token'], _preds_tks_and_lens['length']
            
            # Logic to slice predictions and ground truth
            # Ensure indices don't go out of bounds
            pred_tks_ = []
            fuls_tks_ = []
            for idx in range(len(fuls_tks)):
                p_tk = pred_tks[idx]
                f_tk = fuls_tks[idx]
                m_bar = mid_bar[idx]
                
                # Careful slicing
                start_slice = min(m_bar, len(p_tk))
                pred_tks_.append(p_tk[start_slice:])
                fuls_tks_.append(f_tk[m_bar:])

            if ps_type == 'similar':
                match = [lcs(p, f) >= 0.5 * len(p) for p, f in zip(pred_tks_, fuls_tks_)]
            else:
                # Exact match logic
                match = []
                for p, f in zip(pred_tks_, fuls_tks_):
                    if len(p) == 0: # Handle empty prediction case
                        match.append(False)
                        continue
                    # Check exact prefix match
                    min_len = min(len(p), len(f))
                    is_eq = sum([int(a == b) for a, b in zip(p[:min_len], f[:min_len])]) == len(p)
                    match.append(is_eq)

            left_bar  = [left if m else mid  for m,  left, mid in zip(match,  left_bar, mid_bar)]
            right_bar = [mid if m else right for m, right, mid in zip(match, right_bar, mid_bar)]
            old_mid_bar = mid_bar
            
        ps_list += [1 - (m - l) / (r - l + 1e-9) for l, m, r in zip(ques_tks_lens, right_bar, fuls_tks_lens)]
        
    return sum(ps_list) / len(ps_list) if ps_list else 0.0