"""Clustering module for confidence-based bucketization.
Supports both standalone execution and orchestrator-based context.
"""
__author__ = 'XYZ'


import argparse
import os
import json
import sys

from datetime import datetime
from pathlib import Path

import pandas as pd
import numpy as np

from sklearn import cluster as sk_cluster

from .core._log_ import logger
log = logger(__file__)


this = sys.modules[__name__]


def _round(val):
    return round(float(val), 2)

def _read_index_file(file_path):
  """Read index-style file (tab separated). Returns list of dicts with gt, pred, conf, line."""
  entries = []
  if not os.path.exists(file_path):
    return entries
  with open(file_path, "r") as f:
    for ln in f:
      parts = ln.strip().split("\t")
      if len(parts) < 5:
        continue
      idx, relpath, gt, pr, conf = parts
      entries.append({
        "line": ln.strip(),
        "idx": int(idx),
        "path": relpath,
        "gt": int(gt),
        "pr": int(pr),
        "conf": float(conf),
      })
  return entries


def _cluster_entries(entries, method="KMeans", **kwargs):
  """Cluster entries by confidence score using sklearn.cluster API."""
  if not entries:
    return {}

  X = np.array([[e["conf"]] for e in entries])

  method = method.strip()
  if method.lower() == "kmeans":
    method = "KMeans"

  if not hasattr(sk_cluster, method):
    raise NotImplementedError(f"Clustering method `{method}` not found in sklearn.cluster")

  ClusterCls = getattr(sk_cluster, method)

  ## Handle count vs n_clusters conflict
  if "count" in kwargs and "n_clusters" in kwargs:
    log.warning("Both `count` and `n_clusters` provided; using `n_clusters` and ignoring `count`")
    kwargs.pop("count")
  elif "count" in kwargs:
    kwargs["n_clusters"] = kwargs.pop("count")

  ## Auto-adjust cluster count if > samples
  n_samples = len(X)
  if "n_clusters" in kwargs and kwargs["n_clusters"] > n_samples:
    log.warning(f"Reducing n_clusters from {kwargs['n_clusters']} to {n_samples} (only {n_samples} samples)")
    kwargs["n_clusters"] = n_samples

  clusterer = ClusterCls(**kwargs)
  cluster_ids = clusterer.fit_predict(X)

  clustered = {}
  for e, cid in zip(entries, cluster_ids):
    clustered.setdefault(cid, []).append(e["line"])
  return clustered


def load_predictions(context, **kwargs):
  """Flowplan stub to align with YAML. Currently returns context as-is."""
  return context


def _save_clustered_files(clustered, out_dir, prefix, correctness, class_id=None, as_csv=False):
  """Save clustered entries into structured CSV files (no txt dumps)."""
  out_dir.mkdir(parents=True, exist_ok=True)
  merged_records = []

  for cid, lines in clustered.items():
    records = []
    for ln in lines:
      parts = ln.strip().split("\t")
      if len(parts) >= 5:
        idx, relpath, gt, pr, conf = parts[:5]
        rec = {
          "idx": int(idx),
          "path": relpath,
          "gt": int(gt),
          "pr": int(pr),
          "conf": float(conf),
          "cluster": cid,
          "correctness": correctness,
          "class_id": class_id if class_id is not None else "overall",
        }
        records.append(rec)
        merged_records.append(rec)

    ## Save per-cluster CSV if requested
    if as_csv and records:
      df = pd.DataFrame(records)
      csv_file = out_dir / f"{prefix}_cluster{cid}.csv"
      df.to_csv(csv_file, index=False)

  return merged_records


