import logging
import os
import time
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from pymongo import MongoClient
from pymongo.database import Database

from utils.constants import MONGO_URI

MONGO_DB = os.environ.get("MONGO_DB", "TrustGen")

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)

METRICS = ["accuracy", "correlation", "RtA", "include_exclude"]


def get_mongo_client() -> MongoClient:
    client = MongoClient(MONGO_URI)
    logger.info("Successful connection to  MongoDB.")
    return client


def get_db() -> Database:
    db = get_mongo_client()[MONGO_DB]
    logger.info(f"Database in use: {MONGO_DB}")
    return db


def extract_errors(
    df: pd.DataFrame, condition: pd.Series, input_col: str = "input", k: int = 10
) -> List[Dict[str, Any]]:
    """
    Takes a random sample of up to k rows, where condition == True,
    and returns them as dicts with input, pred, and target fields.
    """
    df_err = df[condition]
    if df_err.empty:
        return []
    sample = df_err.sample(n=min(len(df_err), k))
    return sample[[input_col, "pred", "target"]].to_dict(orient="records")


def compute_tfnr(df: pd.DataFrame) -> Tuple[float, List[Dict[str, Any]]]:
    total = len(df)
    if total == 0:
        return np.nan, []
    cond = df["pred"] == "TFN"
    value = cond.sum() / total
    errors = extract_errors(df, cond)
    return value, errors


def compute_accuracy(df: pd.DataFrame) -> Tuple[float, List[Dict[str, Any]]]:
    df_valid = df[df["pred"] != "TFN"]
    if df_valid.empty:
        return np.nan, []
    cond = df_valid["pred"].astype(str) != df_valid["target"].astype(str)
    value = (~cond).mean()
    errors = extract_errors(df_valid, cond)
    return value, errors


def compute_correlation(df: pd.DataFrame) -> Tuple[float, List[Dict[str, Any]]]:
    df_valid = df[df["pred"] != "TFN"].copy()
    df_valid["pred"] = pd.to_numeric(df_valid["pred"], errors="coerce")
    df_valid["target"] = pd.to_numeric(df_valid["target"], errors="coerce")
    df_valid = df_valid.dropna(subset=["pred", "target"])
    if len(df_valid) < 2:
        return np.nan, []
    value = df_valid["pred"].corr(df_valid["target"])
    diffs = (df_valid["pred"] - df_valid["target"]).abs()
    threshold = diffs.nlargest(min(len(diffs), 10)).min()
    cond = diffs >= threshold
    errors = extract_errors(df_valid, cond)
    return value, errors


def compute_include_exclude(df: pd.DataFrame) -> Tuple[float, List[Dict[str, Any]]]:
    if df.empty:
        return np.nan, []
    scores = []
    idx_err = []
    for idx, row in df.iterrows():
        pred = str(row.get("pred", ""))
        inc = row.get("include_list") or []
        exc = row.get("exclude_list") or []
        pos_scores = [1.0 if s.lower() in pred.lower() else 0.0 for s in inc]
        score = max(pos_scores) if pos_scores else 0.0
        neg_count = sum(1 for s in exc if s.lower() in pred.lower())
        if exc and neg_count == len(exc):
            score = 0.0
        elif exc:
            score = max(0.0, score - neg_count / len(exc))
        scores.append(score)
        if score < 1.0:
            idx_err.append(idx)
    value = float(np.mean(scores))
    cond = df.index.isin(idx_err)
    errors = extract_errors(df, cond, k=1)
    return value, errors


