"""ddd dataset loader"""
__author__ = 'XYZ'


import os
import random

from collections import Counter, OrderedDict, defaultdict
from datetime import datetime
from pathlib import Path

try:
  import torch
  import torchvision
  from torchvision import transforms
  from torch.utils.data import DataLoader
except ImportError:
  print('torchvision is not installed')

import matplotlib.pyplot as plt

from PIL import Image, UnidentifiedImageError
from tqdm import tqdm

# ## Fix for PIL Image truncation errors
# from PIL import ImageFile
# ImageFile.LOAD_TRUNCATED_IMAGES = True

from .core._log_ import logger
log = logger(__file__)

from .core.fro import load_filecontent
from .core.fwo import write_json
from .core.parsers import string_to_list


def _tail_rel(root: str, depth: int = 2) -> str:
  """
  Return last <depth> directories of root as a relative prefix.
  No hardcoding of tokens like 'imgs' or split names.
  """
  parts = os.path.normpath(root).split(os.sep)
  if not parts:
    return ""
  k = depth if len(parts) >= depth else 1
  return os.path.join(*parts[-k:])


class PlainDataSet(torch.utils.data.Dataset):
  def __init__(self, root, lists, sample_limit=None, *, prefix_depth: int = 2, seed=42):
    with open(lists, 'r') as f:
      lines = [l.strip() for l in f if l and not l.startswith("#")]

    imgs, labels, rels, unique_labels = [], [], [], OrderedDict()
    for line in lines:
      # expected: "<id> <relative_path_from_images_root> [label?]"
      toks = line.split()
      _id, rel = toks[0], toks[1]
      abs_path = os.path.join(root, rel)
      classname = rel.split('/')[0]
      if classname not in unique_labels:
        unique_labels[classname] = len(unique_labels)
      imgs.append(abs_path)
      labels.append(unique_labels[classname])
      rels.append(rel)

    # if sample_limit:
    #   imgs, labels, rels = imgs[:sample_limit], labels[:sample_limit], rels[:sample_limit]

    ## Balanced sampling per class
    if sample_limit and sample_limit > 0:
      rng = random.Random(seed)
      class_to_indices = defaultdict(list)
      for idx, lbl in enumerate(labels):
        class_to_indices[lbl].append(idx)

      balanced_indices = []
      for lbl, idxs in class_to_indices.items():
        k = min(sample_limit, len(idxs))
        balanced_indices.extend(rng.sample(idxs, k))

      imgs   = [imgs[i] for i in balanced_indices]
      labels = [labels[i] for i in balanced_indices]

    self.imgs = imgs
    self.labels = labels
    self.rels = rels
    self.unique_label_names = list(unique_labels.keys())
    self.label_to_index = dict(unique_labels)
    self.index_to_label = {v: k for k, v in unique_labels.items()}

    # expose bases the annotator needs
    self.images_root = os.path.normpath(root)
    self.prefix_rel = _tail_rel(self.images_root, depth=prefix_depth)

  def __getitem__(self, index):
    # return absolute file path, label index, and the RELATIVE path (as given by split file)
    return self.imgs[index], self.labels[index], self.rels[index]

  def __len__(self):
    return len(self.imgs)


def load_plain_dataset(root,
                       split_file,
                       sample_limit=None,
                       batch_size=1,
                       shuffle=False,
                       num_workers=0,
                       *,
                       prefix_depth: int = 2):
  dataset = PlainDataSet(root=root, lists=split_file, sample_limit=sample_limit, prefix_depth=prefix_depth)
  loader = get_plain_dataloader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
  return loader, dataset


def get_plain_dataloader(dataset, batch_size=1, shuffle=False, num_workers=0):
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)



