"""
Train network using PyTorch for the 100-Driver dataset.

Adapted from: https://github.com/Shenqishaonv/100-Driver-Source
* also the conf directory, models directory and cutils.py module

@author: wenjing
@article{100-Driver,
author    = {Wang Jing, Li Wengjing, Li Fang, Zhang Jun, Wu Zhongcheng, Zhong Zhun and Sebe Nicu},
title     = {100-Driver: A Large-scale, Diverse Dataset for Distracted Driver Classification},
journal={IEEE Transactions on Intelligent Transportation Systems},
year      = {2023}
publisher={IEEE}}
"""
__author__ = 'XYZ'


import os
import json
import sys
import time

from collections import Counter
from datetime import datetime

import pandas as pd
from tqdm import tqdm

try:
  import torch
  import torch.nn as nn
  import torch.optim as optim

  ## clear memory
  torch.cuda.empty_cache()
  ## Adjust this as needed - not needed for dynamic growth
  # torch.cuda.set_per_process_memory_fraction(0.25, 0)
except ImportError:
  print('torch is not installed')


from .core._log_ import logger
log = logger(__file__)


from . import nnloss

from .dataset import get_splits, load_dataset
from .utils.lossmetrics import compute_loss_and_topk
from .utils.modelstats import (
  get_device,
  gather_system_info,
  model_summary,
  model_perfstats,
  save_model_summary,
  save_model_perfstats,
  save_model_architecture,
  save_report,
)
from .utils.classificationmetrics import (
  save_confusion_matrix_with_fp_fn_imagelist,
  plot_curves,
  plot_and_save_curves,
)
from .utils.tensorboard_logger import TensorBoardLogger
from .utils.torchapi import loadmodel
from .utils.sysutils import save_system_info, save_timings
from .conf import settings
from .cutils import WarmUpLR

def write_csv(to_path, data, header=None):
  """Append data to CSV file using pandas."""
  df = pd.DataFrame([data], columns=header)
  df.to_csv(to_path, mode='a', header=not os.path.exists(to_path), index=False)


def train(epoch, net, train_loader, optimizer, loss_function, warmup_scheduler, train_dataset, sample_limit, num_classes, args, description_text='Training', topk_values=(1, 3, 5), score_level=10):
  """Train the network for one epoch."""
  net.train()
  loss_train = 0.0
  score_dis_train = 0.0
  score_train = 0.0
  total = 0

  correct_topk = {k: 0 for k in topk_values}

  for batch_index, (images, labels, *rest) in enumerate(tqdm(train_loader, total=len(train_loader), desc=f"{description_text} Epoch {epoch}", leave=False)):
    if sample_limit and (batch_index * len(images)) >= sample_limit:
      break

    if args.gpu:
      labels = labels.cuda()
      images = images.cuda()

    outputs = net(images)
    __outputs = outputs.logits if hasattr(outputs, 'logits') else outputs

    optimizer.zero_grad()

    try:
      result = compute_loss_and_topk(loss_function, __outputs, labels, num_classes, topk_values=topk_values, score_level=score_level)

      total += result['total']
      loss_train += result['loss'].item()
      score_dis_train += result['score_dis'].item() if result['score_dis'] is not None else 0.0
      score_train += result['score'].item() if result['score'] is not None else 0.0

      for k in topk_values:
        correct_topk[k] += result['correct_topk'][k]

      result['loss'].backward()
      optimizer.step()

    except Exception as e:
      print(f"[!] Training loss computation failed: {e}")
      continue

    if epoch <= args.warm:
      warmup_scheduler.step()

  train_loss = loss_train / len(train_loader)
  topk_accuracy = {k: correct_topk[k] / total for k in topk_values}

  ## Print metrics
  topk_str = " | ".join([f"Top-{k}: {acc:.4f}" for k, acc in topk_accuracy.items()])
  print(f"Training - Epoch {epoch} | Loss: {train_loss:.4f} | {topk_str}")

  if score_dis_train > 0:
    print(f"         ScoreDis: {score_dis_train / len(train_loader):.4f} | Score: {score_train / len(train_loader):.4f}")

  return train_loss, topk_accuracy[1], topk_accuracy


