"""
SF3D dataset preparation: split, statistics, and visualization.

Features:
- Dataset split (train/val/test) by driver or image.
- Per-class and per-driver statistics.
- Custom visualizations and mosaic generation.
"""
__author__ = 'XYZ'


import argparse
import logging
import os
import json
import random

from collections import Counter
from datetime import datetime
from pathlib import Path

import pandas as pd

from sklearn.model_selection import train_test_split
from tqdm import tqdm

try:
  import torch
  from torch.utils.data import DataLoader, Dataset
  from torch.utils.data.dataloader import default_collate
  from torchvision import transforms
  from PIL import Image
  import matplotlib.pyplot as plt
except ImportError:
  print("PyTorch, torchvision, and matplotlib are required.")

from ..core.yio import yml_dump

## Set up logger
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)

skipped_images = []

def log_skipped_image(img_path, error):
  global skipped_images
  logger.warning(f"Skipping missing or corrupted image: {img_path}. Error: {error}")
  skipped_images.append(img_path)

class SF3DDataset(Dataset):
  """Custom PyTorch Dataset for SF3D."""
  def __init__(self, data, root_dir, transform=None):
    self.data = data
    self.root_dir = root_dir
    self.transform = transform

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    row = self.data.iloc[idx]
    label = row["classname"]
    img_path = os.path.join(self.root_dir, label, row["img"])

    try:
      image = Image.open(img_path).convert("RGB")
      if self.transform:
        image = self.transform(image)
    except (FileNotFoundError, OSError) as e:
      logger.warning(f"Skipping missing or corrupted image: {img_path}. Error: {e}")
      return None

    return image, label

def generate_label_mappings_from_file(base_dir):
  """Generates numeric encoding and reverse encoding dynamically from labels.csv."""
  labels_mapping = load_labels(base_dir)  ## Dynamically load the labels from labels.csv
  labels = list(labels_mapping.keys())   ## Get the list of label IDs (e.g., c0, c1, ...)
  label_to_numeric = {label: idx for idx, label in enumerate(labels)}
  numeric_to_label = {idx: label for label, idx in label_to_numeric.items()}
  return label_to_numeric, numeric_to_label

def load_labels(base_dir):
  """Load label mapping from labels.csv."""
  labels_csv_path = os.path.join(base_dir, "labels.csv")
  if not os.path.exists(labels_csv_path):
    raise FileNotFoundError(f"Error: '{labels_csv_path}' not found.")
  
  df = pd.read_csv(labels_csv_path, header=None, names=["label", "classname"])
  return {row["label"]: row["classname"] for _, row in df.iterrows()}

def split_data_by_driver(df, split_ratio=(0.7, 0.2, 0.1), random_state=42):
  """Split data into train, val, and test sets by driver."""
  unique_drivers = df["subject"].unique()
  train_drivers, temp_drivers = train_test_split(
    unique_drivers, test_size=(1 - split_ratio[0]), random_state=random_state
  )
  val_drivers, test_drivers = train_test_split(
    temp_drivers, test_size=(split_ratio[2] / (split_ratio[1] + split_ratio[2])), random_state=random_state
  )
  train_df = df[df["subject"].isin(train_drivers)]
  val_df = df[df["subject"].isin(val_drivers)]
  test_df = df[df["subject"].isin(test_drivers)]
  return train_df, val_df, test_df


def split_data_by_image(df, split_ratio=(0.7, 0.2, 0.1), random_state=42):
  """Split data into train, val, and test sets by images."""
  train_df, temp_df = train_test_split(
    df, test_size=(1 - split_ratio[0]), random_state=random_state, stratify=df["classname"]
  )
  val_df, test_df = train_test_split(
    temp_df, test_size=(split_ratio[2] / (split_ratio[1] + split_ratio[2])),
    random_state=random_state,
    stratify=temp_df["classname"]
  )
  return train_df, val_df, test_df