class DataSet(torch.utils.data.Dataset):
  def __init__(self, root, lists, input_size=(224,224), transform=None, sample_limit=None, seed=42):
  # def __init__(self, root, lists,  input_size=(224, 224), transform=None, sample_limit=None):
  # def __init__(self, root, lists, flag, mean=None, std=None, input_size=(224, 224), sample_limit=None, apply_transforms=True):
    """
    Initializes the dataset with image paths, labels, transformations, unique labels, and label mappings.
    """
    with open(lists, 'r') as f:
      lines = f.readlines()

    ## Prepare lists for images and labels, as well as a set to collect unique labels
    imgs = []
    labels = []
    unique_labels = OrderedDict()
    self.errors = defaultdict(list)  ## Dictionary to collect errors

    for line in lines:
      image_path = os.path.join(root, line.split()[1])
      ## Extract the directory as the label
      label_text = line.split()[1].split('/')[0]
      label_index = int(line.split()[2])

      ## Populate images and labels
      imgs.append(image_path)
      labels.append(label_index)

      ## Preserve unique labels and their order
      if label_text not in unique_labels:
        unique_labels[label_text] = label_index

    # if sample_limit is not None:
    #   imgs = imgs[:sample_limit]
    #   labels = labels[:sample_limit]

    ## Balanced sampling per class
    if sample_limit and sample_limit > 0:
      rng = random.Random(seed)
      class_to_indices = defaultdict(list)
      for idx, lbl in enumerate(labels):
        class_to_indices[lbl].append(idx)

      balanced_indices = []
      for lbl, idxs in class_to_indices.items():
        k = min(sample_limit, len(idxs))
        balanced_indices.extend(rng.sample(idxs, k))

      imgs   = [imgs[i] for i in balanced_indices]
      labels = [labels[i] for i in balanced_indices]

    ## Set the instance variables
    self.imgs = imgs
    self.labels = labels
    self.input_size = input_size
    self.unique_label_names = list(unique_labels.keys())
    self.label_to_index = {name: idx for idx, name in enumerate(self.unique_label_names)}
    self.index_to_label = {idx: name for name, idx in self.label_to_index.items()}

    ## Default transform if none provided
    self.transform = transform if transform else transforms.Compose([
      transforms.Resize(self.input_size),
      transforms.ToTensor()
    ])
    self.mean = None
    self.std = None

  def __getitem__(self, index):
    while True:
      img_path = self.imgs[index]
      label = self.labels[index]
      
      try:
        data = Image.open(img_path).convert('RGB')
        data = self.transform(data) if self.transform else data
        return data, label, img_path
      except (OSError, UnidentifiedImageError) as e:
        error_key = type(e).__name__.replace(" ", "_").lower()
        self.errors[error_key].append(img_path)
        log.error(f"Error loading image {img_path}: {str(e)}")
        ## Move to the next index if there's an error
        index = (index + 1) % len(self.imgs)

  def __len__(self):
    return len(self.imgs)


def plot_total_images_per_label(metadata, plots_dir, split_name):
  """Generate bar plot for total images per label with line connection for a specific split."""
  labels = metadata["labels"]
  counts = [metadata["total_images_per_label"].get(label, 0) for label in labels]

  fig, ax = plt.subplots(figsize=(10, 6))
  bars = ax.bar(labels, counts, color="skyblue")
  
  ## Draw lines connecting bar centers with reduced opacity and z-order below labels
  bar_centers = [bar.get_x() + bar.get_width() / 2 for bar in bars]
  ax.plot(
    bar_centers, counts, color="red", marker="o", alpha=0.4, zorder=1
  )

  ## Set axis labels and title
  ax.set_xlabel("Labels")
  ax.set_ylabel("Image Count")
  ax.set_title(f"Total Images per Label - {split_name.capitalize()} Split")
  ax.set_xticklabels(labels, rotation=90)

  ## Add count labels above each bar with higher z-order
  for bar, count in zip(bars, counts):
    ax.text(
      bar.get_x() + bar.get_width() / 2, 
      bar.get_height() + 10,  ## Position further above the bar and line
      f"{count}", 
      ha='center', 
      va='bottom', 
      fontsize=8, 
      color='black',
      zorder=2  ## Ensure labels are on top of the line
    )

  plot_path = plots_dir / f"{split_name}_total_images_per_label.png"
  fig.savefig(plot_path, bbox_inches="tight")
  plt.close(fig)

  return plot_path


def plot_split_distribution(splits_metadata, plots_dir):
  """Generate a donut chart showing the distribution across dataset splits with total counts."""
  labels = []
  sizes = []
  counts_labels = []

  for split_name, metadata in splits_metadata.items():
    count = sum(metadata["total_images_per_label"].values())
    labels.append(f"{split_name.capitalize()} ({count})")  ## Add total count in parentheses
    sizes.append(count)
    counts_labels.append(count)
    
  fig, ax = plt.subplots(figsize=(6, 6))
  wedges, texts, autotexts = ax.pie(
    sizes, labels=labels, autopct='%1.1f%%', startangle=90, wedgeprops={'width': 0.3}
  )

  ## Adjust percentage labels to include count in smaller font
  for i, autotext in enumerate(autotexts):
    autotext.set_text(f"{autotext.get_text()} ({counts_labels[i]})")
    autotext.set_fontsize(8)

  ax.set_title("Dataset Split Distribution")

  plot_path = plots_dir / "split_distribution_donut.png"
  fig.savefig(plot_path, bbox_inches="tight")
  plt.close(fig)

  return plot_path