def validate(epoch, net, val_loader, loss_function, val_dataset, sample_limit, num_classes, args, description_text='Validation', topk_values=(1, 3, 5), score_level=10):
  net.eval()
  val_loss = 0.0
  correct = 0.0
  score_dis_val = 0.0
  score_val = 0.0
  total = 0
  correct_topk = {k: 0 for k in topk_values}

  with torch.no_grad():
    for batch_idx, (images, labels, *rest) in enumerate(tqdm(val_loader, total=len(val_loader), desc=f"{description_text} Epoch {epoch}", leave=False)):
      if sample_limit and (batch_idx * len(images)) >= sample_limit:
        break

      if args.gpu:
        images = images.cuda()
        labels = labels.cuda()

      outputs = net(images)
      __outputs = outputs.logits if hasattr(outputs, 'logits') else outputs

      try:
        result = compute_loss_and_topk(loss_function, __outputs, labels, num_classes, topk_values=topk_values, score_level=score_level)

        total += result['total']
        val_loss += result['loss'].item()
        score_dis_val += result['score_dis'].item() if result['score_dis'] is not None else 0.0
        score_val += result['score'].item() if result['score'] is not None else 0.0

        for k in topk_values:
          correct_topk[k] += result['correct_topk'][k]

      except Exception as e:
        print(f"[!] Validation loss computation failed: {e}")
        continue

  val_loss /= len(val_loader)
  topk_accuracy = {k: correct_topk[k] / total for k in topk_values}

  ## Print metrics
  topk_str = " | ".join([f"Top-{k}: {acc:.4f}" for k, acc in topk_accuracy.items()])
  print(f"{description_text} - Epoch {epoch} | Loss: {val_loss:.4f} | {topk_str}")

  if score_dis_val > 0:
    print(f"         ScoreDis: {score_dis_val / len(val_loader):.4f} | Score: {score_val / len(val_loader):.4f}")

  return val_loss, topk_accuracy[1], topk_accuracy


def evaluate_model(model, dataloader, device):
  all_logits, all_labels = [], []
  total_inference_time = 0.0

  model.eval()
  with torch.no_grad():
    with tqdm(total=len(dataloader), desc="Evaluating", unit="batch") as pbar:
      for images, labels, _ in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        start_time = time.time()
        logits = model(images)
        total_inference_time += time.time() - start_time
        
        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())
        pbar.update(1)

  return torch.cat(all_logits), torch.cat(all_labels), total_inference_time


