"""Inference + evaluation module integrated with predictions, metrics, and reporting."""
__author__ = 'XYZ'


import ast
import argparse
import json
import os
import sys
import time

from datetime import datetime
from pathlib import Path

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms

from .core._log_ import logger
log = logger(__file__)

from .dataset import DataSet, get_dataloader, get_splits
from .utils.torchapi import loadmodel, unloadmodel
from .utils.modelstats import (
  gather_system_info,
  model_summary,
  model_perfstats,
  save_metrics,
  save_model_summary,
  save_model_perfstats,
  save_report,
)
from .utils.classificationmetrics import (
  save_confusion_matrix_with_fp_fn_imagelist,
  plot_curves,
  plot_and_save_curves,
)


## -----------------------------------------------------------------
## Internal implementations (standalone functions)
## -----------------------------------------------------------------
def _load_split(args, splits, split="test"):
  """Load dataset split for inference (internal)."""
  transform = transforms.Compose([
    transforms.Resize(args.input_size),
    transforms.ToTensor(),
  ])

  dataset = DataSet(
    root=splits['loadertxt'],
    lists=splits[f"{split}loadertxt"],
    input_size=args.input_size,
    transform=transform,
    sample_limit=args.sample_limit
  )

  loader = get_dataloader(
    dataset=dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=False,
  )
  return loader, dataset


def _predict(model, dataloader, device):
  """Run inference and return logits and labels (internal)."""
  all_logits, all_labels = [], []
  total_time = 0.0

  model.eval()
  with torch.no_grad():
    with tqdm(total=len(dataloader), desc="Inferencing", unit="batch") as pbar:
      for images, labels, _ in dataloader:
        images, labels = images.to(device), labels.to(device)

        start_time = time.time()
        logits = model(images)
        total_time += time.time() - start_time

        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())
        pbar.update(1)

  return torch.cat(all_logits), torch.cat(all_labels), total_time


def _write_index_files(logits, labels, dataset, to_path):
  """Writes correct/incorrect index files and per-class splits with GT, predicted, and confidence."""
  import os
  from pathlib import Path
  import torch.nn.functional as F

  save_dir = Path(to_path)
  save_dir.mkdir(parents=True, exist_ok=True)

  ## compute predictions and confidence
  probs = F.softmax(logits, dim=1)
  conf_vals, pred_ids = torch.max(probs, dim=1)

  ## relative paths
  rel_paths = []
  for p in dataset.imgs:
    try:
      rel_p = os.path.relpath(p, start=dataset.root) if dataset.root else p
    except Exception:
      rel_p = p
    rel_paths.append(rel_p)

  ## build lines
  correct_lines, incorrect_lines = [], []
  correct_per_class, incorrect_per_class = {}, {}

  n = len(labels)
  for i in range(n):
    gt = int(labels[i].item())
    pr = int(pred_ids[i].item())
    sc = float(conf_vals[i].item())
    ln = f"{i}\t{rel_paths[i]}\t{gt}\t{pr}\t{sc:.4f}"

    if pr == gt:
      correct_lines.append(ln)
      correct_per_class.setdefault(gt, []).append(ln)
    else:
      incorrect_lines.append(ln)
      incorrect_per_class.setdefault(gt, []).append(ln)

  ## global files
  (save_dir / "correct.txt").write_text("\n".join(correct_lines))
  (save_dir / "incorrect.txt").write_text("\n".join(incorrect_lines))

  ## per-class files
  cdir = save_dir / "correct_per_class"
  idir = save_dir / "incorrect_per_class"
  cdir.mkdir(exist_ok=True)
  idir.mkdir(exist_ok=True)

  for cls_id, lines in correct_per_class.items():
    (cdir / f"c{cls_id}.txt").write_text("\n".join(lines))
  for cls_id, lines in incorrect_per_class.items():
    (idir / f"c{cls_id}.txt").write_text("\n".join(lines))

  log.info(f"Wrote index files with predictions to {save_dir}")


