from transformers import EvalPrediction, PreTrainedTokenizer
from Levenshtein import ratio
import numpy as np
import os
import re

def collapse_brackets(inp, tokenizer):
    left_bracket_id = tokenizer.convert_tokens_to_ids('[')
    right_bracket_id = tokenizer.convert_tokens_to_ids(']')
    temp = []
    collect_node = False
    node_str = ""
    for x in inp:
        if x == left_bracket_id:
            collect_node = True
        elif x == right_bracket_id:
            collect_node = False
            temp.append(int(node_str) + 10000)
            node_str = ""
        elif collect_node:
            node_str += tokenizer.convert_ids_to_tokens([x])[0]
        else:
            temp.append(x)
    inp = np.array(temp)
    return inp 

def compute_metrics(pred_obj: EvalPrediction, tokenizer: PreTrainedTokenizer=None, task=None):
    inputs = pred_obj.inputs
    labels = pred_obj.label_ids
    pred = pred_obj.predictions
    pred[pred == -100] = tokenizer.pad_token_id
    inputs[inputs == -100] = tokenizer.pad_token_id
    labels[labels == -100] = tokenizer.pad_token_id

    if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
        prompt_str = tokenizer.batch_decode(inputs[:5])
        pred_str = tokenizer.batch_decode(pred[:5])
        label_str = tokenizer.batch_decode(labels[:5])
        for pr, p, l in zip(prompt_str, pred_str, label_str):
            print("="*80)
            print(f"Prompt: {repr(pr)}")
            print(f"Pred  : {repr(p)}")
            print(f"Label : {repr(l)}")

    if task == 'maze':
        colon_id = tokenizer.convert_tokens_to_ids(':')
        comma_id = tokenizer.convert_tokens_to_ids(',')
        pad_id = tokenizer.pad_token_id
        X_id = tokenizer.convert_tokens_to_ids('[X]')
        eos_id = tokenizer.eos_token_id
        question_id = tokenizer.convert_tokens_to_ids('?')
        back_id = tokenizer.convert_tokens_to_ids(';')
        legal_paths = 0
        jumps = 0
        exact_match = 0
        contains = 0
        correct = 0
        for inp, p, l in zip(inputs, pred, labels):
            inp = inp[(inp != pad_id) & (inp != eos_id)]
            p = p[(p != pad_id) & (p != eos_id)]
            l = l[(l != pad_id) & (l != eos_id)]

            # inp = collapse_brackets(inp, tokenizer)
            # p = collapse_brackets(p, tokenizer)
            # l = collapse_brackets(l, tokenizer)
            # print(p)
            # print(l)
            if len(p) == 0:
                continue
            adj = {}
            from_node = None
            for j, tok in enumerate(inp):
                if tok == colon_id:
                    adj[inp[j-1]] = set()
                    from_node = inp[j-1]
                elif tok == comma_id or tok == question_id:
                    from_node = None
                elif from_node is not None and tok != pad_id and tok != X_id:
                    adj[from_node].add(tok)
            # print(adj)

            legal = True
            illegal_moves = 0
            for i in range(0, len(p) - 1):
                if p[i] != back_id and p[i+1] != back_id and (p[i] not in adj or p[i+1] not in adj[p[i]]):
                    legal = False
                    illegal_moves += 1
                    # print(f"Illegal move: {p[i]} -> {p[i+1]}")
            jumps += illegal_moves / len(p)
            if legal:
                legal_paths += 1
                if p[0] == l[0] and p[-1] == l[-1]:
                    correct += 1
            if len(p) == len(l) and (p == l).all():
                exact_match += 1
                if not legal:
                    breakpoint()
                    raise ValueError("Exact match but not legal")
            if set(p) == set(l):
                contains += 1

        return {'legal': legal_paths / len(pred), 'correct': correct / len(pred), 'exact_match': exact_match / len(pred), 'jumps': jumps / len(pred), 'contains': contains / len(pred)}

    elif task == 'synthetic_COT_mult':
        pred_str = tokenizer.batch_decode(pred, skip_special_tokens=True)
        label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
        correct = 0
        not_found = 0
        for p, l in zip(pred_str, label_str):
            # Extract content between \boxed{} from prediction
            label_answer = re.search(r'\\boxed{(.*?)}', p)
            if not label_answer:
                not_found += 1
            elif label_answer.group(1) == l:
                correct += 1
        len_dist = np.array([len(p[p != tokenizer.pad_token_id]) for p in pred])
        return {'accuracy': correct / len(pred), 'not_found': not_found / len(pred), 'len_dist': len_dist}
    else:
        pred = np.pad(pred, ((0, 0), (0, labels.shape[1] - pred.shape[1])), constant_values=tokenizer.pad_token_id)
        correct = (pred == labels).all(axis=1)
        accuracy = correct.mean()
        # acc_first_half = (pred[:, :labels.shape[1]//2] == labels[:, :labels.shape[1]//2]).all(axis=1).mean()
        # acc_first_quarter = (pred[:, :labels.shape[1]//4] == labels[:, :labels.shape[1]//4]).all(axis=1).mean()
        dist = np.array([ratio(p, l) for p, l in zip(pred, labels)]).mean()

        return {
            'accuracy': accuracy,
            'input_length': pred_obj.inputs.shape[1],
            'distance': dist,
            # 'acc_first_half': acc_first_half,
            # 'acc_first_quarter': acc_first_quarter
        }

if __name__ == "__main__":
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    test = [tokenizer.convert_tokens_to_ids([x])[0] for x in '[1][33][2]']
    print(collapse_brackets(test, tokenizer))
