import json
from typing import Any, Dict, List, Union

import pandas as pd
import numpy as np
from dotenv import load_dotenv, find_dotenv
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss
from scipy.stats import pearsonr, spearmanr
from uqlm import BlackBoxUQ
from langchain_openai import AzureChatOpenAI
from langchain_google_vertexai import ChatVertexAI
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn

from anonlib.longform.luq import UnitResponseScorer, MatchedUnitScorer
from anonlib.longform.benchmark import FactScoreGrader
from anonlib.longform.graph import ClaimMerger
from anonlib.longform.decomposition import ResponseDecomposer
from anonlib.longform.utils import claims_dicts_to_lists
from anonlib.scorers import LongTextQA, LongTextGraph

BLACK_BOX_SCORERS = ["exact_match", "cosine_sim", "noncontradiction", "semantic_negentropy", "bert_score"]

def llm_name_to_rate_limit(llm_name: str) -> Union[float, None]:
    if llm_name == "gpt4o":
        max_calls_per_min = 175
    elif llm_name in ["gemini_pro", "gpt4o_mini"]:
        max_calls_per_min = 1000
    else:
        max_calls_per_min = 2000
    return max_calls_per_min


def load_llms():
    load_dotenv(find_dotenv())
    gpt4o = AzureChatOpenAI(
        deployment_name="gpt-4o",
        openai_api_type="azure",
        openai_api_version="2024-02-15-preview",
        temperature=0
    )
    gpt4o_mini = AzureChatOpenAI(
        deployment_name="gpt-4o-mini",
        openai_api_type="azure",
        openai_api_version="2024-02-15-preview",
        temperature=0
    )
    gemini_pro = ChatVertexAI(model_name="gemini-2.5-pro", temperature=0)
    gemini_flash = ChatVertexAI(model_name="gemini-2.5-flash", temperature=0)
    return gpt4o, gpt4o_mini, gemini_pro, gemini_flash

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)

def create_progress_bar():
    completion_text = "[progress.percentage]{task.completed}/{task.total}"
    progress_bar = Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn(completion_text), TimeElapsedColumn())
    progress_bar.start()
    return progress_bar

async def generate_responses_and_samples(dataset: pd.DataFrame, llm_dict: dict[str, Any], write_path: str) -> None:
    prompts = dataset["question"].to_list()
    
    for llm_name, llm in llm_dict.items():
        max_calls_per_min = llm_name_to_rate_limit(llm_name)

        bbuq = BlackBoxUQ(
            llm=llm, 
            max_calls_per_min=max_calls_per_min, 
            scorers=BLACK_BOX_SCORERS,
            sampling_temperature=1.0,
        )
        bbuq_result = await bbuq.generate_and_score(prompts=prompts, num_responses=10)
        bbuq_result_dict = bbuq_result.data

        with open(f"{write_path}/responses_{llm_name}.json", "w") as json_file:
            json.dump(bbuq_result_dict, json_file, cls=NumpyEncoder, indent=4)


async def decompose_claims(llm_dict: Dict[str, Any], path: str) -> None:
    rd = ResponseDecomposer(llm_dict["gemini_flash"])

    for llm_name, llm in llm_dict.items():
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)

        response_dict["claims"] = await rd.decompose_claims(responses=response_dict["responses"])
        response_dict["sampled_claims"] = await rd.decompose_candidate_claims(sampled_responses=response_dict["sampled_responses"])  

        with open(f"{path}/responses_{llm_name}.json", "w") as json_file:
            json.dump(response_dict, json_file, indent=4)



