"""Cluster visualization & dashboard module.
Generates static plots, interactive dashboards, and enriched captions for cluster analysis.
"""
__author__ = 'XYZ'


import json

from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import plotly.io as pio

from sklearn.manifold import TSNE

plt.rcParams.update({
  "axes.spines.top": False,
  "axes.spines.right": False,
  "axes.edgecolor": "#333",
  "axes.labelcolor": "#333",
  "xtick.color": "#333",
  "ytick.color": "#333",
  "legend.edgecolor": "#333",
  "legend.labelcolor": "#333",
  "text.color": "#333",
  "axes.titlesize": 12,  # safe default
})


def _plot_clusters(entries, cluster_ids, out_dir):
  confs = [e["conf"] for e in entries]
  idxs = [e["idx"] for e in entries]

  plt.figure(figsize=(6, 4))
  scatter = plt.scatter(idxs, confs, c=cluster_ids, cmap="tab10", alpha=0.7)
  plt.colorbar(scatter, label="Cluster ID")
  plt.xlabel("Sample Index")
  plt.ylabel("Confidence")
  out_file = out_dir / "scatter.png"
  plt.savefig(out_file, dpi=300, bbox_inches="tight")
  plt.close()
  return out_file


def _plot_histogram(entries, cluster_ids, out_file):
  """Histogram of confidence distributions per cluster (sorted by cluster ID)."""
  confs = [e["conf"] for e in entries]
  unique_clusters = sorted(set(cluster_ids))

  grouped = [[confs[i] for i in range(len(confs)) if cluster_ids[i] == cid]
             for cid in unique_clusters]

  plt.figure(figsize=(6, 4))
  plt.hist(
    grouped,
    bins=20,
    stacked=True,
    label=[f"Cluster {cid}" for cid in unique_clusters]
  )
  plt.legend()
  plt.xlabel("Confidence")
  plt.ylabel("Count")
  plt.savefig(out_file, dpi=200, bbox_inches="tight")
  plt.close()
  return out_file


def _plot_tsne(logits, cluster_ids, out_file):
  """2D t-SNE visualization of logits colored by cluster."""
  tsne = TSNE(n_components=2, random_state=42, perplexity=30)
  X_2d = tsne.fit_transform(logits)
  plt.figure(figsize=(6, 6))
  plt.scatter(X_2d[:, 0], X_2d[:, 1], c=cluster_ids, cmap="tab10", alpha=0.6)
  plt.savefig(out_file, dpi=200, bbox_inches="tight")
  plt.close()
  return out_file


def _plot_confidence_band(entries, out_file, centroid, std):
  confs = [e["conf"] for e in entries]
  idxs = [e["idx"] for e in entries]

  plt.figure(figsize=(6, 4))
  plt.scatter(idxs, confs, color="#1f77b4", alpha=0.6, label="Samples")
  plt.axhline(y=centroid, color="red", linestyle="--", linewidth=1.0,
              label=f"Centroid {centroid:.2f}")
  plt.fill_between(idxs, centroid - std, centroid + std,
                   color="red", alpha=0.1, label=f"±1σ ({std:.2f})")
  plt.xlabel("Sample Index")
  plt.ylabel("Confidence")
  plt.legend(frameon=False)
  plt.savefig(out_file, dpi=300, bbox_inches="tight")
  plt.close()
  return out_file



def _generate_caption(entries, cluster_ids, out_file):
  """Generate enriched captions as markdown for cluster plots."""
  num_clusters = len(set(cluster_ids))
  num_samples = len(entries)
  avg_conf = sum(e["conf"] for e in entries) / max(1, num_samples)

  caption = f"""
# Cluster Analysis (Confidence)

- **Samples analyzed**: {num_samples}
- **Clusters formed**: {num_clusters}
- **Average confidence**: {avg_conf:.2f}

**Interpretation**:  
Confidence band and scatter show how predictions distribute across levels.  
Low-confidence clusters may indicate ambiguity or noise.  
Tight high-confidence clusters suggest strong model agreement.
"""
  Path(out_file).write_text(caption.strip())
  return out_file


def _interactive_dashboard(entries, cluster_ids, out_dir):
  """Generate interactive 2D/3D dashboards using Plotly."""
  df = pd.DataFrame(entries)
  df["cluster"] = cluster_ids

  ## 2D Scatter
  fig2d = px.scatter(df, x="idx", y="conf", color="cluster",
                     hover_data=["path", "gt", "pr"],
                     title="Confidence Scatter (Interactive)")

  fig2d.update_layout(
      template="simple_white",
      xaxis=dict(showline=True, linewidth=1, linecolor="#333", mirror=True),
      yaxis=dict(showline=True, linewidth=1, linecolor="#333", mirror=True),
      font=dict(color="#333"),
      showlegend=True
  )

  html2d = out_dir / "scatter.html"
  pio.write_html(fig2d, file=html2d, auto_open=False)

  ## 3D Scatter (z = gt for now, replace with embeddings later)
  fig3d = px.scatter_3d(df, x="idx", y="conf", z="gt",
                        color="cluster", hover_data=["path", "pr"],
                        title="3D Confidence Clusters")
  html3d = out_dir / "scatter3d.html"
  pio.write_html(fig3d, file=html3d, auto_open=False)

  return [html2d, html3d]


def generate_dashboard(context, **kwargs):
  """Generate static plots + dashboards + captions into plots/ directory."""
  model_out = Path(context["to_path"])
  clusters_dir = model_out / "clusters"

  def _process_csv(csv_file, out_dir):
    if not csv_file.exists():
      return None
    df = pd.read_csv(csv_file)
    if "cluster" not in df.columns:
      return None

    entries = df.to_dict("records")
    cluster_ids = df["cluster"].tolist()

    out_dir.mkdir(parents=True, exist_ok=True)

    scatter_file = _plot_clusters(entries, cluster_ids, out_dir)
    band_file = _plot_confidence_band(entries, out_dir / "band.png",
                                      centroid=df["conf"].mean(),
                                      std=df["conf"].std())
    hist_file = _plot_histogram(entries, cluster_ids, out_dir / "hist.png")
    caption_file = _generate_caption(entries, cluster_ids, out_dir / "band_caption.md")
    html_files = _interactive_dashboard(entries, cluster_ids, out_dir)

    meta = {
      "scatter": str(scatter_file),
      "band_plot": str(band_file),
      "histogram": str(hist_file),
      "caption": str(caption_file),
      "interactive": [str(f) for f in html_files]
    }
    with open(out_dir / "plots.json", "w") as f:
      json.dump(meta, f, indent=2)
    return meta

  # ---- Overall (merged CSV)
  overall_csv = clusters_dir / "clusters.csv"
  _process_csv(overall_csv, clusters_dir / "plots/overall")

  # ---- Per correctness + per-class
  for correctness in ["correct", "incorrect"]:
    split_csv = clusters_dir / "analysis" / correctness / f"clusters.{correctness}.csv"
    _process_csv(split_csv, clusters_dir / f"plots/{correctness}")

    per_class_dir = clusters_dir / "analysis" / correctness / "per_class"
    if per_class_dir.exists():
      for cls_dir in per_class_dir.iterdir():
        if not cls_dir.is_dir():
          continue
        cls_id = cls_dir.name
        cls_csv = cls_dir / f"clusters.{correctness}-{cls_id}.csv"
        _process_csv(cls_csv, clusters_dir / f"plots/{correctness}/per_class/{cls_id}")

  return context
