import argparse
import pandas as pd
from rouge_score import rouge_scorer
from statistics import mean, stdev
import bert_score
from .utils import OllamaClient, AzureOpenAIClient
from .eval_prompts import EVAL_PROMPT
from typing import Dict
from tqdm import tqdm


def get_llm_score(evaluation: Dict[str, str]) -> float:
    valid_answers = [
        answer
        for answer in evaluation.values()
        if answer != "Not mentioned in reference"
    ]

    if not valid_answers:
        return 0.0

    yes_count = valid_answers.count("Yes")
    total_questions = len(valid_answers)
    ratio = yes_count / total_questions
    return round(ratio, 4)


def calculate_bert_score(gt_texts, pred_texts):
    print("\nCalculating BERT Score")
    _, _, F1 = bert_score.score(
        pred_texts,
        gt_texts,
        lang="en",
        verbose=False,
        device="cuda:1",
        model_type="microsoft/deberta-xlarge-mnli",
    )
    return F1.tolist()


def calculate_rouge_l(gt_texts, pred_texts):
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    scores = []
    for pred, ref in tqdm(
        zip(pred_texts, gt_texts), total=len(gt_texts), desc="Calculating ROUGE-L score"
    ):
        score = scorer.score(ref, pred)
        scores.append(score["rougeL"].fmeasure)
    return scores


def calculate_llm_judge(gt_texts, pred_texts, llm_client, prompt_template):
    scores = []
    for gt, pred in tqdm(
        zip(gt_texts, pred_texts),
        desc="Calculating LLM-as-a-Judge score",
        total=len(pred_texts),
    ):
        prompt = prompt_template.format(ground_truth=gt, prediction=pred)
        output = llm_client.get_output(prompt)
        llm_score = get_llm_score(output)
        scores.append(llm_score)
    return scores


def calculate_metrics(df, gt_col, pred_cols, llm_client=None, prompt_template=None):
    gt_texts = df[gt_col].astype(str).tolist()
    results = {}
    i = 0
    for pred_col in pred_cols:
        i += 1
        print(f"\n\nHanding column {i}/{len(pred_cols)}: {pred_col} \n\n")
        col_results = {}
        pred_texts = df[pred_col].astype(str).tolist()

        rouge_l_scores = calculate_rouge_l(gt_texts, pred_texts)
        col_results["ROUGE-L"] = {
            "scores": rouge_l_scores,
            "mean": mean(rouge_l_scores),
            "stdev": stdev(rouge_l_scores),
        }

        bert_scores = calculate_bert_score(gt_texts, pred_texts)
        col_results["BERTScore"] = {
            "scores": bert_scores,
            "mean": mean(bert_scores),
            "stdev": stdev(bert_scores),
        }

        if llm_client and prompt_template:
            llm_scores = calculate_llm_judge(
                gt_texts, pred_texts, llm_client, prompt_template
            )
            col_results["LLM-Judge"] = {
                "scores": llm_scores,
                "mean": mean(llm_scores),
                "stdev": stdev(llm_scores),
            }
        else:
            print("No llm or prompt template provided, skipping LLM-as-a-judge")
        results[pred_col] = col_results

    return results


def print_report(metrics):
    print("=" * 40)
    print("METRICS REPORT")
    print("=" * 40)
    for col_name, result_dict in metrics.items():
        print(f"\n\nScores for {col_name}")
        for metric_name, result in result_dict.items():
            print(f"{metric_name:12}: {result['mean']:.4f} ± {result['stdev']:.4f}")
        print("___" * 20)

    print("=" * 40)


def main():
    parser = argparse.ArgumentParser(description="Calculate text metrics")
    parser.add_argument("--input", required=True, help="Input CSV file path")
    parser.add_argument("--gt_col", required=True, help="Ground truth column name")
    parser.add_argument(
        "--pred_cols", nargs="+", required=True, help="Prediction column names"
    )
    parser.add_argument(
        "--llm-judge", choices=["ollama", "azure"], help="LLM judge type"
    )
    parser.add_argument(
        "--model-name", default="gemma3:27b-it-qat", help="Ollama model name"
    )

    args = parser.parse_args()

    df = pd.read_csv(args.input)

    llm_client = None
    prompt_template = EVAL_PROMPT

    if args.llm_judge == "ollama":
        llm_client = OllamaClient(args.model_name, host="http://localhost:11434")
    elif args.llm_judge == "azure":
        llm_client = AzureOpenAIClient(args.model_name)

    print("Calculating metrics...")
    metrics = calculate_metrics(
        df, args.gt_col, args.pred_cols, llm_client, prompt_template
    )
    print_report(metrics)


if __name__ == "__main__":
    main()