def _save_predictions(logits, labels, dataset, to_path, split):
  """Save raw predictions, logits, labels, and metadata (internal)."""
  save_dir = Path(to_path)
  save_dir.mkdir(parents=True, exist_ok=True)

  preds = torch.argmax(logits, dim=1)

  ## Save tensors
  torch.save(logits, save_dir / "logits.pt")
  torch.save(labels, save_dir / "labels.pt")

  ## Save JSONs
  preds_list = preds.tolist()
  labels_list = labels.tolist()
  with open(save_dir / "preds.json", "w") as f:
    json.dump(preds_list, f, indent=2)
  with open(save_dir / "labels.json", "w") as f:
    json.dump(labels_list, f, indent=2)

  ## Metadata
  metadata = {
    "split": split,
    "num_samples": len(dataset),
    "num_classes": len(dataset.index_to_label),
    "class_names": list(dataset.index_to_label.values()),
    "index_to_label": dataset.index_to_label,
    "label_to_index": dataset.label_to_index,
    "paths": dataset.imgs,
    "logits_shape": list(logits.shape)
  }
  with open(save_dir / "metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

  ## index-style files with confidence
  _write_index_files(logits, labels, dataset, save_dir)

  log.info(f"Predictions + logits saved to {save_dir}")


def _save_evaluations(args, model, logits, labels, dataset, to_path, split):
  """Save evaluation metrics, curves, and model stats (internal)."""
  device = "cuda" if args.gpu and torch.cuda.is_available() else "cpu"
  class_names = list(dataset.index_to_label.values())

  ## Metrics & visualizations
  metrics = save_confusion_matrix_with_fp_fn_imagelist(logits, labels, class_names, to_path, dataset)
  plot_curves(logits, labels, class_names, to_path)
  plot_and_save_curves(logits, labels, class_names, to_path)

  ## Model summaries & perf stats
  key_stats, summary_info = model_summary(
    model=model,
    input_size=(3, *args.input_size),
    device=device,
    verbose=False,
    num_iterations=args.num_iterations,
    weights_path=args.weights_path,
    dnnarch=args.net,
    num_class=len(class_names),
  )

  perfstats = model_perfstats(
    model=model,
    input_size=(3, *args.input_size),
    device=device,
    verbose=False,
    num_iterations=args.num_iterations,
    weights_path=args.weights_path,
    dnnarch=args.net,
    num_class=len(class_names),
  )

  ## Bundle
  report_data = {
    "system_info": gather_system_info(),
    "metrics": metrics,
    "dataset_info": {
      split: {
        "num_samples": len(dataset),
        "num_classes": len(class_names),
        "class_names": class_names,
      }
    },
    "model_info": {
      "architecture": model.__class__.__name__,
      "weights_path": args.weights_path,
      "input_size": args.input_size,
      "creator": __author__,
    },
    "modelsummary": key_stats,
    "modelperfstats": perfstats,
  }

  ## Persist
  save_metrics(to_path, metrics)
  save_model_summary(to_path, key_stats, summary_info)
  save_model_perfstats(to_path, perfstats)
  save_report(report_data, to_path)

  log.info(f"Evaluations completed. Results saved to {to_path}")
  return report_data


## -----------------------------------------------------------------
## Orchestrator-friendly wrappers (context-based)
## -----------------------------------------------------------------
def load_split(context, **kwargs):
  """Wrapper for orchestrator: load dataset split."""
  args = context["args"]
  splits = context["splits"]
  split = kwargs.get("split", "test")
  dataloader, dataset = _load_split(args, splits, split)
  context["dataloader"] = dataloader
  context["dataset"] = dataset
  return context


def predict(context, **kwargs):
  """Wrapper for orchestrator: run prediction."""
  model = context["model"]
  dataloader = context["dataloader"]
  device = context["device"]
  logits, labels, total_time = _predict(model, dataloader, device)
  context["logits"] = logits
  context["labels"] = labels
  context["total_time"] = total_time
  return context


def save_predictions(context, **kwargs):
  """Wrapper for orchestrator: save predictions."""
  args = context["args"]
  _save_predictions(
    context["logits"],
    context["labels"],
    context["dataset"],
    args.to_path,
    args.split,
  )
  return context


def save_evaluations(context, **kwargs):
  """Wrapper for orchestrator: save evaluations."""
  args = context["args"]
  _save_evaluations(
    args,
    context["model"],
    context["logits"],
    context["labels"],
    context["dataset"],
    args.to_path,
    args.split,
  )
  return context


## -----------------------------------------------------------------
## CLI entrypoint (unchanged)
## -----------------------------------------------------------------
def main(args):
  """Main pipeline: load_model → load_split → predict → save_predictions → save_evaluations."""
  __dataset_root__ = os.getenv("__DATASET_ROOT__")
  splits = get_splits(args.dataset, args.datasetcfg, __dataset_root__)

  timestamp = datetime.now().strftime("%d%m%y_%H%M%S")
  to_path = args.to_path or os.path.join("logs", f"inference-{timestamp}")
  save_dir = Path(to_path)

  device = "cuda" if args.gpu and torch.cuda.is_available() else "cpu"

  if not os.path.exists(args.weights_path):
    log.error(f"Weights not found at {args.weights_path}")
    sys.exit(1)

  model = loadmodel(args).to(device)
  dataloader, dataset = _load_split(args, splits, split=args.split)
  logits, labels, total_time = _predict(model, dataloader, device)

  _save_predictions(logits, labels, dataset, save_dir, args.split)
  _save_evaluations(args, model, logits, labels, dataset, save_dir, args.split)

  unloadmodel(model)


def parse_args():
  """Parse command-line arguments for inference + evaluation."""
  parser = argparse.ArgumentParser(description="Inference + evaluation module")
  parser.add_argument('--dataset', type=str, required=True)
  parser.add_argument('--datasetcfg', type=str, default="data/ddd-datasets.yml")
  parser.add_argument('--split', type=str, default="test")
  parser.add_argument('--weights_path', type=str, required=True)
  parser.add_argument('--net', type=str, required=True)
  parser.add_argument('--pretrain', action='store_true', default=False)
  parser.add_argument('--num_class', type=int, required=True)
  parser.add_argument('--loss', type=str, default="CrossEntropyLoss")
  parser.add_argument('--input_size', type=str, default="(224,224)")
  parser.add_argument('--batch_size', type=int, default=64)
  parser.add_argument('--num_workers', type=int, default=4)
  parser.add_argument('--sample_limit', type=int)
  parser.add_argument('--gpu', action='store_true', default=True)
  parser.add_argument('--num_iterations', type=int, default=100)
  parser.add_argument('--to', dest='to_path', type=str)

  args = parser.parse_args()

  try:
    args.input_size = ast.literal_eval(args.input_size)
    if not isinstance(args.input_size, tuple) or len(args.input_size) != 2:
      raise ValueError
  except (ValueError, SyntaxError):
    log.error("Error: --input_size should be a tuple of two integers, e.g., '(224,224)'")
    sys.exit(1)

  return args


def print_args(args):
  """Print parsed arguments."""
  print("Arguments:")
  for k, v in vars(args).items():
    print(f"{k}: {v}")


if __name__ == "__main__":
  args = parse_args()
  print_args(args)
  main(args)