def cluster_confidence(context, method="kmeans", count=4, **kwargs):
  """Cluster correct/incorrect bins overall + per-class, save structured outputs."""
  model_out = Path(context["to_path"])
  clusters_dir = model_out / "clusters"
  clusters_dir.mkdir(exist_ok=True)

  all_records = []

  for correctness in ["correct", "incorrect"]:
    base_file = model_out / f"{correctness}.txt"
    if not base_file.exists():
      continue

    entries = _read_index_file(base_file)

    ## --- overall ---
    overall_out = clusters_dir / correctness / "overall"
    clustered = _cluster_entries(entries, method=method, count=count, **kwargs)
    all_records.extend(_save_clustered_files(
      clustered,
      overall_out,
      prefix=correctness,
      correctness=correctness,
      class_id="overall",
      as_csv=True
    ))

    ## --- per-class ---
    per_class_dir = clusters_dir / correctness / "per_class"
    for cls_file in (model_out / f"{correctness}_per_class").glob("c*.txt"):
      cls_id = cls_file.stem  # e.g., "c0"
      cls_entries = _read_index_file(cls_file)
      clustered_cls = _cluster_entries(cls_entries, method=method, count=count, **kwargs)
      all_records.extend(_save_clustered_files(
        clustered_cls,
        per_class_dir / cls_id,  # dedicated subdir for class
        prefix=cls_id,
        correctness=correctness,
        class_id=cls_id,
        as_csv=True
      ))

  ## ---- Save merged CSV for downstream save_clusters ----
  if all_records:
    df_all = pd.DataFrame(all_records)
    df_all["method"] = method
    if kwargs:
      for k, v in kwargs.items():
        df_all[f"param_{k}"] = v

    merged_csv = clusters_dir / "clusters.csv"
    df_all.to_csv(merged_csv, index=False)

    log.info(f"Merged clusters.csv saved at {merged_csv}")

  log.info(f"Clustering completed. Results saved under {clusters_dir}")
  return context


def save_clusters(context, **kwargs):
  """Save enhanced cluster metadata and stats into analysis/ directory."""
  clusters_dir = Path(context["to_path"]) / "clusters"
  analysis_dir = clusters_dir / "analysis"
  analysis_dir.mkdir(parents=True, exist_ok=True)

  # Load merged CSV created earlier
  merged_csv = clusters_dir / "clusters.csv"
  if not merged_csv.exists():
    log.warning(f"No merged clusters.csv found under {clusters_dir}")
    return context

  df = pd.read_csv(merged_csv)
  method = kwargs.get("method", "kmeans")
  timestamp = datetime.now().strftime("%d%m%y_%H%M%S")

  # ---- Save aggregate (all samples) ----
  agg_csv = analysis_dir / f"clusters.all.csv"
  agg_json = analysis_dir / f"clusters.all.json"
  df.to_csv(agg_csv, index=False)
  with open(agg_json, "w") as f:
    json.dump({
      "meta": {
        "method": method,
        "params": {k: v for k, v in kwargs.items()},
        "timestamp": timestamp,
        "total_samples": int(len(df)),
      },
      "interpretation": "All samples (correct + incorrect) clustering summary."
    }, f, indent=2)

  # ---- Per correctness (correct, incorrect) ----
  for correctness in ["correct", "incorrect"]:
    split_dir = analysis_dir / correctness
    split_dir.mkdir(parents=True, exist_ok=True)

    df_split = df[df["correctness"] == correctness]
    if df_split.empty:
      continue

    ## Save split-level CSV + JSON
    split_csv = split_dir / f"clusters.{correctness}.csv"
    split_json = split_dir / f"clusters.{correctness}.json"
    df_split.to_csv(split_csv, index=False)

    clusters_info = []
    for cid, df_c in df_split.groupby("cluster"):
      confs = df_c["conf"].tolist()
      stats = {
        "id": int(cid),
        "count": len(df_c),
        "centroid": _round(df_c["conf"].mean()),
        "confidence": {
          "min": _round(np.min(confs)),
          "max": _round(np.max(confs)),
          "avg": _round(np.mean(confs)),
          "std": _round(np.std(confs)),
        },
        "classes": {k: int(v) for k, v in df_c["class_id"].value_counts().to_dict().items()},
      }
      clusters_info.append(stats)

    with open(split_json, "w") as f:
      json.dump({
        "meta": {
          "method": method,
          "params": {k: v for k, v in kwargs.items()},
          "timestamp": timestamp,
          "total_samples": int(len(df_split)),
        },
        "clusters": clusters_info,
        "interpretation": (
          f"{correctness.capitalize()} predictions clustered into "
          f"{len(clusters_info)} groups. "
          "Low-confidence clusters may indicate ambiguity or noise."
        )
      }, f, indent=2)

    ## ---- Per class inside correctness ----
    per_class_dir = split_dir / "per_class"
    per_class_dir.mkdir(parents=True, exist_ok=True)

    for cls_id, df_cls in df_split.groupby("class_id"):
      class_dir = per_class_dir / cls_id
      class_dir.mkdir(parents=True, exist_ok=True)

      cls_csv = class_dir / f"clusters.{correctness}-{cls_id}.csv"
      cls_json = class_dir / f"clusters.{correctness}-{cls_id}.json"
      df_cls.to_csv(cls_csv, index=False)

      clusters_info = []
      for cid, df_c in df_cls.groupby("cluster"):
        confs = df_c["conf"].tolist()
        stats = {
          "id": int(cid),
          "count": len(df_c),
          "centroid": _round(df_c["conf"].mean()),
          "confidence": {
            "min": _round(np.min(confs)),
            "max": _round(np.max(confs)),
            "avg": _round(np.mean(confs)),
            "std": _round(np.std(confs)),
          }
        }
        clusters_info.append(stats)

      with open(cls_json, "w") as f:
        json.dump({
          "meta": {
            "method": method,
            "params": {k: v for k, v in kwargs.items()},
            "timestamp": timestamp,
            "total_samples": int(len(df_cls)),
            "class_id": cls_id,
          },
          "clusters": clusters_info,
          "interpretation": (
            f"Class {cls_id} ({correctness}) predictions clustered into "
            f"{len(clusters_info)} groups."
          )
        }, f, indent=2)

  log.info(f"Enhanced cluster outputs saved under {analysis_dir}")
  return context