def decompose_sentences(llm_dict: Dict[str, Any], path: str) -> None:
    import spacy
    try:
        spacy.load('en_core_web_sm')
    except OSError:
        import subprocess
        subprocess.run(['python -m spacy download en_core_web_sm'], check=True)
    nlp = spacy.load('en_core_web_sm')
    def split_sentences(text: str) -> List[str]:
        return [str(sent) for sent in nlp(text).sents]
    
    for llm_name in llm_dict:
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)

        response_dict["sentences"] = [split_sentences(text) for text in response_dict["responses"]]

        sampled_sentences = []
        for sampled_response_set in response_dict["sampled_responses"]:
            sampled_sentences.append([split_sentences(text) for text in sampled_response_set])

        response_dict["sampled_sentences"] = sampled_sentences

        with open(f"{path}/responses_{llm_name}.json", "w") as json_file:
            json.dump(response_dict, json_file, indent=4)   
            
async def grade_claims(llm_dict: Dict[str, Any], dataset: pd.DataFrame, granularity: str, path: str) -> None:
    for llm_name in llm_dict:
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)
            
        fse = FactScoreGrader(llm_dict["gemini_flash"], max_calls_per_min=1500)
        
        if granularity == "claim":
            claim_sets = response_dict["claims"]
        else:
            claim_sets = response_dict["sentences"]
        
        answers = dataset.answer.to_list()
            
        grades = await fse.grade_claims(
            claim_sets=claim_sets, 
            answers=answers, 
            # progress_bar=progress_bar
        )
        response_dict[f"{granularity}_grades"] = grades
        
        with open(f"{path}/responses_{llm_name}.json", "w") as json_file:
            json.dump(response_dict, json_file, indent=4)    
            
async def evaluate_objectivity(llm_dict: Dict[str, Any], granularity: str, path: str) -> None: 
    for llm_name in llm_dict:
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)
            
        fse = FactScoreGrader(llm_dict["gemini_flash"], max_calls_per_min=1500)

        if granularity == "claim":
            claim_sets = response_dict["claims"]
        else:
            claim_sets = response_dict["sentences"]
            
        objectivity_bools = await fse.evaluate_claim_objectivity(
            claim_sets=claim_sets, 
            # progress_bar=progress_bar
        )
        response_dict[f"{granularity}_objectivity"] = objectivity_bools
        
        with open(f"{path}/responses_{llm_name}.json", "w") as json_file:
            json.dump(response_dict, json_file, indent=4)    
            
async def score_unit_response(llm_dict: Dict[str, Any], granularity: str, path: str) -> None: 
    for llm_name in llm_dict:
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)
            
        sets_to_score = response_dict[f"{granularity}s"]
        sampled_responses = response_dict["sampled_responses"]

        urs = UnitResponseScorer()
        
        urs_result = urs.evaluate(
            claim_sets=sets_to_score, 
            sampled_responses=sampled_responses, 
            # progress_bar=progress_bar
        )
        df = pd.DataFrame(urs_result.to_dict(return_all=True))
        df.to_parquet(f"{path}/{llm_name}_{granularity}_response_scores.parquet")
        

async def score_matched_unit(llm_dict: Dict[str, Any], granularity: str, path: str) -> None: 
    for llm_name in llm_dict:
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)

        sets_to_score = response_dict[f"{granularity}s"]
        sampled_sets = response_dict[f"sampled_{granularity}s"]
    
        mus = MatchedUnitScorer()
        mus_result = mus.evaluate(
            claim_sets=sets_to_score, 
            sampled_claim_sets=sampled_sets, 
            # progress_bar=progress_bar
        )
        
        df = pd.DataFrame(mus_result.to_dict(return_all=True))
        df.to_parquet(f"{path}/{llm_name}_matched_{granularity}_scores.parquet")
        
        
async def score_unit_qa(llm_dict: Dict[str, Any], dataset: pd.DataFrame, granularity: str, path: str) -> None: 

    for llm_name, llm in llm_dict.items():
        max_calls_per_min = llm_name_to_rate_limit(llm_name)
        
        if granularity == "sentence":
            num_questions = 3
        else:
            num_questions = 1
        
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)
        sets_to_score = response_dict[f"{granularity}s"]
        
        ltqa = LongTextQA(
            llm=llm, 
            scorers=BLACK_BOX_SCORERS, 
            question_generator_llm=llm_dict["gemini_flash"],
            max_calls_per_min=max_calls_per_min,
        )
        claimqa_result = await ltqa._score_from_decomposed(
            claim_sets=sets_to_score,
            prompts=dataset.question.tolist(),
            num_questions=num_questions
        )
        
        claims_data = claims_dicts_to_lists(claimqa_result.data["claims_data"])
        df = pd.DataFrame(claims_data)
        
        df.to_parquet(f"{path}/{llm_name}_{granularity}_qa_scores.parquet")

    