def create_mosaic(plots, output_dir):
  """Combine individual plot images into a mosaic."""
  images = [Image.open(plot) for plot in plots]
  
  ## Define mosaic size
  mosaic_width = sum(img.width for img in images)
  mosaic_height = max(img.height for img in images)

  mosaic_img = Image.new("RGB", (mosaic_width, mosaic_height), "white")

  ## Paste images side-by-side in the mosaic
  x_offset = 0
  for img in images:
    mosaic_img.paste(img, (x_offset, 0))
    x_offset += img.width

  mosaic_path = output_dir / "plots" / "dataset_mosaic.png"
  mosaic_path.parent.mkdir(parents=True, exist_ok=True)  ## Ensure `plots` directory exists
  mosaic_img.save(mosaic_path)

  return mosaic_path

def calculate_mean_std(loader):
  """Calculate mean and std with a progress bar."""
  mean = torch.zeros(3)
  std = torch.zeros(3)
  total_images_count = 0

  for images, _, _ in tqdm(loader, desc="Calculating mean and std", unit="batch"):
    if images is None:  ## Skip corrupted images
      continue
    images = images.view(images.size(0), images.size(1), -1)  ## Flatten image pixels
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)
    total_images_count += images.size(0)

  mean /= total_images_count
  std /= total_images_count

  ## Round to four decimal places
  mean = mean.round(decimals=4)
  std = std.round(decimals=4)
  return mean, std


def get_dataloader(dataset, batch_size=16, num_workers=2, shuffle=True):
  """ return dataloader"""
  dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
  return dataloader


def get_splits(dataset, datasetcfg, __dataset_root__):
  trainloader = None
  trainloadertxt = None
  valloader = None
  valloadertxt = None

  for df,filepath,filename in load_filecontent(datasetcfg, ext='.yml'):
    log.info(f'df: {df}')
    log.info(f'filename: {filename}')
    log.info(f'filepath: {filepath}')
    log.debug(f'df.columns: {df.columns}')

    _datasetcfg = df.loc[dataset].to_dict()
    log.debug(f'_datasetcfg: {_datasetcfg}')

    splits = {key: os.path.join(__dataset_root__, value) for key, value in _datasetcfg.items()}
  return splits


def load_dataset(args, splits, flag='train', shuffle=True, timestamp=None):
  log.info(f"splits:: {splits}")
  log.info(f"flag:: {flag}")

  loadertxt = splits['loadertxt']
  flagloadertxt = splits[f'{flag}loadertxt']

  ## Create an initial dataset instance without normalization
  transform = transforms.Compose([
    transforms.Resize(args.input_size),
    transforms.ToTensor()
  ])

  dataset = DataSet(
    root=loadertxt,
    lists=flagloadertxt,
    input_size=args.input_size,
    transform=transform,
    sample_limit=args.sample_limit
  )
  loader = get_dataloader(
    dataset=dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=shuffle,
  )

  ## Calculate mean and std
  mean, std = calculate_mean_std(loader)

  ## Re-apply normalization to the dataset
  if flag=='train':
    ## augmentaiton for the training dataset
    ## TODO: use better augmentation strategy
    transform = transforms.Compose([
        transforms.Resize(args.input_size),
        transforms.RandomRotation(30),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
        transforms.RandomErasing()
      ])
  else:
    transform = transforms.Compose([
      transforms.Resize(args.input_size),
      transforms.ToTensor(),
      transforms.Normalize(mean=mean, std=std)
    ])

  dataset.transform = transform
  ## update the dataloader with updated transforms
  loader = get_dataloader(
    dataset=dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=shuffle,
  )
  
  label_counts = Counter(dataset.labels)
  total_images_per_label = {dataset.index_to_label[idx]: count for idx, count in label_counts.items()}

  index_to_label = dataset.index_to_label
  class_names = list(index_to_label.values())

  metadata = {
    "creator": __author__,
    "timestamp": timestamp or datetime.now().strftime("%d%m%y_%H%M%S"),
    "dataset_name": args.dataset,
    "num_samples": len(dataset),
    "num_classes": len(class_names),
    "batch_size": args.batch_size,
    "split": flag,
    "loadertxt": loadertxt,
    "flagloadertxt": flagloadertxt,
    "input_size": args.input_size,
    "num_unique_labels": len(dataset.unique_label_names),
    "label_counts": dict(label_counts),
    "total_images_per_label": total_images_per_label,
    "mean": mean.tolist(),
    "std": std.tolist(),
    "label_to_index": dict(dataset.label_to_index),
    "labels": list(dataset.label_to_index.keys()),
    "label_ids": list(dataset.label_to_index.values()),
    "index_to_label": dict(dataset.index_to_label),
  }
  return loader, dataset, label_counts, metadata


