# -*- coding: utf-8 -*-
"""
LLMSynthor: A framework for synthesizing micro-records aligned with macro-statistics.

This script defines the main controller class `LLMSynthor` which orchestrates the
entire data synthesis process. It iteratively interacts with a Large Language Model (LLM)
to generate synthetic data that progressively matches target statistical distributions.
"""

from __future__ import annotations

import os
import json
import copy
from typing import Any, Dict, List, Sequence, Tuple

import pandas as pd
from tqdm import tqdm

from utils.args_utils import get_args
from utils.data import get_data, read_json, write_json
from utils.funcs import (
    analyze_bin_freqs, assign_ctrl_vars, df2json,
    check_marginal_data, check_plan_data,
    check_guide_data, check_joint_data,
    parse_batch_plan, parse_json, sample_from_plan,
    categorical_marginal_table, numerical_marginal_table,
    get_topk_gaps, topk_joint_diffs,
)
from utils.lmclient import _create_llm
from utils.prompt_builder import PromptBuilder
from utils.eval import Evaluator


__all__ = ["LLMSynthor", "main"]


class LLMSynthor:
    """
    Controller for Large Language Model-based Synthetic Data Generation.

    This class manages an iterative process of generating synthetic data records.
    In each iteration, it analyzes the statistical discrepancy between the current
    synthetic dataset and the real data, prompts an LLM to generate a plan for
    a new batch of data that corrects this discrepancy, and then materializes
    the data from that plan.

    Parameters
    ----------
    args   : Namespace
        Parsed command-line arguments containing configuration for the run.
    data   : pd.DataFrame
        The real-world dataset used for statistical guidance and evaluation.
    config : dict
        A dictionary describing the variables (columns) in the data, including
        their types, categories, or ranges.
    llm    : object
        An LLM client instance used for generating responses.
    """
    DEFAULT_RETRIES = 10

    # region: Initialization & Setup
    # =================================================================================

    def __init__(self, args, data: pd.DataFrame, config: Dict[str, Any], llm):
        self.args = args
        self.data = data
        self.config = config
        self.llm = llm

        # --- Core settings from args ---
        self.max_retries: int = getattr(args, "max_retries", self.DEFAULT_RETRIES)
        self.debug: bool = getattr(args, "debug", False)
        self.batch_mode: bool = getattr(args, "batch_mode", False)
        self.n_plans: int = getattr(args, "n_plans", 3)
        self.batch_size: int = getattr(args, "batch_size", 20)

        # --- Paths and data containers ---
        os.makedirs(self.args.result_dir, exist_ok=True)
        self.result_path = os.path.join(self.args.result_dir, "synthetic_data.json")
        self.synthetic_data: Dict[str, Any] = {"metadata": {}, "data": {}, "result": {}}
        self.current_iter: int = 0
        self.real_data = {str(i): r.to_dict() for i, r in data.iterrows()}

        # --- Helper components ---
        self.prompter = PromptBuilder(args, self.real_data, config)
        self.evaluator = Evaluator(config, self.args.result_dir)

        # --- Caches ---
        self.real_C: Dict[str, Any] | None = None  # Cached real control statistics

    # endregion

    # region: Utilities & State Management
    # =================================================================================

    def _dbg(self, label: str, prompt: str, resp: str, parsed: Any | None = None) -> None:
        """Prints detailed debugging information if the --debug flag is enabled."""
        if not self.debug:
            return
        print(f"\n{'='*20} DEBUG – {label} {'='*20}\nPrompt:\n{prompt}\n\nResponse:\n{resp}\n")
        if parsed is not None:
            print("Parsed value:")
            print(json.dumps(parsed, indent=2, ensure_ascii=False) if isinstance(parsed, (dict, list)) else parsed)
        print("=" * 60 + "\n")

    def _save_results(self) -> None:
        """Saves the current state of synthetic_data to a JSON file."""
        if self.args.save:
            data_to_save = copy.deepcopy(self.synthetic_data)
            write_json(self.result_path, data_to_save)

    def _evaluate_and_save(self) -> None:
        """Runs evaluation on the generated data and saves all results."""
        self._save_results()
        if self.args.eval:
            marg_res, joint_res = self.evaluator.evaluate(
                self.real_data, self.synthetic_data["data"]
            )
            self.synthetic_data["result"] = {"marginal_results": marg_res, "joint_results": joint_res}
            self._save_results()

    def _append_sample(self, idx: int, data: Dict[str, Any]) -> None:
        """Adds a new synthetic record to the dataset."""
        self.synthetic_data["data"][str(idx)] = data

    def _record_metadata(self, meta: Dict[str, Any]) -> None:
        """Records metadata for the current generation iteration."""
        self.synthetic_data["metadata"][str(self.current_iter)] = meta
        self.current_iter += 1

    # endregion

    # region: Statistical Preparation
    # =================================================================================

    def prepare_control_variables(self) -> Tuple[Dict[str, Any], Dict[str, str], Dict[str, Any]]:
        """
        Calculates and prepares the statistical distributions for real and synthetic data.

        This function serves as the core of the feedback loop, providing the
        discrepancy information needed to guide the next generation step.

        Returns
        -------
        Tuple[Dict, Dict, Dict]
            - bin_freqs: Binned frequencies for numerical variables.
            - table_map: String-formatted tables comparing real vs. synthetic marginals.
            - dist_summary: Raw dictionaries of real vs. synthetic distributions.
        """
        # Step 1: Calculate control statistics for real data (if not cached) and synthetic data.
        if self.real_C is None:
            self.real_C = assign_ctrl_vars(self.args, df2json(self.data), self.config)
            self.synthetic_data["real_stat"] = self.real_C

        syn_json = list(self.synthetic_data["data"].values())
        syn_C = assign_ctrl_vars(self.args, syn_json, self.config)

        # Step 2: For numerical variables, compute binned frequencies to discretize them.
        bin_freqs: Dict[str, Any] = {}
        for var, desc in self.config.items():
            if "categories" in desc:  # Categorical variable
                bin_freqs[var] = None
                continue
            # Numerical variable
            real_vals = self.data[var].tolist()
            syn_vals = [row.get(var) for row in syn_json] if syn_json else real_vals
            bin_freqs[var] = analyze_bin_freqs(
                real_vals, syn_vals, n_bins=self.args.n_bins,
                n_sub_bins=self.args.n_sub_bins, round_digits=1
            )

        # Step 3: Build marginal summary tables and distribution dictionaries.
        table_map: Dict[str, str] = {}
        dist_summary: Dict[str, Dict[str, Any]] = {"real": {}, "synthetic": {}}
        for var in self.config:
            if "categories" in self.config[var]:
                table = categorical_marginal_table(var, self.real_C[var], syn_C[var])
            else:
                table = numerical_marginal_table(var, self.real_C[var], syn_C[var], bin_freq=bin_freqs.get(var))
            table_map[var] = table
            dist_summary["real"][var] = self.real_C[var]
            dist_summary["synthetic"][var] = syn_C[var]

        return bin_freqs, table_map, dist_summary

    # endregion

    # region: LLM Interaction
    # =================================================================================

    def inference_with_retry(
        self,
        prompts: str | Sequence[str],
        check_types: str | Sequence[str] | None = None,
    ) -> Tuple[List[Any], List[str]]:
        """
        Sends prompts to the LLM and validates the response structure.

        This method acts as a unified interface for all LLM calls, handling
        JSON parsing and type-specific validation. Note: Retries are handled
        in the calling method `_generate_batch`.

        Parameters
        ----------
        prompts     : str or Sequence[str]
            A single prompt or a list of prompts to send to the LLM.
        check_types : str or Sequence[str] or None
            The type of validation to perform on each response ('guide', 'batch',
            'joint', or None).

        Returns
        -------
        Tuple[List[Any], List[str]]
            - A list of parsed data from each response.
            - A list of raw string responses from the LLM.
        """
        # Normalize inputs to lists for consistent processing
        if isinstance(prompts, str):
            prompts = [prompts]
        if isinstance(check_types, str) or check_types is None:
            check_types = [check_types] * len(prompts)

        results_data: List[Any] = [None] * len(prompts)
        results_resp: List[str] = [None] * len(prompts)
        errors: list[Tuple[int, Exception]] = []

        # Perform a single batch API call
        batch_resps: Sequence[str] = self.llm.generate_response(prompts)

        for idx, resp in enumerate(batch_resps):
            ctype = check_types[idx]
            try:
                # Parse and validate the response based on its expected type
                parsed = parse_json(resp) if ctype in {"guide", "batch", "joint"} else resp
                if ctype == "guide":
                    valid, parsed = check_guide_data(parsed)
                    if not valid: raise ValueError("Invalid guide data structure")
                elif ctype == "batch":
                    valid, parsed = check_plan_data(parsed, self.config)
                    if not valid: raise ValueError("Invalid batch plan data structure")
                elif ctype == "joint":
                    valid, parsed = check_joint_data(parsed, self.config)
                    if not valid: raise ValueError("Invalid joint data structure")

                self._dbg("LLM-CALL", prompts[idx], resp, parsed)
                results_data[idx] = parsed
                results_resp[idx] = resp
            except Exception as exc:
                errors.append((idx, exc))

        if errors:
            failed_prompts = ", ".join(f"idx={i}: {e}" for i, e in errors)
            raise RuntimeError(f"LLM call failed for prompts with errors: {failed_prompts}")

        return results_data, results_resp

    # endregion

    # region: Guidance Generation
    # =================================================================================

    def process_marginal(
        self,
        var_desc: Dict[str, Any],
        dist_summary: Dict[str, Any],
        var_bin_freq: Dict[str, Any],
    ) -> Tuple[str, Dict[str, str], Dict[str, str]]:
        """
        Generates marginal guidance by identifying the largest discrepancies.

        This method greedily finds the category or bin for each variable where the
        synthetic data is most deficient compared to the real data and creates
        a textual guide for the LLM to focus on generating records with these features.

        Returns
        -------
        Tuple[str, Dict, Dict]
            - A concatenated string of all marginal control guides.
            - A dictionary of prompts (currently unused, for future extension).
            - A dictionary of responses (currently unused, for future extension).
        """
        C, P, R = {}, {}, {} # Control, Prompt, Response dictionaries
        vars_ = list(var_desc)

        for v in vars_:
            # Case 1: Numerical variable (using binned frequencies)
            if var_bin_freq.get(v):
                freq_info = var_bin_freq[v]
                real_f, syn_f = freq_info["full_real_freq"], freq_info["full_synthetic_freq"]
                positive_bins = {b: real_f[b] - syn_f.get(b, 0) for b in real_f if real_f[b] - syn_f.get(b, 0) > 0}
                if positive_bins:
                    best_bin = max(positive_bins, key=positive_bins.get)
                    C[v] = self.prompter.build_greedy_marginal_guide(v, best_bin)
                else:
                    C[v] = ""
            # Case 2: Categorical variable (using counts)
            else:
                r_cnt, s_cnt = dist_summary["real"][v], dist_summary["synthetic"][v]
                total_r = sum(r_cnt.values())
                total_s = sum(s_cnt.values()) or total_r # Avoid division by zero
                diffs = {c: (r_cnt[c] / total_r) - (s_cnt.get(c, 0) / total_s) for c in r_cnt}
                positive_cats = {c: diff for c, diff in diffs.items() if diff > 0}
                if positive_cats:
                    best_cat = max(positive_cats, key=positive_cats.get)
                    C[v] = self.prompter.build_greedy_marginal_guide(v, best_cat)
                else:
                    C[v] = ""
            P[v], R[v] = "", "" # Placeholder for future prompting strategies

        return "".join(C.values()), P, R

    def process_joint(
        self,
        dist_summary: Dict[str, Any],
        var_bin_freq: Dict[str, Any]
    ) -> Tuple[str, Dict[str, Any]]:
        """
        Generates joint distribution guidance for the LLM.

        This involves a two-step process:
        1.  Copula Inference: Ask the LLM to identify groups of variables that are
            likely to be correlated.
        2.  Discrepancy Guidance: Find the specific combinations of values within
            these groups that are underrepresented in the synthetic data and
            create a textual guide.

        Returns
        -------
        Tuple[str, Dict]
            - A string containing the joint distribution control guide.
            - A metadata dictionary logging the process.
        """
        # Step 1: Infer correlated variable groups from the LLM.
        copula_inf_p = self.prompter.build_copula_inference()
        [variable_groups], _ = self.inference_with_retry([copula_inf_p], ["joint"])

        # Step 2: Identify top-k joint discrepancies for these groups.
        joint_diff = topk_joint_diffs(
            self.real_data, self.synthetic_data["data"],
            variable_groups, var_bin_freq, k=self.args.topk_diff,
        )

        # Step 3: Build a textual guide based on these discrepancies.
        joint_C_str = self.prompter.build_greedy_copula_guide(variable_groups, joint_diff)
        joint_C = "".join(joint_C_str.values())

        meta = {
            "copula_inference_prompt": copula_inf_p,
            "copula_inference_response": variable_groups,
            "joint_C": joint_C,
        }
        return joint_C, meta

    # endregion

    # region: Core Generation Loop
    # =================================================================================

    def _generate_batch(
        self,
        start_idx: int,
        var_desc: Dict[str, Any],
        dist_summary: Dict[str, Any],
        var_bin_freq: Dict[str, Any],
    ) -> int:
        """
        Generates one batch of synthetic data.

        This method orchestrates the core logic for a single iteration:
        1.  Build marginal and joint guidance based on current statistics.
        2.  Prompt the LLM with this guidance to get a generation "plan".
        3.  Materialize synthetic records by sampling from the plan.
        4.  Record all metadata for the iteration.

        Returns
        -------
        int
            The index for the next sample to be generated.
        """
        def _apply_plans(
            start_idx: int, plans: List[Tuple[Dict[str, Any], int]], batch_target: int
        ) -> Tuple[int, List[Dict[str, Any]]]:
            """Materializes records from the LLM's plan and returns the next index."""
            idx, generated, ranges = start_idx, 0, []
            for plan_dict, num_to_generate in plans:
                segment_start_idx = idx
                for _ in range(num_to_generate):
                    if idx >= self.args.n_samples or generated >= batch_target: break
                    self._append_sample(idx, sample_from_plan(plan_dict, self.config, self.data))
                    idx += 1
                    generated += 1
                ranges.append({"plan": plan_dict, "plan_num": num_to_generate, "idx_range": f"{segment_start_idx}_{idx - 1}"})
                if idx >= self.args.n_samples or generated >= batch_target: break
            return idx, ranges

        # --- Main generation logic with retry mechanism ---
        for attempt in range(1, self.max_retries + 1):
            try:
                # Step 1: Build marginal and joint guidance
                marg_C, marg_P, marg_R = self.process_marginal(var_desc, dist_summary, var_bin_freq)
                joint_C, joint_meta = None, {}
                if self.args.run_type != "marginal":
                    joint_C, joint_meta = self.process_joint(dist_summary, var_bin_freq)

                # Step 2: Prompt LLM for a generation plan
                total_remaining = self.args.n_samples - len(self.synthetic_data["data"])
                batch_target = min(self.batch_size, total_remaining)
                gen_prompt = self.prompter.build_batch_generation_prompt(
                    joint_C, marg_C, self.n_plans, batch_target, self.args.run_type == "marginal"
                )
                [plan_json], [raw_plan_text] = self.inference_with_retry([gen_prompt], ["batch"])
                plans = parse_batch_plan(plan_json)
                if not plans:
                    raise ValueError("LLM returned an empty plan list.")
                break # Success, exit retry loop
            except Exception as exc:
                print(f"[Batch Generation Attempt {attempt}/{self.max_retries}] Failed: {exc!r}")
                if attempt == self.max_retries:
                    raise # Rethrow exception if all retries fail

        self._dbg("BATCH-PLAN", gen_prompt, raw_plan_text, plan_json)

        # Step 3: Materialize samples from the generated plan
        next_idx, plan_ranges = _apply_plans(start_idx, plans, batch_target)

        # Step 4: Record metadata for this iteration
        meta = {
            "marginal": {"var_desc": var_desc, "control": marg_C, "prompt": marg_P, "response": marg_R},
            "joint": {"var_stat": dist_summary, "control": joint_C, **joint_meta},
            "full": {"prompt": gen_prompt, "response": raw_plan_text, "idx_range": f"{start_idx}_{next_idx - 1}", "plans": plan_ranges},
        }
        self._record_metadata(meta)
        return next_idx

    def run(self) -> None:
        """
        Main entry point to start the data synthesis process.

        This method initializes the generation loop, handles resuming from a previous
        run, and orchestrates the overall progress until the desired number of samples
        is generated.
        """
        last_idx = 0
        # Check if a previous run exists and load it to resume generation.
        if os.path.exists(self.result_path):
            self.synthetic_data = read_json(self.result_path)
            last_idx = max(map(int, self.synthetic_data["data"].keys()), default=-1) + 1
            self.current_iter = len(self.synthetic_data["metadata"])
            if last_idx >= self.args.n_samples:
                print("Target number of samples already generated. Exiting.")
                if self.args.eval:
                    self._evaluate_and_save()
                return

        # --- Main generation loop with progress bar ---
        with tqdm(total=self.args.n_samples, initial=last_idx, desc="Overall generation progress", unit="samples") as pbar:
            while last_idx < self.args.n_samples:
                print(f"\n--- Iteration {self.current_iter}: Generating batch starting at index {last_idx} ---")
                
                # Prepare statistical feedback for the LLM
                var_bin_freq, var_desc, dist_summary = self.prepare_control_variables()
                
                prev_idx = last_idx
                last_idx = self._generate_batch(
                    last_idx, var_desc, dist_summary, var_bin_freq
                )
                
                pbar.update(last_idx - prev_idx)
                self._evaluate_and_save()

        print("Synthetic data generation finished successfully.")

    # endregion


def main() -> None:
    """CLI wrapper: parses arguments, loads data, creates an LLMSynthor instance, and runs it."""
    args = get_args()
    data, config = get_data(args)
    llm = _create_llm(args)
    synthesizer = LLMSynthor(args, data, config, llm)
    synthesizer.run()


if __name__ == "__main__":
    main()