# prompts_tasks.py
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd
import streamlit as st
from dataset_management import render_dataset_varcols_section  # for getting var_cols

from utils.constants import MODELS, RTA_MODEL
from utils.db_client import MongoDBClient, MongoDBConfig

# DB client initialization
config = MongoDBConfig(database="TrustGen")
db_client = MongoDBClient(config)

DEFAULT_REGEX = r"(?:^\W*([01]).*)|(?:.*([01])\W*$)"


def show_all_prompts():
    coll_name = "prompt_storage"
    if coll_name not in db_client.list_collections():
        st.write("There are no promptos in the storage.")
        return
    coll = db_client.get_collection(coll_name)
    prompts = list(coll.find({}))
    if prompts:
        df = pd.DataFrame(prompts)
        if "_id" in df.columns:
            df = df.drop(columns=["_id"])
        st.write("Existing prompta (name, prompt):")
        st.dataframe(df)
    else:
        st.write("There are no promptos in the storage.")


def get_all_prompts() -> List[Dict[str, Any]]:
    coll_name = "prompt_storage"
    if coll_name not in db_client.list_collections():
        return []
    coll = db_client.get_collection(coll_name)
    return list(coll.find({}))


def prompt_exists(name: str) -> bool:
    coll_name = "prompt_storage"
    if coll_name not in db_client.list_collections():
        return False
    coll = db_client.get_collection(coll_name)
    return coll.find_one({"name": name}) is not None


def insert_prompt_global(name: str, prompt: str):
    coll_name = "prompt_storage"
    coll = db_client.get_collection(coll_name)
    coll.insert_one({"name": name, "prompt": prompt})


def show_all_rta_prompts():
    show_all_prompts()


def render_prompt_creation_section(var_cols: List[str]) -> Optional[str]:
    hint, placeholders = (
        ", ".join("{" + c + "}" for c in var_cols),
        f"You can use any selected speakers.: {', '.join('{' + c + '}' for c in var_cols)}.",
    )
    st.write(hint)
    prompt_name = st.text_input("Enter new prompt name:")
    user_prompt = st.text_area("Enter yout prompt:", value=placeholders)
    if user_prompt and prompt_name:
        missing_cols = [c for c in var_cols if f"{{{c}}}" not in user_prompt]
        if missing_cols:
            st.error("Placeholders are missing: " + ", ".join(missing_cols))
        else:
            if prompt_exists(prompt_name):
                st.warning(
                    f"A prompt named '{prompt_name}' already exists. You can use it."
                )
                if st.button("Use an existing prompt"):
                    for p in get_all_prompts():
                        if p["name"] == prompt_name:
                            return p["prompt"]
            else:
                if st.button("Add prompt to the database"):
                    insert_prompt_global(prompt_name, user_prompt)
                    st.success("Prompt has added!")
                    return user_prompt
    return None


def render_prompt_selection_section(var_cols: List[str]) -> Optional[str]:
    with st.expander("Selecting or creating a product", expanded=False):
        show_all_prompts()
        use_existing_prompt = st.radio("Prompt:", ("Select from db", "Enter your"))
        selected_prompt = None
        all_prompt_docs = get_all_prompts()
        if use_existing_prompt == "Select from db":
            if all_prompt_docs:
                names = [p["name"] for p in all_prompt_docs]
                selected_name = st.selectbox(
                    "Select the prompt by name:", names, key="prompt_selectbox"
                )
                for p in all_prompt_docs:
                    if p["name"] == selected_name:
                        selected_prompt = p["prompt"]
                        break
                if selected_prompt:
                    for c in var_cols:
                        if f"{{{c}}}" not in selected_prompt:
                            st.warning(
                                f"The placeholder for the column was not found in the product.{c}"
                            )
            else:
                st.write("There are no prompt available. Enter your own.")
        else:
            selected_prompt = render_prompt_creation_section(var_cols)
        return selected_prompt


def show_existing_regexp(metric: str):
    coll_name = f"regexp_{metric}"
    if coll_name not in db_client.list_collections():
        st.write("There are no regexps for this metric.")
        return
    coll = db_client.get_collection(coll_name)
    docs = list(coll.find({}))
    if docs:
        df = pd.DataFrame(docs)
        if "_id" in df.columns:
            df = df.drop(columns=["_id"])
        st.write("Existing regexps (name, pattern, metric):")
        st.dataframe(df)
    else:
        st.write("There are no regexps for this metric.")


def get_all_regexps_for_metric(metric: str) -> List[Dict[str, Any]]:
    coll_name = f"regexp_{metric}"
    if coll_name not in db_client.list_collections():
        return []
    coll = db_client.get_collection(coll_name)
    return list(coll.find({}))


def insert_regexp_global(name: str, pattern: str, metric: str):
    coll_name = f"regexp_{metric}"
    coll = db_client.get_collection(coll_name)
    coll.insert_one({"name": name, "pattern": pattern, "metric": metric})


def render_regexp_section(metric: str) -> Optional[str]:
    with st.expander("Selecting or creating a regexp for the metric", expanded=False):
        show_existing_regexp(metric)
        use_existing_regexp = st.radio("Regexp:", ("Existing", "Custom"))
        selected_regexp = None
        if use_existing_regexp == "Existing":
            regexps = get_all_regexps_for_metric(metric)
            if regexps:
                names = [r["name"] for r in regexps]
                selected_name = st.selectbox(
                    "Select regexp by name:", names, key="regexp_selectbox"
                )
                for r in regexps:
                    if r["name"] == selected_name:
                        selected_regexp = r["pattern"]
                        break
            else:
                st.write("There are no available regexps for this metric.")
        else:
            st.write(f"Default regexp: {DEFAULT_REGEX}")
            custom_regexp = st.text_input(
                "Enter custom regexp:", value=DEFAULT_REGEX
            )
            if custom_regexp:
                if db_client.validate_regex(custom_regexp):
                    regexp_name = st.text_input("Enter the regexp name:")
                    if regexp_name and st.button("Add regexp to DB"):
                        insert_regexp_global(regexp_name, custom_regexp, metric)
                        st.success("Regexp has been added!")
                        selected_regexp = custom_regexp
                else:
                    st.error("Invalid regular expression!")
        return selected_regexp


