import dspy
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from dspy.teleprompt import BootstrapFewShot, MIPROv2, BootstrapFewShotWithRandomSearch
from azure.identity import AzureCliCredential, get_bearer_token_provider
from typing import List
from ..structures.dspy import (
    LLMScoreSignature,
    SummarySignature,
    SummarySignature_Legacy,
)


def metric(gold, pred, trace=None):
    """
    Metric function to evaluate the quality of a single predicted summary against the ground truth.
    Uses an LLM to assess the accuracy of the prediction.
    """
    evaluation = dspy.Predict(LLMScoreSignature)(
        ground_truth=gold.summary,
        prediction=pred.summary,
    )

    evaluation_dict = {}
    for field in LLMScoreSignature.output_fields.keys():
        evaluation_dict[field] = getattr(evaluation, field)

    valid_answers = [
        answer
        for answer in evaluation_dict.values()
        if answer != "Not mentioned in reference"
    ]

    if not valid_answers:
        return 0.0

    yes_count = valid_answers.count("Yes")
    total_questions = len(valid_answers)
    score = yes_count / total_questions

    return score


def prepare_training_examples(df: pd.DataFrame, text_col: str, summary_col: str):
    examples = []
    for _, row in df.iterrows():
        example = dspy.Example(
            text=row[text_col], summary=row[summary_col]
        ).with_inputs("text")
        examples.append(example)
    return examples


def optimize_prompt(train_examples, method="bootstrap", legacy=False):
    signature = SummarySignature_Legacy if legacy else SummarySignature
    if method == "bootstrap":
        print("Using Bootstrap Few Shot")
        teleprompter = BootstrapFewShot(
            metric=metric, max_bootstrapped_demos=3, max_labeled_demos=6
        )
        optimized_program = teleprompter.compile(
            dspy.Predict(signature), trainset=train_examples
        )
    elif method == "bootstrap_rs":
        print("Using Bootstrap Few Shot with Random Search")
        teleprompter = BootstrapFewShotWithRandomSearch(
            metric=metric,
            max_bootstrapped_demos=3,
            max_labeled_demos=6,
            num_candidate_programs=6,
        )
        optimized_program = teleprompter.compile(
            dspy.Predict(signature), trainset=train_examples
        )
    elif method == "mipro":
        teleprompter = MIPROv2(
            metric=metric, auto="light", prompt_model=lm, task_model=lm
        )
        optimized_program = teleprompter.compile(
            dspy.Predict(signature), trainset=train_examples
        )
    else:
        raise ValueError("Invalid optimization method. Choose 'bootstrap' or 'mipro'.")
    return optimized_program


def train_and_predict(
    train_df: pd.DataFrame,
    test_df: pd.DataFrame,
    text_cols: List[str],
    summary_col: str,
    output_folder: Path,
    optimization_method="bootstrap",
    legacy=False,
):
    test_df_copy = test_df.copy()

    for text_col in text_cols:
        print(f"Training model for {text_col}")

        train_examples = prepare_training_examples(train_df, text_col, summary_col)

        # Optimize the summarization program
        optimized_generator = optimize_prompt(
            train_examples, method=optimization_method, legacy=legacy
        )

        # Save optimized model for this column
        model_path = output_folder / f"optimized_summary_generator_{text_col}.json"
        optimized_generator.save(str(model_path))

        # Generate predictions for this text column
        predictions = []
        for _, row in tqdm(
            test_df_copy.iterrows(), desc="Making Predictions", total=len(test_df_copy)
        ):
            prediction = optimized_generator(text=row[text_col])
            predictions.append(prediction.summary)

        # Add predictions column named after the text column
        test_df_copy[f"{text_col}_one_liner"] = predictions

    # Save test set with all predictions
    output_csv_path = output_folder / "test_predictions.csv"
    test_df_copy.to_csv(output_csv_path, index=False)

    print(f"Optimization complete. Results saved to: {output_folder}")
    print(f"Test predictions saved as: {output_csv_path}")

    return output_folder


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Train DSPy summary generator and generate predictions"
    )
    parser.add_argument("--train", required=True, help="Path to training CSV file")
    parser.add_argument("--test", required=True, help="Path to test CSV file")
    parser.add_argument("--output_path", required=True)
    parser.add_argument(
        "--text_cols", required=True, nargs="+", help="Names of the text columns"
    )
    parser.add_argument(
        "--summary_col", required=True, help="Name of the summary column"
    )
    parser.add_argument("--provider", default="ollama")
    parser.add_argument(
        "--method",
        choices=["bootstrap", "bootstrap_rs", "mipro"],
        default="bootstrap",
        help="Optimization method to use (bootstrap or mipro)",
    )


    args = parser.parse_args()

    if args.provider == "ollama":
        lm = dspy.LM(
            model="ollama_chat/gemma3:27b-it-qat",
            api_base="http://localhost:11434",
            cache=False,
        )
    elif args.provider == "azure":
        ad_token = get_bearer_token_provider(
            AzureCliCredential(), "https://cognitiveservices.azure.com/.default"
        )()
        lm = dspy.LM(
            "azure/o3-mini",
            api_base="",
            api_version="2024-12-01-preview",
            temperature=1.0,
            max_tokens=20000,
            azure_ad_token=ad_token,
        )
    else:
        print(f"Provider {args.provider} is not available.")
        exit()

    dspy.configure(lm=lm, verbose=True)

    output_dir = Path(args.output_path)
    output_dir.mkdir(exist_ok=True)

    train_df = pd.read_csv(Path(args.train))
    test_df = pd.read_csv(Path(args.test))

    train_and_predict(
        train_df=train_df,
        test_df=test_df,
        text_cols=args.text_cols,
        summary_col=args.summary_col,
        output_folder=output_dir,
        optimization_method=args.method,
        legacy=args.legacy,
    )