def main(args):
  datasetcfg = args.datasetcfg
  dataset = args.dataset
  num_class = args.num_class
  sample_limit = args.sample_limit
  loss_fname = args.loss

  topk_values = getattr(args, 'topk_values', (1, 3, 5))
  assert isinstance(topk_values, tuple) and all(isinstance(k, int) for k in topk_values)
  topk_values = sorted(set(topk_values))


  to_path_default = os.path.join('logs', f"train-{datetime.now().strftime('%d%m%y_%H%M%S')}", args.net)
  to_path = args.to_path or to_path_default
  os.makedirs(to_path, exist_ok=True)

  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
  print(torch.version.cuda)
  print(f'loadmodel::device: {device}')
  print(f'loadmodel::args.gpu: {args.gpu}')

  ## Record system information
  save_system_info(to_path)

  net = loadmodel(args)

  ## Initialize TensorBoard Logger also initialise directory creation
  tensorlog = TensorBoardLogger(log_dir=to_path)
  log.info(f"TensorBoard logs at: {tensorlog.log_dir}")

  ## Dataset Preparation
  __dataset_root__ = os.getenv('__DATASET_ROOT__')
  splits = get_splits(args.dataset, args.datasetcfg, __dataset_root__)

  train_loader, train_dataset, train_label_counts, train_dataset_info = load_dataset(args, splits, flag='train')
  val_loader, val_dataset, val_label_counts, val_dataset_info = load_dataset(args, splits, flag='val')
  test_loader, test_dataset, test_label_counts, test_dataset_info = load_dataset(args, splits, flag='test')

  ## Derive number of classes from dataset metadata
  actual_num_classes = len(train_dataset_info['label_to_index'])
  batch_size = args.batch_size
  args.num_class = actual_num_classes  ## patch args to ensure downstream consistency
  score_level = args.score_level

  ## Loss, Optimizer, and Schedulers
  ## Prepare loss config from dataset info instead of args
  loss_cfg = {
    'batch_size': batch_size,
    'classes': actual_num_classes,
    'device': device,
    'dtype': torch.float32,
    'score_level': score_level,
  }

  try:
    ## loss_function = nn.CrossEntropyLoss()
    loss_function = nnloss.get_loss_function(loss_fname, config=loss_cfg)
  except Exception as e:
    log.error(f"Failed to initialize loss function '{loss_fname}': {e}")
    sys.exit(1)

  optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
  train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2)
  warmup_scheduler = WarmUpLR(optimizer, len(train_loader) * args.warm)
  best_val_accuracy = 0.0

  ## Set up CSV file paths and headers
  train_csv_filepath = os.path.join(to_path, f"{args.net}-{args.dataset}-train.csv")
  val_csv_filepath = os.path.join(to_path, f"{args.net}-{args.dataset}-val.csv")
  test_csv_filepath = os.path.join(to_path, f"{args.net}-{args.dataset}-test.csv")
  # header = ['start', 'end', 'duration_in_seconds', 'epoch', 'loss', 'accuracy']
  header = ['start', 'end', 'duration_in_seconds', 'epoch', 'loss', 'accuracy'] + [f'top{k}_accuracy' for k in topk_values if k != 1]

  ## Record training start time
  training_start = time.time()

  ## Training and Validation Loop
  for epoch in range(1, args.epochs + 1):
    start_time_epoch = time.time()

    if epoch > args.warm:
      train_scheduler.step(epoch)
    if args.resume and epoch <= args.resume_epoch:
      continue

    ## train
    train_loss, train_accuracy, train_topk_accuracy = train(epoch, net, train_loader, optimizer, loss_function, warmup_scheduler, train_dataset, sample_limit, actual_num_classes, args, description_text='Training', topk_values=topk_values, score_level=score_level)

    end_time_epoch = time.time()
    duration_epoch = end_time_epoch - start_time_epoch

    tensorlog.train_metrics(train_loss, train_accuracy, epoch)
    tensorlog.topk_metrics("Train", train_topk_accuracy, epoch)

    train_row = [start_time_epoch, end_time_epoch, duration_epoch, epoch, train_loss, train_accuracy] + [train_topk_accuracy[k] for k in topk_values if k != 1]
    write_csv(train_csv_filepath, train_row, header)

    ## val
    val_loss, val_accuracy, val_topk_accuracy = validate(epoch, net, val_loader, loss_function, val_dataset, sample_limit, actual_num_classes, args, description_text='Validation', topk_values=topk_values, score_level=score_level)

    tensorlog.val_metrics(val_loss, val_accuracy, epoch)
    tensorlog.topk_metrics("Val", val_topk_accuracy, epoch)
    # val_row = [start_time_epoch, end_time_epoch, duration_epoch, epoch, val_loss, val_accuracy]
    val_row = [start_time_epoch, end_time_epoch, duration_epoch, epoch, val_loss, val_accuracy] + [val_topk_accuracy[k] for k in topk_values if k != 1]
    write_csv(val_csv_filepath, val_row, header)

    ## test
    test_loss, test_accuracy, test_topk_accuracy = validate(epoch, net, test_loader, loss_function, test_dataset, sample_limit, actual_num_classes, args, description_text='Testing', topk_values=topk_values, score_level=score_level)

    tensorlog.test_metrics(test_loss, test_accuracy, epoch)
    tensorlog.topk_metrics("Test", test_topk_accuracy, epoch)

    test_row = [start_time_epoch, end_time_epoch, duration_epoch, epoch, test_loss, test_accuracy] + [test_topk_accuracy[k] for k in topk_values if k != 1]
    write_csv(test_csv_filepath, test_row, header)

    tensorlog.increment_epoch()

    ## Save Best Model Checkpoint
    if epoch % 1 == 0 and epoch > settings.MILESTONES[1] and val_accuracy > best_val_accuracy:
      log.info("New best model after milestone! Saving checkpoint...")
      torch.save(net.state_dict(), os.path.join(to_path, f"{args.net}-epoch_{epoch}-best.pth"))
      best_val_accuracy = val_accuracy

  ## Save Final Model in PyTorch format
  torch.save(net.state_dict(), os.path.join(to_path, f"{args.net}-final.pth"))

  ## Record training end time
  training_end = time.time()

  ## Close TensorBoard Logger
  tensorlog.close()

  # ## Save Final Model in ONNX format
  # d_input = torch.randn(1, 3, *args.input_size).to(device)
  # torch.onnx.export(net, d_input, os.path.join(to_path, f"{args.net}-final.onnx"), verbose=True)

  ## ----------------------------------------------------------
  ## Record evaluation start time
  ## ----------------------------------------------------------
  eval_start = time.time()

  save_dir = to_path
  ## Evaluate Best Model - Initialize the ClassificationMetrics class
  log.info("Evaluating best model...on the test set::")

  save_model_architecture(net, save_dir, args.net)

  log.info(f"Running inference for model: {args.net}...")


  ## Get class names from dataset_info
  index_to_label = test_dataset_info['index_to_label']
  class_names = list(index_to_label.values())

  logits, labels, _ = evaluate_model(net, test_loader, get_device(args))
  
  ## TODO: fix for SSoftmax
  metrics = []
  # ## Visualization
  # metrics = save_confusion_matrix_with_fp_fn_imagelist(logits, labels, class_names, save_dir, test_dataset)
  # plot_curves(logits, labels, class_names, save_dir)
  # plot_and_save_curves(logits, labels, class_names, save_dir)

  num_iterations = args.num_iterations
  ## Generate Model Summary based on final trained model
  log.info("Generating model summary...")
  key_stats, summary_info = model_summary(
    model=net,
    input_size=(3, *args.input_size),  ## Use args.input_size for summary
    device=device,
    verbose=False,
    num_iterations=num_iterations,
    weights_path=args.weights_path or None,
    dnnarch=args.net,
    num_class=len(class_names),
    depth=10,
  )
  save_model_summary(to_path, key_stats, summary_info)
  log.info(f"Model summary saved at: {to_path}")

  ## Generate Model stats based on final trained model
  log.info("Generating model information...")
  perfstats = model_perfstats(
    model=net,
    input_size=(3, *args.input_size),
    device=device,
    verbose=False,
    num_iterations=num_iterations,
    weights_path=args.weights_path or None,
    dnnarch=args.net,
    num_class=len(class_names),
  )
  save_model_perfstats(to_path, perfstats)
  log.info(f"Model perfstats: {perfstats}")

  ## Save report using actual class names
  report_data = {
    "system_info": gather_system_info(),
    "metrics": metrics,
    "dataset_info": {
      "train": train_dataset_info,
      "val": val_dataset_info,
      "test": test_dataset_info,
    },
    "model_info": {
      "architecture": net.__class__.__name__,
      "total_layers": len(list(net.modules())),
      "weights_path": args.weights_path or "N/A",
      "input_size": args.input_size,
      "creator": __author__,
    },
    "modelsummary": key_stats,
    "modelperfstats": perfstats,
    "score_level": score_level,
    "loss_fname": loss_fname,
  }

  save_report(report_data, save_dir)

  ## Record evaluation end time
  eval_end = time.time()

  ## Save timing stats
  save_timings(to_path, training_start, training_end, eval_start, eval_end)