async def merge_claims(llm_dict: Dict[str, Any], path: str) -> None:
    
    for llm_name in llm_dict:
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)

        claims = response_dict["claims"]
        sampled_claims = response_dict["sampled_claims"]

        cm = ClaimMerger(llm_dict["gemini_flash"])

        merged_claims = await cm.merge_claims(
            original_claim_sets=claims,
            sampled_claim_sets=sampled_claims,
        )

        response_dict["merged_claims"] = merged_claims

        with open(f"{path}/responses_{llm_name}.json", "w") as json_file:
            json.dump(response_dict, json_file, indent=4)   


async def score_graph_uq(llm_dict: Dict[str, Any], dataset: pd.DataFrame, path: str) -> None: 

    for llm_name, llm in llm_dict.items():
        max_calls_per_min = llm_name_to_rate_limit(llm_name)
        
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)
        responses = response_dict[f"responses"]
        sampled_responses = response_dict[f"sampled_responses"]
        response_sets = [[r] + sr for r, sr in zip(responses, sampled_responses)]
        claim_sets = response_dict[f"claims"]
        master_claim_sets = response_dict[f"merged_claims"]
        
        ltg = LongTextGraph(
            llm=llm, 
            claim_decomposition_llm=llm_dict["gemini_flash"], 
            max_calls_per_min=max_calls_per_min,
            scorers=["degree_centrality", "betweenness_centrality", "closeness_centrality", "page_rank", "laplacian_centrality", "harmonic_centrality"],
        )
        original_claim_scores, master_claim_scores, graph_score_result = ltg._score_from_decomposed(
            original_claim_sets=claim_sets,
            master_claim_sets=master_claim_sets,
            response_sets=response_sets,
        )
        

        df = pd.DataFrame(original_claim_scores)
        df.to_parquet(f"{path}/{llm_name}_graph_uq_scores.parquet")
        
        
def expected_calibration_error(y_true: np.ndarray, y_proba: np.ndarray, n_bins: int = 10) -> float:
    """
    Calculate Expected Calibration Error (ECE).
    
    ECE measures the difference between predicted confidence and actual accuracy
    across confidence bins. Lower values indicate better calibration.
    
    Args:
        y_true: Ground truth labels (0 or 1)
        y_proba: Predicted probabilities
        n_bins: Number of bins for calibration
        
    Returns:
        ECE score (lower is better)
    """
    y_true = np.array(y_true)
    y_proba = np.array(y_proba)
    
    # Create bins
    bins = np.linspace(0, 1, n_bins + 1)
    bin_indices = np.digitize(y_proba, bins[:-1]) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)
    
    ece = 0.0
    for i in range(n_bins):
        mask = bin_indices == i
        if mask.sum() > 0:
            bin_accuracy = y_true[mask].mean()
            bin_confidence = y_proba[mask].mean()
            bin_weight = mask.sum() / len(y_true)
            ece += bin_weight * np.abs(bin_accuracy - bin_confidence)
    
    return ece