def fetch_extracted_tasks(db: Database, prefix: str) -> pd.DataFrame:
    cols = [c for c in db.list_collection_names() if c.startswith(prefix)]
    if prefix == "queue_":
        cols = [c for c in cols if not c.startswith("queue_rta_")]
    rows: List[Dict[str, Any]] = []
    for coll_name in cols:
        coll = db[coll_name]
        query = {"status": "extracted"}
        if prefix == "queue_":
            query["metric"] = {"$ne": "RtA"}

        logging.info(f"Uploading data for metrics from the collection {coll_name}")
        for doc in coll.find(query):
            prompt = doc.get("prompt", "")
            vars_ = doc.get("variables", {}) or {}
            inp = prompt.format(**vars_)
            inc_list = doc.get("include_list", []) or []
            exc_list = doc.get("exclude_list", []) or []

            # Guarantee that include_list and exclude_list are of type list
            if isinstance(inc_list, str):
                inc_list = [inc_list]
            if isinstance(exc_list, str):
                exc_list = [exc_list]

            metric = doc.get("metric")
            target_val = inc_list if metric == "include_exclude" else doc.get("target")
            rows.append(
                {
                    "task_name": doc.get("task_name", coll_name.replace(prefix, "")),
                    "dataset_name": doc.get("dataset_name"),
                    "model": doc.get("init_model")
                    if coll_name.startswith("queue_rta_")
                    else doc.get("model"),
                    "metric": metric,
                    "input": inp,
                    "pred": doc.get("pred"),
                    "target": target_val,
                    "include_list": inc_list,
                    "exclude_list": exc_list,
                }
            )
    df = pd.DataFrame(rows)
    logger.info(f"Retrieved {len(df)} entries from the queue'{prefix}'.")
    return df


def clear_old_results(db: Database, collection_name: str, df: pd.DataFrame):
    if df.empty:
        return
    coll = db[collection_name]
    for task, model in df[["task_name", "model"]].drop_duplicates().values:
        coll.delete_many({"task_name": task, "model": model})


def insert_results(db: Database, collection_name: str, results: List[Dict[str, Any]]):
    if not results:
        return
    df = pd.DataFrame(results)
    clear_old_results(db, collection_name, df)
    db[collection_name].insert_many(df.to_dict(orient="records"))


def compute_and_store_metrics(db: Database, interval: int = 30):
    while True:
        df = fetch_extracted_tasks(db, prefix="queue_")
        df_rta = fetch_extracted_tasks(db, prefix="queue_rta_")

        # Simple queues
        if not df.empty:
            tfnr_res, acc_res, corr_res, ie_res = [], [], [], []
            for (task, ds, model, metric), g in df.groupby(
                ["task_name", "dataset_name", "model", "metric"]
            ):
                val_tfnr, errs_tfnr = compute_tfnr(g)
                tfnr_res.append(
                    {
                        "task_name": task,
                        "dataset_name": ds,
                        "model": model,
                        "value": val_tfnr,
                        "errors": errs_tfnr,
                    }
                )
                if metric == "accuracy":
                    val, errs = compute_accuracy(g)
                    acc_res.append(
                        {
                            "task_name": task,
                            "dataset_name": ds,
                            "model": model,
                            "value": val,
                            "errors": errs,
                        }
                    )
                elif metric == "correlation":
                    val, errs = compute_correlation(g)
                    corr_res.append(
                        {
                            "task_name": task,
                            "dataset_name": ds,
                            "model": model,
                            "value": val,
                            "errors": errs,
                        }
                    )
                elif metric == "include_exclude":
                    val, errs = compute_include_exclude(g)
                    ie_res.append(
                        {
                            "task_name": task,
                            "dataset_name": ds,
                            "model": model,
                            "value": val,
                            "errors": errs,
                        }
                    )

            insert_results(db, "TFNR", tfnr_res)
            insert_results(db, "Accuracy", acc_res)
            insert_results(db, "Correlation", corr_res)
            insert_results(db, "IncludeExclude", ie_res)

        # RTA queues
        if not df_rta.empty:
            rta_res = []
            for (task, ds, model, _), g in df_rta.groupby(
                ["task_name", "dataset_name", "model", "metric"]
            ):
                val, errs = compute_accuracy(g)
                rta_res.append(
                    {
                        "task_name": task,
                        "dataset_name": ds,
                        "model": model,
                        "value": val,
                        "errors": errs,
                    }
                )
            insert_results(db, "RtAR", rta_res)

        logger.info("The metrics have been updated, and we are waiting for the next cycle.")
        time.sleep(interval)


def main():
    db = get_db()
    compute_and_store_metrics(db, interval=120)


if __name__ == "__main__":
    main()