def collate_fn(batch):
  """Custom collate function to handle missing files."""
  batch = [item for item in batch if item is not None]
  return default_collate(batch) if batch else None

def setup_transforms(input_size, phase="train"):
  """Set up transformations for data augmentation."""
  if phase == "train":
    return transforms.Compose([
      transforms.Resize(input_size),
      transforms.RandomHorizontalFlip(),
      transforms.RandomRotation(15),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
  return transforms.Compose([
    transforms.Resize(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

def get_dataloader(df, root_dir, batch_size, num_workers, input_size, phase="train"):
  """Create a DataLoader for a specific split."""
  transform = setup_transforms(input_size, phase=phase)
  dataset = SF3DDataset(data=df, root_dir=root_dir, transform=transform)
  return DataLoader(dataset, batch_size=batch_size, shuffle=(phase == "train"), num_workers=num_workers, collate_fn=collate_fn)

def calculate_mean_std(loader):
  """Calculate mean and std for dataset normalization."""
  mean = torch.zeros(3)
  std = torch.zeros(3)
  total_images = 0

  for batch in tqdm(loader, desc="Calculating mean and std"):
    if batch is None:
      continue

    ## Normalize pixel values to [0, 1]
    images, _ = batch
    images = images / 255.0  ## Explicitly scale to [0, 1]

    total_images += images.size(0)
    mean += images.mean(dim=(0, 2, 3)) * images.size(0)
    std += images.std(dim=(0, 2, 3)) * images.size(0)

  mean /= total_images
  std /= total_images

  return mean, std

def generate_statistics(train_df, val_df, test_df, labels_mapping, shuffle, seed, split_method, mean, std, label_to_numeric, numeric_to_label):
  """Generate dataset statistics for splits and classes."""
  stats = {
    "name": "sf3d",
    "basepath": os.path.abspath(args.from_path),
    "shuffle": shuffle,
    "seed": seed,
    "split_by_method": split_method,
    "splits": ["train", "val", "test"],
    "total_images": len(train_df) + len(val_df) + len(test_df),
    "num_classes": len(labels_mapping),
    "total_images_per_split": {
      "train": len(train_df),
      "val": len(val_df),
      "test": len(test_df),
    },
    "classes_per_split": {
      "train": {label: int(count) for label, count in train_df["classname"].value_counts().items()},
      "val": {label: int(count) for label, count in val_df["classname"].value_counts().items()},
      "test": {label: int(count) for label, count in test_df["classname"].value_counts().items()},
    },
    "total_drivers_per_split": {
      "train": train_df["subject"].nunique(),
      "val": val_df["subject"].nunique(),
      "test": test_df["subject"].nunique(),
    },
    "labels": list(labels_mapping.keys()),  ## List of labels ["c0", "c1", ..., "c9"]
    "classes": labels_mapping,  ## Mapping of label ids to class names
    "mean": mean.tolist() if mean is not None else None,
    "std": std.tolist() if std is not None else None,
    "label_to_numeric": label_to_numeric,
    "numeric_to_label": numeric_to_label
  }
  return stats

def create_bar_plot(data, title, xlabel, ylabel, save_path):
  """Generate a bar plot with enhanced styling."""
  sorted_data = dict(sorted(data.items()))
  plt.figure(figsize=(10, 6))

  ## Bar plot
  bars = plt.bar(sorted_data.keys(), sorted_data.values(), color='skyblue', alpha=0.9, edgecolor='#333')

  ## Add labels on top of bars
  for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, height, f"{int(height)}", ha='center', va='bottom', fontsize=8, color='#333')

  ## Line plot on top
  keys = list(sorted_data.keys())
  values = list(sorted_data.values())
  plt.plot(keys, values, color='red', marker='o', linewidth=1.5, alpha=0.7)

  ## Styling
  # plt.title(title, fontsize=12, color='#333')
  plt.xlabel(xlabel, fontsize=10, color='#333')
  plt.ylabel(ylabel, fontsize=10, color='#333')
  plt.xticks(rotation=45, ha='right', fontsize=8, color='#333')
  plt.yticks(fontsize=8, color='#333')

  ax = plt.gca()
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.spines['left'].set_color('#333')
  ax.spines['bottom'].set_color('#333')

  plt.tight_layout()
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
  plt.close()
  logger.info(f"Bar plot saved at {save_path}")

def create_scatter_plot(data, x_column, y_column, save_path):
  """Generate a scatter plot with custom styling."""
  plt.figure(figsize=(10, 6))
  plt.scatter(data[x_column], data[y_column], color='skyblue', alpha=0.7, edgecolors='k', linewidth=0.5)

  ## Add labels for extremes and average
  min_point = data.loc[data[y_column].idxmin()]
  max_point = data.loc[data[y_column].idxmax()]
  avg_x = data[x_column].mean()
  avg_y = data[y_column].mean()

  plt.scatter(min_point[x_column], min_point[y_column], color='green', s=100, label='Min', edgecolor='k')
  plt.scatter(max_point[x_column], max_point[y_column], color='red', s=100, label='Max', edgecolor='k')
  plt.scatter(avg_x, avg_y, color='blue', s=100, label='Avg', edgecolor='k')

  ## Add legend
  plt.legend(frameon=False, fontsize=8)

  ## Styling
  # plt.title(f"Scatter Plot: {x_column} vs {y_column}", fontsize=12, color='#333')
  plt.xlabel(x_column, fontsize=10, color='#333')
  plt.ylabel(y_column, fontsize=10, color='#333')

  ax = plt.gca()
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.spines['left'].set_color('#333')
  ax.spines['bottom'].set_color('#333')

  plt.tight_layout()
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
  plt.close()
  logger.info(f"Scatter plot saved at {save_path}")


def visualize_classes_per_split(stats, to_path):
  """Visualize class distribution per split as bar plots."""
  plots_dir = os.path.join(to_path, "plots-datasets")
  os.makedirs(plots_dir, exist_ok=True)

  for split, distribution in stats["classes_per_split"].items():
    save_path = os.path.join(plots_dir, f"classes_per_split_{split}.png")
    create_bar_plot(
      data=distribution,
      title=f"Class Distribution ({split})",
      xlabel="Class (c0, c1, ...)",
      ylabel="Count",
      save_path=save_path
    )

def visualize_total_images_per_split(stats, to_path):
  """Visualize total images per split."""
  plots_dir = os.path.join(to_path, "plots-datasets")
  os.makedirs(plots_dir, exist_ok=True)

  save_path = os.path.join(plots_dir, "total_images_per_split.png")
  create_bar_plot(
    data=stats["total_images_per_split"],
    title="Total Images Per Split",
    xlabel="Split (Train/Val/Test)",
    ylabel="Count",
    save_path=save_path
  )

def create_mosaic(base_path, output_path, random_seed, labels_mapping, img_width=150, img_height=150):
  """Create a mosaic of one sample per class."""
  train_dir = os.path.join(base_path, 'imgs', 'train')

  ## Check if the train directory exists
  if not os.path.exists(train_dir):
    raise FileNotFoundError(f"Train directory not found: {train_dir}")

  classes = sorted(labels_mapping.keys())

  ## Select one random image per class
  random.seed(random_seed)
  selected_images = []
  for cls in classes:
    class_dir = os.path.join(train_dir, cls)
    images = os.listdir(class_dir)
    if not images:
      raise FileNotFoundError(f"No images found in class directory: {class_dir}")
    selected_image = random.choice(images)
    selected_images.append((cls, os.path.join(class_dir, selected_image)))

  ## Mosaic layout
  n_cols = 5
  n_rows = (len(selected_images) + n_cols - 1) // n_cols

  plt.figure(figsize=(n_cols * (img_width / 100), n_rows * (img_height / 100)))
  for idx, (cls, img_path) in enumerate(selected_images):
    img = Image.open(img_path).resize((img_width, img_height))
    plt.subplot(n_rows, n_cols, idx + 1)
    plt.imshow(img)
    plt.axis('off')
    plt.title(cls, fontsize=8, pad=5, color='#333')

  output_file = os.path.join(output_path, "mosaic.png")
  plt.tight_layout()
  plt.savefig(output_file, dpi=300, bbox_inches='tight')
  plt.close()
  logger.info(f"Saved mosaic at {output_file}")

def prepare_dataloaders(base_dir, split_ratio, batch_size, num_workers, input_size, split_method, shuffle, seed, calculate_mean_std_flag):
  """Prepares DataLoaders and statistics for dataset."""
  ## Load metadata
  df = pd.read_csv(os.path.join(base_dir, "driver_imgs_list.csv"))

  ## Map 'classname' to human-readable names only for the summary
  labels_mapping = load_labels(base_dir)

  ## Split data
  if split_method == "by-driver":
    train_df, val_df, test_df = split_data_by_driver(df, split_ratio=split_ratio)
  elif split_method == "by-image":
    train_df, val_df, test_df = split_data_by_image(df, split_ratio=split_ratio)
  else:
    raise ValueError("Invalid split_method. Choose 'by-driver' or 'by-image'.")

  logger.info(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")

  ## Create DataLoaders
  loaders = {
    "train": get_dataloader(train_df, os.path.join(base_dir, "imgs/train"), batch_size, num_workers, input_size, phase="train"),
    "val": get_dataloader(val_df, os.path.join(base_dir, "imgs/train"), batch_size, num_workers, input_size, phase="val"),
    "test": get_dataloader(test_df, os.path.join(base_dir, "imgs/train"), batch_size, num_workers, input_size, phase="test"),
  }

  ## Calculate mean and std for normalization
  mean, std = None, None
  if calculate_mean_std_flag:
    mean, std = calculate_mean_std(loaders["train"])

  return loaders, labels_mapping, {"train": train_df, "val": val_df, "test": test_df}, mean, std


def parse_args():
  parser = argparse.ArgumentParser(description="SF3D Dataset Preparation")
  parser.add_argument("--from", dest="from_path", required=True, help="Base dataset directory")
  parser.add_argument("--to", dest="to_path", default=f"logs/sf3d-{datetime.now().strftime('%d%m%y_%H%M%S')}", help="Output directory")
  parser.add_argument("--split-method", choices=["by-driver", "by-image"], default="by-image",
                      help="Splitting method: 'by-driver' or 'by-image'")
  parser.add_argument("--split-ratio", type=str, default="(0.7,0.2,0.1)", help="Split ratio for train, val, test")
  parser.add_argument("--batch-size", type=int, default=32, help="Batch size for dataloaders")
  parser.add_argument("--num-workers", type=int, default=4, help="Number of workers for data loading")
  parser.add_argument("--input-size", type=str, default="(224,224)", help="Input size for resizing images")
  parser.add_argument("--shuffle", action="store_true", default=True, help="Shuffle the dataset before splitting")
  parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
  parser.add_argument("--calculate-mean-std", action="store_true", default=False, help="Calculate mean and std for normalization")

  args = parser.parse_args()
  args.input_size = eval(args.input_size)
  args.split_ratio = eval(args.split_ratio)
  return args

def generate_split_files(df, split_name, root_dir, save_dir, label_to_numeric):
  """Generate text files for dataset splits with numeric class encoding."""
  os.makedirs(save_dir, exist_ok=True)
  output_path = os.path.join(save_dir, f"{split_name}.txt")

  with open(output_path, "w") as f:
    for idx, row in df.iterrows():
      class_id = label_to_numeric[row["classname"]]  ## Map to numeric value
      relative_path = os.path.relpath(
        os.path.join(root_dir, row["classname"], row["img"]), start=root_dir
      )
      f.write(f"{idx}\t{relative_path}\t{class_id}\n")  ## Use numeric class ID
  logger.info(f"Generated split file: {output_path}")


def generate_dataset_yml(output_dir, splits_dir, splits, stats):
  """Generate dataset yml in the destination directory."""
  yml_structure = {
    "name": "sf3d",
    "description": "Configuration for the State Farm Distracted Driver Detection (SF3D) dataset.",
    "basepath": os.path.abspath(args.from_path),  ## Absolute path
    "splits": {
      "train": f"{splits_dir}/train.txt",
      "val": f"{splits_dir}/val.txt",
      "test": f"{splits_dir}/test.txt",
    },
    "classes": stats["classes"],
    "input_size": [224, 224],
    "augmentations": {
      "resize": [224, 224],
      "random_rotation": 30,
      "random_crop": True,
      "random_erasing": True,
      "horizontal_flip": True,
      "normalize": True,
    },
    "mean": [round(m, 6) for m in stats["mean"]],  ## Rounded to avoid very long floats
    "std": [round(s, 6) for s in stats["std"]],    ## Rounded to avoid very long floats
    "batch_size": 16,
    "num_workers": 4,
    "shuffle": True,
    "seed": 42,
    "license": "",
    "extra": None,
    "version": "0.0.1",
    "url": None,
  }

  yml_path = os.path.join(output_dir, "dataset.yml")
  yml_dump(yml_path, yml_structure)
  logger.info(f"Generated YAML configuration at: {yml_path}")


def main(args):
  ## Prepare DataLoaders
  loaders, labels_mapping, splits, mean, std = prepare_dataloaders(
    args.from_path,
    args.split_ratio,
    args.batch_size,
    args.num_workers,
    args.input_size,
    args.split_method,
    args.shuffle,
    args.seed,
    args.calculate_mean_std,
  )

  ## Generate label mappings
  label_to_numeric, numeric_to_label = generate_label_mappings_from_file(args.from_path)

  splits_dir = os.path.join("data-splits", args.split_method)
  to_path = os.path.join(args.to_path, splits_dir)

  os.makedirs(to_path, exist_ok=True)

  ## Generate statistics
  stats = generate_statistics(
    splits["train"],
    splits["val"],
    splits["test"],
    labels_mapping,
    args.shuffle,
    args.seed,
    args.split_method,
    mean,
    std,
    label_to_numeric,
    numeric_to_label,
  )

  ## Save statistics to summary.json
  summary_path = os.path.join(to_path, "summary.json")
  with open(summary_path, "w") as f:
    json.dump(stats, f, indent=2)
  logger.info(f"Saved dataset statistics to {summary_path}")

  ## Generate split files in the 100-driver annotation format
  generate_split_files(splits["train"], "train", os.path.join(args.from_path, "imgs/train"), to_path, label_to_numeric)
  generate_split_files(splits["val"], "val", os.path.join(args.from_path, "imgs/train"), to_path, label_to_numeric)
  generate_split_files(splits["test"], "test", os.path.join(args.from_path, "imgs/train"), to_path, label_to_numeric)

  ## Generate dataset yml configuration file
  generate_dataset_yml(to_path, splits_dir, splits, stats)

  ## Visualize and save dataset distributions
  visualize_classes_per_split(stats, to_path)
  visualize_total_images_per_split(stats, to_path)

  ## Create mosaic of one image per class
  logger.info("Generating mosaic for one image per class...")
  create_mosaic(args.from_path, to_path, args.seed, labels_mapping, img_width=150, img_height=150)
  logger.info("Mosaic creation complete.")

  logger.info("All tasks completed successfully!")

if __name__ == "__main__":
  args = parse_args()
  main(args)
