"""
Prepare 100-driver dataset downsized to sf3d-equivalent taxonomy (nc=10),
with clear stats & plots comparing:
  - Original 100-driver distribution (projected to sf3d) vs
  - Re-mapped (post-filter/balance) sf3d distribution
For: total, per-split, and per-class.

Outputs:
- New split files (sf3d nc=10)
- indices/*.csv (per-split + all)
- plots/*.png (per-split + total): per-class bars, vehicle bars, donuts, side-by-side comparisons
- summary.json (includes original and remapped stats)
- mosaic.png (one sample per sf3d class)
"""
__author__ = 'XYZ'


import csv
import json
import logging
import os
import random

from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt

from PIL import Image

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)


## ----------------------------------------------------------
## Mapping utilities
## ----------------------------------------------------------
def load_label_map(label_map_path: str):
  """
  Load forward, reverse, and labels mapping from JSON.
  """
  with open(label_map_path, "r") as f:
    label_map = json.load(f)

  class_mapping = label_map["forward"]          # 100D → sf3d
  sf3d_sources = label_map["reverse"]          # sf3d → [100D]
  sf3d_labels = label_map["labels"]            # sf3d → human-readable
  sf3d_order = sorted(sf3d_labels.keys(), key=lambda x: int(x[1:]))

  return class_mapping, sf3d_sources, sf3d_labels, sf3d_order


## --------------------------
## helpers
## --------------------------
def parse_split_file(split_file):
  """Parse split file: returns [(id, rel_path, class_id_str)]."""
  with open(split_file, "r") as f:
    lines = [l.strip().split() for l in f if l.strip()]
  return lines

def filter_and_remap(entries, class_mapping):
  """Filter entries to mapped classes and remap."""
  out = []
  for idx, rel, _ in entries:
    orig_cls = rel.split("/")[0]
    if orig_cls in class_mapping:
      out.append((idx, rel, class_mapping[orig_cls], orig_cls))
  return out

def write_split_file(sampled, out_path, label_to_numeric):
  with open(out_path, "w") as f:
    for i, (_, rel, sf3d_label, _) in enumerate(sampled):
      f.write(f"{i}\t{rel}\t{label_to_numeric[sf3d_label]}\n")

def write_index_csv(sampled, out_path, label_to_numeric):
  with open(out_path, "w", newline="") as csvfile:
    w = csv.writer(csvfile)
    w.writerow(["index", "rel_path", "class_id", "sf3d_label", "orig_label"])
    for i, (_, rel, sf3d_label, orig_cls) in enumerate(sampled):
      w.writerow([i, rel, label_to_numeric[sf3d_label], sf3d_label, orig_cls])

## ----------------------------------------------------------
## Statistics
## ----------------------------------------------------------
def collect_original_stats(entries):
  """Collect original (100-driver) stats from filenames."""
  stats = {"drivers": set(), "vehicles": set(), "scenes": set(),
           "images": 0, "classes": Counter()}
  for _, rel, _ in entries:
    parts = rel.split("/")
    classname = parts[0]
    fname = parts[-1]
    if fname.startswith("P"):
      segs = fname.split("_")
      if len(segs) >= 3:
        stats["drivers"].add(segs[0])
        stats["vehicles"].add(segs[1])
        stats["scenes"].add(segs[2])
    stats["images"] += 1
    stats["classes"][classname] += 1
  return {
    "total_images": stats["images"],
    "total_drivers": len(stats["drivers"]),
    "total_vehicles": len(stats["vehicles"]),
    "total_scenes": len(stats["scenes"]),
    "per_class": dict(stats["classes"])
  }

## ----------------------------------------------------------
## Balancing
## ----------------------------------------------------------
def balance_grouped_counts(grouped):
  """
  Downsample each multi-source subgroup to avg size.
  Input: dict[sf3d_cls] -> list[(idx, rel, sf3d_label, orig_cls)]
  """
  balanced = {}
  for sf3d_cls, samples in grouped.items():
    sub = defaultdict(list)
    for it in samples:
      sub[it[3]].append(it)
    if len(sub) == 1:
      balanced[sf3d_cls] = samples
    else:
      avg_size = sum(len(v) for v in sub.values()) // len(sub)
      new_samples = []
      for arr in sub.values():
        k = min(len(arr), avg_size)
        new_samples.extend(random.sample(arr, k))
      balanced[sf3d_cls] = new_samples
  return balanced

## --------------------------
## visual style
## --------------------------
def style_axis(ax):
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.spines['left'].set_color("#333")
  ax.spines['bottom'].set_color("#333")
  ax.tick_params(colors="#333")

