import pandas as pd
import os
# import evaluate
import numpy as np
import json

# from utils.FineRadScore.gpt4_generations import *
# from RaTEScore import RaTEScore
from green_score import GREEN
from radgraph import F1RadGraph

def calculate_rate_score(paired_reports, path):
    """
    Calculate RaTEScore
    - paired_reports: DataFrame with the following columns:
        - study_id: study_id for which to calculate metric
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    - path: where to store structured NER results
    """

    print("----- RUN RATESCORE CALCULATION -----")

    # calculate RATEScore
    pred_report = list(paired_reports["report_cand"])
    gt_report = list(paired_reports["report_ref"])
    ratescore = RaTEScore(visualization_path = path, use_gpu=False)
    scores = ratescore.compute_score(pred_report, gt_report)

    # add scores to dataframe
    paired_reports["RATEScore"] = scores

    return paired_reports

def calculate_green_score(paired_reports):
    """
    Calculate GREEN
    - paired_reports: DataFrame with the following columns:
        - study_id: study_id for which to calculate metric
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    """

    print("----- RUN GREEN SCORE CALCULATION -----")

    # calculate GREEN Score
    pred_report = list(paired_reports["report_cand"])
    gt_report = list(paired_reports["report_ref"])
    model_name = "StanfordAIMI/GREEN-radllama2-7b"
    green_scorer = GREEN(model_name, output_dir=".", cpu=False)
    mean, std, green_score_list, summary, result_df = green_scorer(gt_report, pred_report)
    result_df["study_id"] = list(paired_reports["study_id"])

    # we can calculate mean, std and green_score_list from "green" column in result_df
    # we don't need the summary
    return result_df

def get_GPT4_response(row, max_retries):
    """
    Prompt GPT4 to generate FineRadScore output
    - row: Series for which to get FineRadScore response, has the following columns
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    - max_retries: maximum number of times to try reprompting when an error occurs
    """

    pred_target = row["report_cand"]
    gt_target = row["report_ref"]
    total_cost = 0
    done = False
    retry_count = 0

    while not done:
        done = True

        try:
            response, cost = generate_gpt4_response(pred_target, gt_target)
        except Exception as e:
            print("something's wrong!")
            print(e)
            done = False
            continue

        total_cost += cost
            
        for sentence_id in response:
            # bad generation: regenerate
            if not sentence_id.isdigit() and sentence_id != "None":
                done = False
                print("Error: key is not a sentence id")
                break

            # bad generation: regenerate
            corrected_line = response[sentence_id]
            if "corrections" not in corrected_line or "clinical severity" not in corrected_line or "comments" not in corrected_line or "error category" not in corrected_line:
                done = False
                print("Error: json object not formatted correctly")
                break

        # bad generation: regenerate
        if len(response) == 0: 
            done = False
            
        retry_count += 1
        if retry_count > max_retries:
            done = True
            response = {}
            break

    result = {"pred": pred_target, "ground_truth": gt_target, "response": response}
    
    return result, total_cost


SEVERITY_SCORE = {"no error": 0, "invalid comparison": 0, 
                 "not actionable": 1, "actionable nonurgent error": 2, 
                 "urgent error": 3, "emergent error": 4}

def calculate_fineradscore(paired_reports, path):
    """
    Calculate FineRadScore
    - paired_reports: DataFrame with the following columns:
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    - path: where to store full response for each study
    """

    print("----- RUN FINERADSCORE CALCULATION -----")

    total_cost = 0
    i = 0
    
    with open(path, "a") as f:
        for idx, row in paired_reports.iterrows():

            # get FineRadScore analysis
            result, cost = get_GPT4_response(row, 5) 
            total_cost += cost

            # extract severity scores from response
            scores = []
            for sent_id, entry in result["response"].items(): 
                try: 
                    severity = entry["clinical severity"].lower()
                    score = SEVERITY_SCORE[severity]
                    scores.append(score)
                except:
                    scores.append(0)

            # store sum of severity scores (FineRadScore orig. paper) and max of severity scores (ReXrank benchmark)
            if len(scores) != 0:
                paired_reports.at[idx, "sum_score"] = sum(scores)
                paired_reports.at[idx, "max_score"] = max(scores)
                result["sum_score"] = sum(scores)
                result["max_score"] = max(scores)
            else: 
                paired_reports.at[idx, "sum_score"] = float("nan")
                paired_reports.at[idx, "max_score"] = float("nan")
                result["sum_score"] = float("nan")
                result["max_score"] = float("nan")
                print("score is nan!")

            # store response for later use
            f.write(json.dumps(result) + "\n") 
            f.flush()

            i += 1
            if i % 10 == 0:
                print(f"total cost after {i} reports: {total_cost}")

    print(f"Total cost after FineRadScore calculation: {total_cost}")

    return paired_reports, total_cost

def calculate_bleu(paired_reports):
    """
    Calculate BLEU (use hugginface/evaluate library)
    - paired_reports: DataFrame with the following columns:
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    """

    print("----- RUN BLEU SCORE CALCULATION -----")

    result_df = paired_reports.copy()
    result_df["bleu"] = 0
    bleu = evaluate.load("bleu")
    for i, row in paired_reports.iterrows():
        if not isinstance(row["report_cand"], float) and not isinstance(row["report_ref"], float):
            results = bleu.compute(predictions=[row["report_cand"]], references=[[row["report_ref"]]])
            result_df.at[i, "bleu"] = results["bleu"]
        else: 
            result_df.at[i, "bleu"] = np.nan      
    return result_df

def calculate_bertscore(paired_reports): 
    """
    Calculate BERTScore (use hugginface/evaluate library)
    - paired_reports: DataFrame with the following columns:
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    """

    print("----- RUN BERTSCORE CALCULATION -----")

    result_df = paired_reports.copy()
    bertscore = evaluate.load("bertscore")
    results = bertscore.compute(predictions=paired_reports["report_cand"].to_list(), references=paired_reports["report_ref"].to_list(), model_type="distilroberta-base")
    result_df["bertscore"] = results["f1"]
    return result_df     
        
def calculate_radgraphF1(paired_reports):
    """
    Calculate RadGraph F1
    - paired_reports: DataFrame with the following columns:
        - report_ref: reference ground truth report
        - report_cand: candidate predicted report to compare with reference
    - model_ver: Version of radgraph ("radgraph" or "radgraph-xl")
    """

    print("----- RUN RADGRAPH CALCULATION -----")

    result_df = paired_reports.copy()
    hyps = list(paired_reports["report_cand"])
    refs = list(paired_reports["report_ref"])

    # run old radgraph
    f1radgraph = F1RadGraph(reward_level="all", model_type="radgraph")
    _, reward_list, _, _ = f1radgraph(hyps=hyps, refs=refs)
    result_df["RadgraphF1"] = reward_list[1] # second value is RG_ER

    # run new radgraph-xl
    f1radgraph = F1RadGraph(reward_level="all", model_type="radgraph-xl")
    _, reward_list, _, _ = f1radgraph(hyps=hyps, refs=refs)
    result_df["RadgraphF1-xl"] = reward_list[1] # second value is RG_ER

    return result_df