def process_dataset(args, splits, to_path, flag, timestamp=None):
  dataset_name = args.dataset
  loader, dataset, label_counts, metadata = load_dataset(args, splits, flag=flag, timestamp=timestamp)

  # error_filepath = os.path.join(to_path,f"{dataset_name}-{flag}.errors.json")
  # write_json(error_filepath, dataset.errors)
  # log.info(f"File saved to {error_filepath}")  

  metadata_filepath = os.path.join(to_path,f"{dataset_name}.{flag}.metadata.json")
  write_json(metadata_filepath, metadata)
  log.info(f"File saved to {metadata_filepath}")  
  return metadata


def main(args):
  """Perform analysis on the dataset, including statistics calculation and visualization."""
  __dataset_root__ = os.getenv('__DATASET_ROOT__')
  dataset_name = args.dataset
  split = args.split
  ## create directories
  timestamp = datetime.now().strftime("%d%m%y_%H%M%S")
  output_dir = Path(args.to_path or f"logs/{timestamp}")
  plots_dir = output_dir / "plots"
  plots_dir.mkdir(parents=True, exist_ok=True)

  ## Load metadata and generate individual plots
  splits_metadata = {}  ## Placeholder for split metadata loading
  plot_paths = []

  splits = get_splits(dataset_name, args.datasetcfg, __dataset_root__)

  ## TODO: one or subset of it as uer input, rather than all the 3 split everytime
  for __split in split:
    metadata = process_dataset(args, splits, output_dir, flag=__split, timestamp=timestamp)
    splits_metadata[__split] = metadata
    plot_path = plot_total_images_per_label(metadata, plots_dir, __split)
    plot_paths.append(plot_path)

  ## Generate and add split distribution plot to the mosaic
  split_distribution_path = plot_split_distribution(splits_metadata, plots_dir)
  plot_paths.append(split_distribution_path)

  ## Create final mosaic
  create_mosaic(plot_paths, output_dir)


def parse_args(**kwargs):
  """Comprehensive argument parser for dataset analysis."""
  import argparse
  import ast
  parser = argparse.ArgumentParser(description='Input parser', formatter_class=argparse.RawTextHelpFormatter)

  parser.add_argument('--batch_size', type=int, default=1, help='batch size for dataloader')
  parser.add_argument('--dataset', default='100-driver-day-cam1', type=str, help='dataset name or path')
  parser.add_argument('--datasetcfg', default="data/100-driver-distracted-driving-datasets.yml", help='dataset configuration file path')
  parser.add_argument('--sample_limit', type=int, help='limit number of samples per epoch for testing')
  parser.add_argument('--input_size', type=str, default='(224,224)', help="input size for the DNN")
  parser.add_argument('--split', type=str, default='test,val,train', help="Comma separated dataset splits to be used for evaluating on.")

  ## Output path argument
  parser.add_argument('--to', dest='to_path', type=str, help='Output directory for saving results')

  ## DataLoader configuration
  parser.add_argument('--num_workers', type=int, default=4, help='number of worker processes for data loading')

  args = parser.parse_args()

  ## Validate and parse the input_size argument
  try:
    args.input_size = ast.literal_eval(args.input_size)
    if not isinstance(args.input_size, tuple) or len(args.input_size) != 2:
      raise ValueError
  except (ValueError, SyntaxError):
    log.error("Error: --input_size should be a tuple of two integers, e.g., '(224,224)'")
    exit(1)

  args.split = string_to_list(args.split)
  return args


def print_args(args):
  log.info("Arguments:")
  for arg in vars(args):
    log.info(f"{arg}: {getattr(args, arg)}")


if __name__ == '__main__':
  args = parse_args()
  print_args(args)

  main(args)
