import os
import pandas as pd
from typing import List, Tuple, Dict, Any
from utils.funcs import transform_structure
from prompts.prompts import (
    GREEDY_COPULA_GUIDE_PROMPT,
    GENERATION_COPULA_GUIDE_C_COT_BATCH_PROMPT,
    GREEDY_MARGINAL_CAT_GUIDE_PROMPT,
    GREEDY_MARGINAL_NUM_GUIDE_PROMPT,
    GREEDY_MARGINAL_GUIDE_PROMPT,
    GENERATION_MARGINAL_GUIDE_C_COT_BATCH_PROMPT,
    COPULA_INFERENCE_PROMPT,
)


class PromptBuilder:
    """Class responsible for assembling prompts without any statistical or bin-freq calculations."""
    
    def __init__(self, args, data, config):
        self.args = args
        self.data = data
        self.config = config
        self.variable_desc_string = self._get_variable_desc_string()

    # ---------- Common Methods ----------
    
    def _get_variable_desc_string(self):
        """Generates a description string for all variables."""
        return "; ".join(
            f"{v}: {d['desc']}" + (f" (Categories: {d['categories']})" if "categories" in d else "")
            for v, d in self.config.items()
        )

    def build_greedy_marginal_guide(self, v, best):
        """Builds the greedy marginal guide prompt based on whether the variable is categorical or numerical."""
        if "categories" in self.config[v]:
            return GREEDY_MARGINAL_CAT_GUIDE_PROMPT.format(best=best, v=v)
        else:
            return GREEDY_MARGINAL_NUM_GUIDE_PROMPT.format(best=best, v=v)

    def build_marginal_prompt(self, top_diff):
        """Builds the marginal prompt based on the top differences between real and synthetic distributions."""
        var_guide = {}
        for dim, records in top_diff.items():
            if len(records) == 0:
                var_guide[dim] = ""
            else:
                total_real = sum(real_prob - syn_prob for _, real_prob, syn_prob in records)
                if total_real == 0:
                    continue
                percent_parts = [
                    f"{label} ({(real_prob - syn_prob) / total_real * 100:.1f}%)" for label, real_prob, syn_prob in records
                ]
                desc = ", ".join(percent_parts)
                var_guide[dim] = GREEDY_MARGINAL_GUIDE_PROMPT.format(dim=dim, desc=desc)
        return var_guide

    def build_copula_inference(self):
        """Builds the copula inference prompt."""
        return COPULA_INFERENCE_PROMPT.format(
            data_desc=self.variable_desc_string, n_joints=self.args.n_joints
        )

    def build_batch_generation_prompt(
        self,
        joint_guide: str,
        marginal_guide: str,
        n_plans: int,
        n_samples: int,
        is_marginal: bool = False
    ) -> str:
        """
        Builds a batch generation prompt, either for marginal or copula generation.

        Args:
            joint_guide (str): Guide for joint distribution.
            marginal_guide (str): Guide for marginal distribution.
            n_plans (int): Number of plans to generate.
            n_samples (int): Number of samples to generate.
            is_marginal (bool): Whether it's for marginal distribution (default is False).
            
        Returns:
            str: The assembled batch generation prompt.
        """
        # Separate variables into numerical and categorical
        num_vars = [v for v, d in self.config.items() if "categories" not in d]
        cat_vars = [v for v, d in self.config.items() if "categories" in d]

        num_var_str = ", ".join(num_vars) if num_vars else "**No numerical variables for this dataset**"
        cat_var_str = ", ".join(cat_vars) if cat_vars else "**No categorical variables for this dataset**"

        example_json = transform_structure(self.data["0"])

        # Common parameters for the prompt
        params = {
            "real_data_example": example_json, 
            "cat_var": cat_var_str, 
            "num_var": num_var_str,
            "data_desc": self.variable_desc_string, 
            "n_plans": n_plans, 
            "n_samples": n_samples,
        }

        # Choose template based on whether it's marginal or copula generation
        if is_marginal:
            template = GENERATION_MARGINAL_GUIDE_C_COT_BATCH_PROMPT
            params["marginal_guide"] = marginal_guide
        else:
            template = GENERATION_COPULA_GUIDE_C_COT_BATCH_PROMPT
            params.update({
                "joint_guide": joint_guide,
                "marginal_guide": marginal_guide,
            })

        return template.format(**params)

    def build_greedy_copula_guide(
        self,
        variable_groups: Dict[str, List[str]],
        joint_diff: Dict[str, List[Tuple[Tuple[str, ...], float]]]
    ) -> Dict[str, List[str]]:
        """
        For each variable group, turns the top joint-diff combos into prompts that guide
        the generator to produce more samples in underrepresented cells.

        Args:
            variable_groups (dict): Groups of variables.
            joint_diff (dict): Joint distributions with differences.

        Returns:
            dict: A dictionary mapping group keys to a list of guide prompts.
        """
        guides: Dict[str, List[str]] = {}
        for grp_key, combos in joint_diff.items():
            vars_in_grp = variable_groups[grp_key]
            combs: List[str] = []
            for combo, delta in combos:
                conds = [f"{val}" for var, val in zip(vars_in_grp, combo)]
                combs.append(f"[{', '.join(conds)}]")
            
            prompt = f"{GREEDY_COPULA_GUIDE_PROMPT.format(var_comb=vars_in_grp, comb=', '.join(combs))}; "
            guides[grp_key] = prompt

        return guides
