import re
from typing import List
import numpy as np

bleu_scorer = None
bleurt_scorer = None
rouge_scorer = None
bert_scorer = None
llm_scorer = None


def init_bleu_scorer():
    global bleu_scorer
    if bleu_scorer is None:
        import evaluate as hfeval
        bleu_scorer = hfeval.load("bleu")
    
def init_bleurt_scorer():
    global bleurt_scorer
    if bleurt_scorer is None:
        import evaluate as hfeval
        bleurt_scorer = hfeval.load("bleurt", module_type="metric", device='cpu')

def init_rouge_scorer():
    global rouge_scorer
    if rouge_scorer is None:
        import evaluate as hfeval
        rouge_scorer = hfeval.load('rouge')


def init_bert_scorer():
    global bert_scorer
    if bert_scorer is None:
        import evaluate as hfeval
        bert_scorer = hfeval.load('bertscore', lang='en')


def init_llm_scorer():
    global bert_scorer
    if llm_scorer is None:
        import openai
        # initialize the model or api call 


def compute_standard_scores(
    ref_answers: List[List[str]], 
    answer: List[str],
    question: List[str] = None,
    scorelist=['bleu', 'bleu_adapt', 'rouge'],
    **kwargs,
):
    # compute all the scores and merge results
    if question is not None:
        assert len(question) == len(answer), f"The number of questions must match the number of proposed answers! {question} | {answer}"
    retdict = {}
    if 'bleu' in scorelist:
        init_bleu_scorer()
        try:
            retdict.update({'bleu': bleu_scorer.compute(
                predictions=answer,
                references=ref_answers,
            )['bleu']})
        except Exception as e:
            print(f"Exceptioning in bleu {e}")
            retdict.update({'bleu': np.nan}) # label as incorrect by default
    if 'bleu_adapt' in scorelist:
        init_bleu_scorer()
        try:
            # adapt the ngramm size to the predictions and references
            # make sure that we have n gramms that fit the answer
            count_refs = ref_answers[0] if isinstance(ref_answers[0], list) else ref_answers
            max_order = min([max([len(a.split(' ')) for a in answer]), max([len(a.split(' ')) for a in count_refs]), 4])
            retdict.update({'bleu_adapt': bleu_scorer.compute(
                predictions=answer,
                references=ref_answers,
                max_order=max_order, # consider 1 gramm, since otherwise it returns 0s for n<4 for identical sequences
            )['bleu']})
        except Exception as e:
            print(f"Exceptioning in bleu_adapt {e}")
            retdict.update({'bleu_adapt': np.nan}) # label as incorrect by default
    if 'bleurt' in scorelist:
        init_bleurt_scorer()
        try:
            # bleurt behaves poorly in this regard with multiple possible answers
            assert len(ref_answers)==1, "Must have one to one for bleurt"
            bleurt_ref_answers = [ra for ra in ref_answers[0]]
            bleurt_answer = [answer[0] for _ in bleurt_ref_answers] # just replicate required amount of times
            bscores = bleurt_scorer.compute(
                predictions=bleurt_answer,
                references=bleurt_ref_answers, 
            )['scores']
            # aggregate for multiple possible answers
            bscores = sum(bscores)/len(bscores)
            retdict.update({'bleurt': bscores})
        except Exception as e:
            print(f"Exceptioning in bleurt {e}")
            retdict.update({'bleurt': np.nan}) # label as incorrect by default
    if 'rouge' in scorelist:
        init_rouge_scorer()
        try:
            retdict.update(rouge_scorer.compute(
                predictions=answer,
                references=ref_answers,
            ))
        except Exception as e:
            print(f"Exceptioning in rouge {e}")
            retdict.update({'rougeL': np.nan, 'rouge1': np.nan}) # label as incorrect by default
    if 'bertscore' in scorelist:
        init_bert_scorer()
        try:
            bscore_outs = bert_scorer.compute(
                predictions=answer,
                references=ref_answers,
                lang='en'
            )
            bscore_outs = {'bert_score_'+k: v for k, v in bscore_outs.items()}
            retdict.update(bscore_outs)
        except Exception as e:
            print(f"Exceptioning in bert score {e}")
            retdict.update({'bert_score_f1': np.nan, 'bert_score_precision': np.nan, 'bert_score_recall': np.nan}) # label as incorrect by default
    
    if 'judge' in scorelist:
        try:
            jkwargs = kwargs['judge'] if 'judge' in kwargs else {}
            assert len(answer)==1 and len(question)==1, f"Somethithing up with sizes: {question}, {answer}"
            judge_out = evaluate_llm_as_a_judge_oai_interface(
                predicted_answer=answer[0],
                correct_answers=ref_answers,
                question=question[0],
                **jkwargs
            )
            retdict.update(judge_out)
        except Exception as e:
            print(f"Exceptioning in llm as a judge {e}")
            retdict.update({'judge_says': np.nan, 'judge_model': 'failed'})


    return retdict


def clean_up_special_chars(s):
    return re.sub(r'<\|[^ ]*?\|>', '', s)


import requests, os, time

if 'ORO_API_KEY' in os.environ:
    oai_key_default = os.environ['ORO_API_KEY']
