# main.py
import argparse
import os
import datetime
import random
import sys

import numpy as np
import torch

# Import the cleaned training entrypoint (from the previous step).
from train import train


def str2bool(v):
    if isinstance(v, bool):
        return v
    return str(v).lower() in ("1", "true", "t", "yes", "y")


def parse_args():
    parser = argparse.ArgumentParser(description="CGMN training (single-run, no Optuna).")

    # Runtime / IO
    parser.add_argument("--gpu", type=int, default=0, help="GPU id to use")
    parser.add_argument("--data_path", type=str, default="./data", help="Path to preprocessed npy/csv data")
    parser.add_argument("--label_dir", type=str, default="", help="Optional directory for label files (emo_def.npy, etc.)")
    parser.add_argument("--output_dir", type=str, default="./outputs", help="Directory to save logs/metrics")

    # Training
    parser.add_argument("--epochs", type=int, default=30, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=52, help="Mini-batch size")
    parser.add_argument("--seed", type=int, default=921, help="Random seed")
    parser.add_argument("--use_head", type=str2bool, default=True, help="Whether to use head modality")
    parser.add_argument(
        "--ablate",
        type=str,
        default="none",
        choices=[
            "none",
            "drop_head", "drop_context", "drop_body", "drop_object", "drop_depth",
            "wodasor", "wohegr", "womlp",
            "wo_sem_1024", "wo_cooccur_1024",
        ],
        help=(
            "Ablation setting:\n"
            "  Modal ablations: drop_head | drop_context | drop_body | drop_object | drop_depth\n"
            "  Structural ablations: wodasor (w/o distance-weighted object pooling / DASOR), "
            "wohegr (no graph, MLP only), womlp (no MLP, graph only)\n"
            "  Special: wo_sem_1024 (semantic-only GCN, 1024-d, no label loss), "
            "wo_cooccur_1024 (co-occur-only GCN, 1024-d, no label loss)"
        ),
    )

    # Loss
    parser.add_argument("--loss_ratio", type=float, default=0.8, help="Weight for classification loss (rest for label loss)")

    # Optimizer / LR
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for CGMN head")
    parser.add_argument("--lr_ratio", type=float, default=1e-2, help="Backbone LR = lr * lr_ratio")
    parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay")

    # Scheduler
    parser.add_argument("--scheduler_type", type=str, default="step", choices=["step", "cosine", "cosine_restart"])
    parser.add_argument("--step_size", type=int, default=7, help="StepLR: step_size")
    parser.add_argument("--gamma", type=float, default=0.5, help="StepLR: gamma")
    parser.add_argument("--T_max", type=int, default=20, help="CosineAnnealingLR: T_max")
    parser.add_argument("--T_0", type=int, default=10, help="CosineAnnealingWarmRestarts: T_0")
    parser.add_argument("--T_mult", type=int, default=2, help="CosineAnnealingWarmRestarts: T_mult")
    parser.add_argument("--eta_min", type=float, default=1e-6, help="Cosine schedulers: eta_min")

    # Label graph thresholds
    parser.add_argument("--t_sem", type=float, default=0.8, help="Threshold for semantic similarity adjacency")
    parser.add_argument("--t_cooccur", type=float, default=0.3, help="Threshold for co-occurrence adjacency")
    parser.add_argument("--p", type=float, default=0.5, help="Sparsification power / normalization factor")

    # Augmentation
    parser.add_argument(
        "--augmentation_strategy",
        type=str,
        default="standard",
        choices=["standard", "aggressive", "emotion_focused", "minimal"],
        help="Data augmentation preset",
    )

    args = parser.parse_args()
    return args


def set_all_seeds(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


if __name__ == "__main__":
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    set_all_seeds(args.seed)

    # Normalization stats for each modality (mean/std triplets).
    # Keep aligned with dataset preprocessing.
    context_mean = [0.47, 0.44, 0.41]
    context_std  = [0.25, 0.24, 0.24]
    body_mean    = [0.44, 0.40, 0.37]
    body_std     = [0.25, 0.24, 0.23]
    head_mean    = [0.44, 0.40, 0.37]
    head_std     = [0.25, 0.24, 0.23]
    depth_mean   = [15.13, 15.13, 15.13]
    depth_std    = [9.54, 9.54, 9.54]

    context_norm = [context_mean, context_std]
    body_norm    = [body_mean, body_std]
    head_norm    = [head_mean, head_std]
    depth_norm   = [depth_mean, depth_std]
    norm = [context_norm, body_norm, depth_norm, head_norm]

    print("=" * 80)
    print("CGMN training (single run)")
    print("=" * 80)
    print(f"Start time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"GPU: {args.gpu}")
    print(f"Data path: {args.data_path}")
    print(f"Label dir: {args.label_dir or '(inherit from data_path / current dir)'}")
    print(f"Output dir: {args.output_dir}")
    print(f"Use head: {args.use_head}")
    print(f"Ablation: {args.ablate}")
    print(f"Epochs: {args.epochs} | Batch size: {args.batch_size}")
    print(f"LR: {args.lr} | LR ratio: {args.lr_ratio} | WD: {args.wd}")
    print(f"Scheduler: {args.scheduler_type}")
    print(f"t_sem: {args.t_sem} | t_cooccur: {args.t_cooccur} | p: {args.p}")
    print(f"Augmentation: {args.augmentation_strategy}")
    print("=" * 80)

    try:
        best_map = train(norm, args)
        print("\n" + "=" * 80)
        print("Training finished")
        print("=" * 80)
        print(f"Best mAP: {best_map:.6f}")
    except Exception as e:
        print(f"[Error] Training failed: {str(e)}", file=sys.stderr)
        raise