def render_rta_prompt_section() -> Tuple[Optional[str], Optional[str], Any]:
    with st.expander("Selecting or creating an RTA prompt", expanded=False):
        show_all_rta_prompts()
        st.write("The RtA metric is selected. An RTA prompt is required.")
        use_rta_existing = st.radio("RTA prompt:", ("Select from DB", "Enter custom one"))
        rta_prompt_selected = None
        all_prompt_docs = get_all_prompts()
        if use_rta_existing == "Select from DB":
            if all_prompt_docs:
                names = [p["name"] for p in all_prompt_docs]
                selected_name = st.selectbox(
                    "Select RTA prompt by name:", names, key="rta_prompt_selectbox"
                )
                for rp in all_prompt_docs:
                    if rp["name"] == selected_name:
                        rta_prompt_selected = rp["prompt"]
                        break
            else:
                st.write("There are no RTA prompta available. Enter your own.")
        else:
            rta_prompt_selected = render_prompt_creation_section(var_cols=[])
        rta_target = st.text_input("Target value for RtA:", value="1")
        rta_model = st.selectbox(
            "Model for RTA:",
            MODELS,
            index=MODELS.index(RTA_MODEL) if RTA_MODEL in MODELS else 0,
            key="rta_model_selectbox",
        )
        return rta_prompt_selected, rta_model, rta_target


def render_models_section() -> List[str]:
    with st.expander("Task model selection", expanded=False):
        selected_models = st.multiselect("Select model:", MODELS)
        return selected_models


def render_preview_and_save_task(
    dataset_name: str,
    var_cols: List[str],
    selected_prompt: str,
    selected_regexp: str,
    target_value: Any,
    selected_models: List[str],
    metric: str,
    rta_prompt_selected: Optional[str],
    rta_model: Optional[str],
    include_column: Optional[str],
    exclude_column: Optional[str],
):
    with st.expander("Previewing and saving an issue", expanded=False):
        if (
            selected_prompt
            and selected_regexp
            and selected_models
            and (target_value or metric in ["RtA", "include_exclude"])
        ):
            group_name = st.text_input("Task group:", value="default")
            task_name = st.text_input("Task name:", value=f"{dataset_name}")
            st.subheader("Preview 5 random examples:")
            df_head = db_client.get_dataset_head(dataset_name, limit=100)
            if not df_head.empty:
                sample_size = min(5, len(df_head))
                preview_data = (
                    df_head[var_cols].sample(n=sample_size).to_dict(orient="records")
                )
                for i, row in enumerate(preview_data):
                    filled_prompt = selected_prompt
                    for k, v in row.items():
                        filled_prompt = filled_prompt.replace(f"{{{k}}}", str(v))
                    st.write(f"**Example {i + 1}:** {filled_prompt}")
            st.write("**The structure of the task in the database:**")
            task_data = {
                "task_name": task_name,
                "dataset_name": dataset_name,
                "prompt": selected_prompt,
                "variables_cols": var_cols,
                "models": selected_models,
                "metric": metric,
                "regexp": selected_regexp,
                "group": group_name,
            }
            if metric == "RtA":
                task_data["rta_prompt"] = rta_prompt_selected
                task_data["rta_model"] = rta_model
                task_data["target"] = target_value
            elif metric == "include_exclude":
                task_data["include_column"] = include_column
                task_data["exclude_column"] = exclude_column
            else:
                task_data["target"] = target_value

            st.json(task_data, expanded=False)
            if st.button("Upload a task to the database"):
                db_client.insert_task(task_data)
                st.success("The task was successfully added!")


def render_create_task_tab():
    st.header("Create new task")
    all_datasets = db_client.get_all_datasets()
    if "regestry" in all_datasets:
        all_datasets.remove("regestry")
    selected_dataset = st.selectbox(
        "Select dataset:", sorted(all_datasets), key="create_task_selectbox"
    )
    if selected_dataset:
        var_cols, metric, target_column, include_column, exclude_column = (
            render_dataset_varcols_section(selected_dataset)
        )
        if var_cols and metric is not None:
            selected_prompt = render_prompt_selection_section(var_cols)
            if selected_prompt:
                if metric != "include_exclude":
                    selected_regexp = render_regexp_section(metric)
                else:
                    selected_regexp = "Metric include_exclude doesn't use regexp."
                if selected_regexp:
                    rta_prompt_selected = None
                    rta_model = None
                    rta_target_value = None
                    if metric == "RtA":
                        rta_prompt_selected, rta_model, rta_target_value = (
                            render_rta_prompt_section()
                        )
                    selected_models = render_models_section()
                    final_target = (
                        rta_target_value if metric == "RtA" else target_column
                    )
                    render_preview_and_save_task(
                        dataset_name=selected_dataset,
                        var_cols=var_cols,
                        selected_prompt=selected_prompt,
                        selected_regexp=selected_regexp,
                        target_value=final_target,
                        selected_models=selected_models,
                        metric=metric,
                        rta_prompt_selected=rta_prompt_selected,
                        rta_model=rta_model,
                        include_column=include_column,
                        exclude_column=exclude_column,
                    )