else:
    oai_key_default = ''
    print(f"ORO_API_KEY not set, might fail at correctness evaluation with llm as a judge")

def evaluate_llm_as_a_judge_oai_interface(
    predicted_answer,
    correct_answers,
    question=None,
    oai_key=oai_key_default,
    rec_depth = 0,
    max_rec_depth = 6,
    provider = ["Lambda"],
    use_model = "meta-llama/llama-3.3-70b-instruct",
    prompt_for_qa = True,
    use_max_tokens = 1,
    temperature = 1.,
    reasoning_config=None
):
    assert oai_key != '', "Must set the ORO_API_KEY before launch! Must have a key for llm judge!"
    if rec_depth>=max_rec_depth:
        print("Max depth achieved, too much failure!")
        print(question, predicted_answer, correct_answers)
        return {
            'judge_says': np.nan, 
            '_judge_model': use_model,
            '_prompt_style': prompt_version, 
            '_use_max_tokens': use_max_tokens,
            '_temperature': temperature,
            '_full_prompt': '',
            '_full_response': '',
        }
    
    prompt_version = 'qa' if prompt_for_qa else 'gen'

    if question is not None:
        prompt = f'We are assessing the quality of answers to the following question: {question}\n'
    
    if prompt_for_qa:
        if isinstance(correct_answers, list) and len(correct_answers)==1:
            prompt += f"The expected answer is: {correct_answers[0]}.\n"
        elif isinstance(correct_answers, str):
            prompt += f"The expected answer is: {correct_answers}.\n"
        else:
            prompt += f"The following are expected answers to this question: {correct_answers}.\n"
    else:
        if isinstance(correct_answers, list) and len(correct_answers)==1:
            prompt += f"The following is an example answer: {correct_answers[0]}.\n"
        elif isinstance(correct_answers, str):
            prompt += f"The following is an example answer: {correct_answers}.\n"
        else:
            prompt += f"The following are example answers: {correct_answers}.\n"

    prompt += f"The proposed answer is: {predicted_answer}\n"

    if prompt_for_qa:
        if (isinstance(correct_answers, list) and len(correct_answers)==0) or isinstance(correct_answers, str):
            prompt += "Within the context of the question, does the proposed answer mean the same as the expected answer?"
        else:
            prompt += "Within the context of the question, does the proposed answer mean the same as any of the expected answers?"
    else:
        prompt += "Within the context of the question and example answer, is the proposed answer correct?"

    prompt += " Respond only with yes or no.\nResponse:"

    req_json = {
        "model": use_model,
        "messages":[
            {
            "role": "user",
            "content": prompt
        }],
        "max_tokens": use_max_tokens,
        "provider": {
            "order": provider,
            "allow_fallbacks": False
        },
        "temperature": temperature,
    }

    if reasoning_config is not None:
        req_json['reasoning'] = reasoning_config

    resp = requests.post(
        url='https://openrouter.ai/api/v1/chat/completions',
        headers={
            "Authorization": f"Bearer {oai_key}",
            # "HTTP-Referer": "Judge Dred",
            # "X-Title": "Judge Dred",
            "Content-Type": "application/json"
        },
        json=req_json
    )
    resp_json = resp.json()
    # assert resp_json['choices'][0]['finish_reason']
    if not isinstance(resp_json, dict) or 'choices' not in resp_json:
        print(f"Failed to get the completion from the provider, retrying in {3.*(rec_depth+1)} sec.")
        print(f"{resp_json}")
        print(resp.status_code)
        time.sleep(3.*(rec_depth+1))
        return evaluate_llm_as_a_judge_oai_interface(
            predicted_answer,
            correct_answers, 
            question=question, 
            rec_depth=rec_depth+1, 
            oai_key=oai_key, 
            max_rec_depth=max_rec_depth,
            use_model=use_model,
            provider=provider,
            use_max_tokens=use_max_tokens,
            temperature=temperature
        )
    response = resp_json['choices'][0]['message']['content']

    if 'yes' in response.lower():
        return {
            'judge_says': 1.0, 
            '_judge_model': use_model, 
            '_prompt_style': prompt_version, 
            '_use_max_tokens': use_max_tokens,
            '_temperature': temperature,
            '_full_prompt': prompt,
            '_full_response': response,
        }
    elif 'no' in response.lower():
        return {
            'judge_says': 0.0, 
            '_judge_model': use_model, 
            '_prompt_style': prompt_version,
            '_use_max_tokens': use_max_tokens,
            '_temperature': temperature,
            '_full_prompt': prompt,
            '_full_response': response,
        }
    else:
        print(f'{response.lower()}; Redo llm check.')
        return evaluate_llm_as_a_judge_oai_interface(
            predicted_answer, 
            correct_answers, 
            question=question, 
            rec_depth=rec_depth+1, 
            oai_key=oai_key, 
            max_rec_depth=max_rec_depth,
            use_model=use_model,
            provider=provider,
            use_max_tokens=use_max_tokens,
            temperature=temperature
        )


import re

def clean_up_answer_for_qa(a: str):
    # remove non alphanumeric characters in front an at the back and bring to lowercase
    a = re.sub(r'^[\W_]+|[\W_]+$', '', a).lower()
    return a