def _collect_model_dirs(base_path, include_all):
  """Return a list of model directories to process."""
  ### Normalize base directory
  base = Path(base_path)

  ### If --all is set, iterate subdirectories under base
  if include_all:
    dirs = []
    for d in base.iterdir():
      #### Consider only directories that look like model outputs
      if d.is_dir():
        has_any = (d / "correct.txt").exists() or (d / "incorrect.txt").exists()
        if has_any:
          dirs.append(d)
    return dirs

  ### Otherwise treat base as a single model directory
  return [base]


def generate_dashboard(context, **kwargs):
  from . import cluster_plots

  """Wrapper: call cluster_plots.generate_dashboard."""
  return cluster_plots.generate_dashboard(context, **kwargs)


def main(args):
  """Main entrypoint for standalone clustering (generic, no branching per mode)."""
  fn_name = getattr(args, "fn", "cluster_confidence")

  try:
    fn_cluster = getattr(this, fn_name)
  except AttributeError:
    log.error(f"Clustering function `{fn_name}` not found in {__name__}")
    sys.exit(1)

  ## Parse extra clustering params (optional JSON via --cluster-params)
  params = {}
  raw_params = getattr(args, "cluster_params", None)
  if raw_params:
    try:
      params = json.loads(raw_params)
    except Exception as e:
      log.error(f"Failed to parse --cluster-params JSON: {e}")
      sys.exit(1)

  ## Always pass `count` separately (backward compat)
  if getattr(args, "count", None) is not None:
    params["count"] = args.count

  model_dirs = _collect_model_dirs(args.from_path, getattr(args, "all", False))
  if not model_dirs:
    log.error(f"No valid model directories found under: {args.from_path}")
    sys.exit(1)

  for mdir in model_dirs:
    context = {"to_path": str(mdir)}
    fn_cluster(context, method=args.method, **params)
    save_clusters(context, method=args.method)
    generate_dashboard(context, method=args.method)


def parse_args():
  parser = argparse.ArgumentParser(description="Clustering module")
  parser.add_argument('--from', dest='from_path', required=True, help="Path to inference outputs (model dir or base dir with multiple models)")
  parser.add_argument('--to', dest='to_path', help="Optional output directory (unused in orchestrator mode, clusters go inside model dirs)")
  parser.add_argument('--method', type=str, default="KMeans", help="Clustering method from sklearn.cluster (e.g., KMeans, DBSCAN)")
  parser.add_argument('--fn', type=str, default='cluster_confidence', help="Backend function to call (e.g., cluster_confidence)")
  parser.add_argument('--cluster-params', type=str, help="JSON string of extra sklearn cluster params, e.g. '{\"n_clusters\":6}'")
  parser.add_argument('--count', type=int, default=4, help="Number of clusters")
  parser.add_argument('--all', action='store_true', help="If set, treat --from as base dir containing multiple model subdirs")

  args = parser.parse_args()
  return args


def print_args(args):
  """Print parsed arguments."""
  print("Arguments:")
  for k, v in vars(args).items():
    print(f"{k}: {v}")


if __name__ == "__main__":
  args = parse_args()
  print_args(args)
  main(args)