def parse_args(**kwargs):
  """Common arg parser."""
  import argparse
  import ast

  from argparse import RawTextHelpFormatter

  parser = argparse.ArgumentParser(description='Input parser', formatter_class=RawTextHelpFormatter)
  parser.add_argument('--net', type=str, default='mobilenet_v2', required=True, help='net type')
  parser.add_argument('--loss', type=str, default='CrossEntropyLoss', help='define the lossfunction to be used')
  parser.add_argument('--score_level', type=int, default=10, help='score_level used in scored lossfunction')
  parser.add_argument('--topk', dest='topk_values', type=lambda s: tuple(map(int, s.split(','))), default=(1, 3, 5), help='Comma-separated top-k values, e.g., 1,3,5')
  parser.add_argument('--gpu', action='store_true', default=True, help='use gpu or not')
  parser.add_argument('--batch_size', type=int, default=64, help='batch size for dataloader')
  parser.add_argument('--warm', type=int, default=2, help='warm up training phase')
  parser.add_argument('--lr', type=float, default=0.01, help='initial learning rate')
  parser.add_argument('--resume', action='store_true', default=False, help='resume training')
  parser.add_argument('--resume_epoch', type=int, default=0)
  parser.add_argument('--pretrain', action='store_true', default=False, help='whether the pretrain model is used')
  parser.add_argument("--dataset", default='100-driver-day-cam1', type=str)
  parser.add_argument("--num_class", default=22, type=int)
  # parser.add_argument('--num_classes', type=str, default=22, required=False)
  parser.add_argument('--epochs', type=int, default=100)
  parser.add_argument('--weights_path', type=str)
  parser.add_argument('--checkpoint_path', type=str)
  parser.add_argument('--to', type=str, dest='to_path', default=None, help='Output directory for saving results')
  parser.add_argument('--datasetcfg', default="data/ddd-datasets.yml", help='dataset configuration file path')
  parser.add_argument('--sample_limit', type=int, help='Limit number of samples per epoch for quick testing')
  parser.add_argument('--input_size', type=str, default='(224,224)', help="input size for the DNN")
  parser.add_argument('--num_iterations', type=int, default=100, help='total number of iterations for FPS calculations')

  ## DataLoader configuration
  parser.add_argument('--num_workers', type=int, default=4, help='number of worker processes for data loading')

  args = parser.parse_args()

  ## Convert the --input_size argument from a string to a tuple
  try:
    args.input_size = ast.literal_eval(args.input_size)
    # print(f"args.input_size: {type(args.input_size)}")
    # print(f"args.input_size: {args.input_size}")
    if not isinstance(args.input_size, tuple) or len(args.input_size) != 2:
      raise ValueError
  except (ValueError, SyntaxError):
    print("Error: --input_size should be a tuple of two integers, e.g., '(224,224)'.")
    sys.exit(1)

  print(f'parse_args:: type(args.input_size), args.input_size: {type(args.input_size), args.input_size}')
  return args


def print_args(args):
  print("Arguments:")
  for arg in vars(args):
    print(f"{arg}: {getattr(args, arg)}")


if __name__ == '__main__':
  args = parse_args()
  print_args(args)

  main(args)