def create_bar_plot(data, xlabel, ylabel, save_path, title=None):
  """Single series bar plot; keys on x, values on y."""
  plt.figure(figsize=(10, 3))
  items = [(k, data.get(k, 0)) for k in sorted(data.keys())]
  keys, vals = zip(*items) if items else ([], [])
  bars = plt.bar(keys, vals, color="#333", width=0.5)  ## narrower bars + #333
  for b, v in zip(bars, vals):
    plt.text(b.get_x() + b.get_width()/2, b.get_height(),
             str(round(v, 2) if isinstance(v, float) else v),
             ha="center", va="bottom", fontsize=7, color="#333")
  if title: plt.title(title, color="#333", pad=6)
  plt.xlabel(xlabel, color="#333")
  plt.ylabel(ylabel, color="#333")
  style_axis(plt.gca())
  plt.tight_layout()
  plt.savefig(save_path, dpi=300)
  plt.close()

def create_side_by_side_bars(series_a, series_b, x_order, labels, save_path, title=None, ylabel="Count"):
  """
  Side-by-side bars for two series on shared x-axis.
  - x_order: list of keys to enforce order
  - labels: (legend_a, legend_b)
  """
  xa = [series_a.get(k, 0) for k in x_order]
  xb = [series_b.get(k, 0) for k in x_order]

  x = range(len(x_order))
  w = 0.38
  plt.figure(figsize=(11, 3.2))
  ax = plt.gca()
  ax.bar([i - w/2 for i in x], xa, width=w, color="#333", alpha=0.55, label=labels[0])
  ax.bar([i + w/2 for i in x], xb, width=w, color="#333", alpha=0.9,  label=labels[1])

  ax.set_xticks(list(x))
  ax.set_xticklabels(x_order, rotation=0, color="#333")
  ax.set_ylabel(ylabel, color="#333")
  if title: ax.set_title(title, color="#333", pad=6)
  ax.legend(frameon=False)
  style_axis(ax)
  plt.tight_layout()
  plt.savefig(save_path, dpi=300)
  plt.close()

def create_vehicle_plot(entries, save_path):
  vehicle_counts = Counter()
  for _, rel, _, _ in entries:
    fname = rel.split("/")[-1]
    if fname.startswith("P"):
      segs = fname.split("_")
      if len(segs) >= 2:
        vehicle_counts[segs[1]] += 1
  create_bar_plot(vehicle_counts, "Vehicle", "#Images", save_path, title="Per-Vehicle Distribution")

def create_mosaic(sampled, base_dir, output_dir, sf3d_labels, sf3d_order, img_w=150, img_h=150):
  grouped = defaultdict(list)
  for _, rel, sf3d_label, orig_cls in sampled:
    grouped[sf3d_label].append((rel, orig_cls))
  ## fixed 5x2 grid
  n_cols, n_rows = 5, 2
  fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*2.5, n_rows*2.2),
                          gridspec_kw={'wspace':0.05, 'hspace':0.2})
  for i, sf3d_cls in enumerate(sf3d_order):
    row, col = divmod(i, n_cols)
    ax = axs[row, col]
    if not grouped[sf3d_cls]:
      ax.axis("off"); continue
    rel, _ = random.choice(grouped[sf3d_cls])
    img_path = os.path.join(base_dir, rel)
    try:
      img = Image.open(img_path).resize((img_w, img_h))
      ax.imshow(img); ax.axis("off")
      ax.set_title(f"{sf3d_cls}: {sf3d_labels[sf3d_cls]}", fontsize=7, pad=3, color="#111")
    except Exception as e:
      log.warning(f"Could not load {img_path}: {e}")
      ax.axis("off")
  plt.tight_layout()
  Path(output_dir).mkdir(parents=True, exist_ok=True)
  plt.savefig(os.path.join(output_dir, "mosaic.png"), dpi=300, bbox_inches="tight")
  plt.close()

def create_split_donut(splits_summary, save_path):
  sizes = [splits_summary[s]["size"] for s in splits_summary]
  labels = list(splits_summary.keys())
  colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
  fig, ax = plt.subplots(figsize=(4, 4))
  ax.pie(sizes, labels=labels, autopct="%1.1f%%", startangle=90,
         colors=colors, wedgeprops=dict(width=0.35))
  plt.tight_layout()
  plt.savefig(save_path, dpi=300)
  plt.close()

## --------------------------
## balancing logic (unchanged for data build)
## --------------------------
def balance_grouped_counts(grouped):
  """
  Downsample each multi-source subgroup to avg size (keeps counts comparable).
  Input: dict[sf3d_cls] -> list[(idx, rel, sf3d_label, orig_cls)]
  """
  balanced = {}
  for sf3d_cls, samples in grouped.items():
    sub = defaultdict(list)
    for it in samples:
      sub[it[3]].append(it)
    if len(sub) == 1:
      balanced[sf3d_cls] = samples
    else:
      avg_size = sum(len(v) for v in sub.values()) // len(sub)
      new_samples = []
      for arr in sub.values():
        k = min(len(arr), avg_size)
        new_samples.extend(random.sample(arr, k))
      balanced[sf3d_cls] = new_samples
  return balanced