def compute_metrics(
    llm_dict: Dict[str, Any],
    dataset: Any, 
    path: str,
    objective_only: bool = False,
    granularity="claim",
) -> Dict[str, Dict[str, float]]:
    
    granularity_to_scorer_types = {
        "claim": ["claim_response", "claim_qa", "graph_uq"],
        "sentence": ["sentence_response", "matched_sentence", "sentence_qa"]
        
    }
    
    scorer_type_to_score_names = {
        "claim_response": ["entailment", "noncontradiction", "contrasted_entailment"],
        "sentence_response": ["entailment", "noncontradiction", "contrasted_entailment"],
        "matched_sentence": ["entailment", "noncontradiction", "contrasted_entailment", "cosine_sim", "bert_score"],
        "claim_qa": ["exact_match", "semantic_negentropy", "noncontradiction", "bert_score", "cosine_sim"],
        "sentence_qa": ["exact_match", "semantic_negentropy", "noncontradiction", "bert_score", "cosine_sim"],
        "graph_uq": ['betweenness_centrality','closeness_centrality','harmonic_centrality','page_rank','laplacian_centrality']
    }    
    
    answers = dataset["answer"]
    full_results = {llm_name: {"average_grade": None, "auroc": {}, "auprc": {}, "brier_score": {}, "ece": {}, "pearson": {}, "spearman": {}} for llm_name in llm_dict}
    
    for llm_name in llm_dict:
        
        with open(f"{path}/responses_{llm_name}.json") as json_file:
            response_dict = json.load(json_file)        
        
        grades = response_dict[f"{granularity}_grades"]
        
        if objective_only:
            unit_objectivity_bools = response_dict[f"{granularity}_objectivity"]
        else:
            unit_objectivity_bools = [[True] * len(grade_set) for grade_set in grades]

        for scorer_type in granularity_to_scorer_types[granularity]:
            scores_df = pd.read_parquet(f"{path}/{llm_name}_{scorer_type}_scores.parquet")
    

            for scorer in scorer_type_to_score_names[scorer_type]:   
                scores = scores_df[scorer]
                unit_grades_flat, unit_scores_flat = [], []
                response_grades, response_scores = [], []
                # iterate over responses
                for i, (unit_scores, unit_grades, unit_objectivity) in enumerate(zip(scores, grades, unit_objectivity_bools)):
                    unit_scores_i, unit_grades_i = [], []
                    for j, (unit_score, unit_grade) in enumerate(zip(unit_scores, unit_grades)):
                        if unit_objectivity[j]:
                            unit_scores_i.append(np.mean(unit_score))
                            unit_grades_i.append(unit_grade)
                    unit_scores_flat.extend(unit_scores_i)
                    unit_grades_flat.extend(unit_grades_i)
                    response_grades.append(np.mean(unit_grades_i))
                    response_scores.append(np.mean(unit_scores_i))
                # clip due to numpy mean can lead to values slightly above 1
                unit_scores_flat = np.clip(unit_scores_flat, 0, 1) 
                unit_grades_flat = np.clip(unit_grades_flat, 0, 1)
                response_grades = np.clip(response_grades, 0, 1)
                response_scores = np.clip(response_scores, 0, 1)
                full_results[llm_name]["auroc"][scorer_type + "_" + scorer] = roc_auc_score(y_score=unit_scores_flat, y_true=unit_grades_flat)
                full_results[llm_name]["auprc"][scorer_type + "_" + scorer] = average_precision_score(y_score=unit_scores_flat, y_true=unit_grades_flat)
                full_results[llm_name]["brier_score"][scorer_type + "_" + scorer] = brier_score_loss(y_proba=unit_scores_flat, y_true=unit_grades_flat)
                full_results[llm_name]["ece"][scorer_type + "_" + scorer] = expected_calibration_error(y_proba=unit_scores_flat, y_true=unit_grades_flat)
                full_results[llm_name]["pearson"][scorer_type + "_" + scorer] = pearsonr(x=response_grades, y=response_scores).statistic
                full_results[llm_name]["spearman"][scorer_type + "_" + scorer] = spearmanr(a=response_grades, b=response_scores).statistic
                
                if not full_results[llm_name]["average_grade"]:
                    full_results[llm_name]["average_grade"] = np.mean(unit_grades_flat)

        with open(f"{path}/metrics_{llm_name}.json", "w") as json_file:
            json.dump(full_results, json_file, indent=4) 
    