import json
from typing import Any, Dict, List

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go  
import streamlit as st

from utils.db_client import MongoDBClient, MongoDBConfig

# DB cliemt initialization
config = MongoDBConfig(database="TrustGen")
db_client = MongoDBClient(config)


def visualize_metrics(results_data: List[Dict[str, Any]], collection_name: str):
    results_df = pd.DataFrame(results_data)
    if "_id" in results_df.columns:
        results_df = results_df.drop(columns=["_id"])
    required_cols = {"task_name", "model", "value"}
    if not required_cols.issubset(results_df.columns):
        st.error("The required fields are missing in the data (task_name, model, value).")
        return

    # Selection by tasks and models
    tasks = results_df["task_name"].unique()
    models = results_df["model"].unique()
    selected_tasks = st.multiselect(
        "Select task(s):",
        options=tasks,
        default=list(tasks),
        key=f"metrics_tasks_{collection_name}",
    )
    selected_models = st.multiselect(
        "Select models:",
        options=models,
        default=list(models),
        key=f"metrics_models_{collection_name}",
    )
    filtered_df = results_df[
        (results_df["task_name"].isin(selected_tasks))
        & (results_df["model"].isin(selected_models))
    ]
    if filtered_df.empty:
        st.info("There is no data to display with the selected filters.")
        return

    # tabular and graphical representation of metrics
    pivot_table = filtered_df.pivot_table(
        index="model", columns="task_name", values="value", aggfunc="mean"
    )
    st.subheader("The metric table by tasks and models")
    st.dataframe(pivot_table)
    st.subheader("Metric visualization")
    st.bar_chart(pivot_table)

    with st.expander("View the top 10 errors for selected tasks and models"):
        if "errors" not in filtered_df.columns:
            st.info("There are no saved errors for this metric.There are no saved errors for this metric.")
        else:
            df_err = (
                filtered_df[["task_name", "model", "errors"]]
                .dropna(subset=["errors"])
                .drop_duplicates(subset=["task_name", "model"])
            )
            if df_err.empty:
                st.info("No errors")
            else:
                df_err["errors"] = df_err["errors"].apply(
                    lambda errs: json.dumps(errs, ensure_ascii=False, indent=2)
                )
                df_to_show = df_err.set_index(["task_name", "model"])
                st.dataframe(df_to_show)


def render_metrics_tab():
    st.header("Model metrics")
    results_collections = ["RtAR", "TFNR", "Accuracy", "Correlation", "IncludeExclude"]
    if results_collections:
        selected_results_collection = st.selectbox(
            "Select a collection with metrics",
            options=results_collections,
            key="metrics_collection_selection",
        )
        results_collection = db_client.get_collection(selected_results_collection)
        results_data = list(results_collection.find())
        if results_data:
            visualize_metrics(results_data, selected_results_collection)
        else:
            st.info(f"There is no data in the collection '{selected_results_collection}'.")
    else:
        st.info("There are no available collections with metrics.")

    # 🔽 Interactive comparison and correlation
    with st.expander("Comparison of metrics and correlation between tasks"):
        task_options = set()
        data_per_collection: Dict[str, pd.DataFrame] = {}
        for coll in results_collections:
            if coll == "TFNR":
                continue
            recs = list(db_client.get_collection(coll).find())
            if not recs:
                continue
            df = pd.DataFrame(recs)
            if "_id" in df.columns:
                df.drop(columns=["_id"], inplace=True)
            if {"task_name", "model", "value"}.issubset(df.columns):
                task_options.update(df["task_name"].unique())
                data_per_collection[coll] = df
        task_options = sorted(task_options)

        # --- scatter plot for 2 tasks ---
        sel = st.multiselect(
            "Select two tasks for the scatter plot:",
            task_options,
            max_selections=3,
            key="compare_task_names",
        )
        if len(sel) >= 2:
            df_all = pd.concat(
                [
                    df[df["task_name"].isin(sel)][["task_name", "model", "value"]]
                    for df in data_per_collection.values()[:2]
                ]
            )
            pivot = df_all.pivot_table(
                index="model", columns="task_name", values="value"
            ).dropna()
            if pivot.shape[1] == 2:
                st.subheader("Interactive graph: comparison of metrics")
                st.dataframe(pivot)
                fig = px.scatter(
                    pivot,
                    x=sel[0],
                    y=sel[1],
                    text=pivot.index,
                    labels={sel[0]: sel[0], sel[1]: sel[1]},
                    title="Comparison of models by selected metrics",
                )
                fig.update_traces(textposition="top center")
                fig.update_layout(height=600)
                st.plotly_chart(fig, use_container_width=True)
            else:
                st.warning("There is not enough data for the scatter plot.")

        # --- Interactive correlation for the task list ---
        corr_sel = st.multiselect(
            "Select tasks for correlation analysis:",
            task_options,
            key="correlation_tasks",
        )
        if len(corr_sel) >= 2:
            df_corr = pd.concat(
                [
                    df[df["task_name"].isin(corr_sel)][["task_name", "model", "value"]]
                    for df in data_per_collection.values()
                ]
            )
            pivot_corr = df_corr.pivot_table(
                index="model", columns="task_name", values="value"
            ).dropna()
            if not pivot_corr.empty:
                st.subheader("Correlation matrix of tasks")
                st.dataframe(pivot_corr.corr().round(2))
                corr_matrix = pivot_corr.corr()
                fig = go.Figure(
                    data=go.Heatmap(
                        z=corr_matrix.values,
                        x=corr_matrix.columns,
                        y=corr_matrix.columns,
                        colorscale="RdBu",
                        zmin=-1,
                        zmax=1,
                        colorbar=dict(title="Correlation"),
                        hovertemplate="Tasks: %{y} and %{x}<br>Value: %{z:.2f}<extra></extra>",
                    )
                )
                fig.update_layout(
                    title="Interactive correlation matrix of tasks",
                    xaxis=dict(title=""),
                    yaxis=dict(title="", autorange="reversed"),
                    height=600,
                )
                st.plotly_chart(fig, use_container_width=True)
            else:
                st.warning("There is not enough data to build a correlation matrix.")