## --------------------------
## projections & summaries
## --------------------------
def project_original_to_sf3d(orig_per_class_22, sf3d_sources, sf3d_order):
  """Project original 22-class counts into sf3d 10-class space."""
  raw = {k: 0 for k in sf3d_order}
  weighted = {k: 0.0 for k in sf3d_order}
  for sf3d, sources in sf3d_sources.items():
    n = len(sources)
    for src in sources:
      c = orig_per_class_22.get(src, 0)
      raw[sf3d] += c
      weighted[sf3d] += c / n if n > 0 else 0
  return raw, weighted

def counts_from_sampled(sampled, sf3d_order):
  """Counts per sf3d class from sampled list of tuples."""
  c = Counter()
  for _, _, sf3d_label, _ in sampled:
    c[sf3d_label] += 1
  return {k: c.get(k, 0) for k in sf3d_order}

def add_dict(a, b):
  out = dict(a)
  for k, v in b.items():
    out[k] = out.get(k, 0) + v
  return out

## --------------------------
## main
## --------------------------
def main(args):
  from_path = args.from_path
  sf3d_summary = args.sf3d_summary
  splits = {"train": args.train_split, "val": args.val_split, "test": args.test_split}
  to_path, seed = args.to_path, args.seed
  label_map = args.label_map

  random.seed(seed)
  with open(sf3d_summary, "r") as f:
    summary_dict = json.load(f)
  label_to_numeric = summary_dict["label_to_numeric"]  ## from sf3d reference summary

  ## load mapping
  class_mapping, sf3d_sources, sf3d_labels, sf3d_order = load_label_map(label_map)

  out_dir = Path(to_path)
  indices_dir = out_dir/"indices"
  plots_dir = out_dir/"plots"
  for d in (out_dir, indices_dir, plots_dir): d.mkdir(parents=True, exist_ok=True)

  new_summary = {
    "name": "100-driver-sf3d",
    "classes": summary_dict["classes"],
    "splits": {},
    "mapping_used": {
      "forward": class_mapping,
      "reverse": sf3d_sources,
      "labels": sf3d_labels,
    },
  }

  ## aggregators (total over splits)
  total_orig_22 = Counter()
  total_remapped_counts = Counter()
  total_weighted_proj = Counter()
  total_raw_proj = Counter()
  aggregated_sampled = []

  for split, split_file in splits.items():
    log.info(f"Processing {split}")
    entries = parse_split_file(split_file)

    ## Original stats (22-class space)
    orig_stats = collect_original_stats(entries)
    orig22 = orig_stats["per_class"]
    for k, v in orig22.items():
      total_orig_22[k] += v

    ## Filter + remap to sf3d
    filtered = filter_and_remap(entries, class_mapping)
    grouped = defaultdict(list)
    for item in filtered:
      grouped[item[2]].append(item)

    balanced = balance_grouped_counts(grouped)

    ## ### inter-class balancing
    ## ### 1. Find the size of the smallest class to use as a target
    ## ###    (excluding empty classes)
    ## class_sizes = [len(v) for v in balanced.values() if v]
    ## if not class_sizes:
    ##   log.warning(f"No samples found for split: {split}. Skipping.")
    ##   continue
    
    ## target_size = min(class_sizes)
    ## log.info(f"[{split}] Inter-class balancing target size: {target_size}")

    ## ### 2. Perform the new inter-class balancing
    ## final_balanced = balance_inter_class(balanced, target_size, seed)

    balanced_flat = [s for arr in balanced.values() for s in arr]

    ## Write new split + index
    split_out = out_dir/f"{Path(split_file).stem}_sf3d_nc10.txt"
    write_split_file(balanced_flat, split_out, label_to_numeric)
    write_index_csv(balanced_flat, indices_dir/f"{split}_index.csv", label_to_numeric)

    ## Per-class (remapped)
    remapped_counts = counts_from_sampled(balanced_flat, sf3d_order)

    ## Project original→sf3d (raw & weighted)
    raw_proj, weighted_proj = project_original_to_sf3d(orig22, sf3d_sources, sf3d_order)

    ## Per-split plots
    create_bar_plot(remapped_counts, "sf3d class", "Count",
                    plots_dir/f"{split}_per_class_remapped.png",
                    title=f"{split.upper()}: Remapped sf3d distribution")
    create_vehicle_plot(balanced_flat, plots_dir/f"{split}_per_vehicle.png")

    ## Side-by-side comparisons (original→sf3d vs remapped)
    create_side_by_side_bars(
      raw_proj, remapped_counts, sf3d_order,
      labels=("Original→sf3d (raw sum)", "Remapped (post-balance)"),
      save_path=plots_dir/f"{split}_orig_vs_remapped_raw.png",
      title=f"{split.upper()}: Original(→sf3d raw) vs Remapped"
    )
    create_side_by_side_bars(
      weighted_proj, remapped_counts, sf3d_order,
      labels=("Original→sf3d (weighted)", "Remapped (post-balance)"),
      save_path=plots_dir/f"{split}_orig_vs_remapped_weighted.png",
      title=f"{split.upper()}: Original(→sf3d weighted) vs Remapped"
    )

    ## Update summary
    new_summary["splits"][split] = {
      "file": str(split_out),
      "size": len(balanced_flat),
      "original_stats": orig_stats,               ## 22-class
      "orig_projected_sf3d_raw": raw_proj,        ## 10-class
      "orig_projected_sf3d_weighted": weighted_proj,
      "per_class_remapped": remapped_counts       ## 10-class
    }

    ## totals (across splits)
    aggregated_sampled.extend(balanced_flat)
    for k in sf3d_order:
      total_remapped_counts[k] += remapped_counts.get(k, 0)
      total_raw_proj[k] += raw_proj.get(k, 0)
      total_weighted_proj[k] += weighted_proj.get(k, 0)

  ## Write all-indices + mosaic + split donut
  write_index_csv(aggregated_sampled, indices_dir/"all_index.csv", label_to_numeric)
  create_mosaic(aggregated_sampled, from_path, str(out_dir), sf3d_labels, sf3d_order)
  create_split_donut(new_summary["splits"], plots_dir/"split_distribution.png")

  ## Total plots (across splits)
  create_side_by_side_bars(
    dict((k, total_raw_proj.get(k, 0)) for k in sf3d_order),
    dict((k, total_remapped_counts.get(k, 0)) for k in sf3d_order),
    sf3d_order,
    labels=("Original→sf3d (raw sum)", "Remapped (post-balance)"),
    save_path=plots_dir/"TOTAL_orig_vs_remapped_raw.png",
    title="TOTAL: Original(→sf3d raw) vs Remapped"
  )
  create_side_by_side_bars(
    dict((k, total_weighted_proj.get(k, 0.0)) for k in sf3d_order),
    dict((k, total_remapped_counts.get(k, 0)) for k in sf3d_order),
    sf3d_order,
    labels=("Original→sf3d (weighted)", "Remapped (post-balance)"),
    save_path=plots_dir/"TOTAL_orig_vs_remapped_weighted.png",
    title="TOTAL: Original(→sf3d weighted) vs Remapped"
  )

  create_bar_plot(
    dict((k, total_remapped_counts.get(k, 0)) for k in sf3d_order),
    "sf3d class", "Count",
    plots_dir/"TOTAL_per_class_remapped.png",
    title="TOTAL: Remapped sf3d distribution"
  )

  ## Persist rolled-up totals in summary.json
  new_summary["totals"] = {
    "original_22_per_class": dict(total_orig_22),
    "orig_projected_sf3d_raw": dict((k, total_raw_proj.get(k, 0)) for k in sf3d_order),
    "orig_projected_sf3d_weighted": dict((k, float(total_weighted_proj.get(k, 0.0))) for k in sf3d_order),
    "remapped_sf3d_per_class": dict((k, total_remapped_counts.get(k, 0)) for k in sf3d_order)
  }

  ## Save summary
  with open(out_dir/"summary.json", "w") as f:
    json.dump(new_summary, f, indent=2)

  print(f"Prepared dataset with stats written to {out_dir}")

def parse_args():
  import argparse
  p = argparse.ArgumentParser(description="Prepare 100-driver → sf3d (nc=10) with comparative stats/plots.")
  p.add_argument('--from', dest='from_path', required=True, help="Base path to 100-driver images (e.g., Day/Cam2)")
  p.add_argument('--from-sf3d-summary', dest='sf3d_summary', required=True, help="Path to sf3d summary.json")
  p.add_argument('--train-split', dest='train_split', required=True)
  p.add_argument('--val-split', dest='val_split', required=True)
  p.add_argument('--test-split', dest='test_split', required=True)
  p.add_argument('--label-map', dest='label_map', required=True, help="JSON with forward/reverse/labels mapping")
  p.add_argument('--to', dest='to_path', default=f'logs/100driver-sf3d-{datetime.now().strftime("%d%m%y_%H%M%S")}')
  p.add_argument('--seed', type=int, default=42)
  return p.parse_args()

if __name__ == '__main__':
  args = parse_args()
  main(args)
