# yolo.py – YOLOv5 fine‑tuning wrapper with layer‑wise LR decay (LLRD)
# ----------------------------------------------------------------------------------
# This is a *full* drop‑in replacement for the previous yolo.py you shared.  The only
# behavioural change is that you can now enable layer‑wise learning‑rate decay by
# setting `--llrd` (or `llrd: true` in your YAML).  If you leave it turned off, the
# optimiser behaves exactly as before.
# ----------------------------------------------------------------------------------

import argparse
import yaml
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from .wrapper import BaseModelWrapper
from .bd_yolo.yolo_attack import YOLOAdapter
from .bd_yolo.yolov5.models.yolo import Detect

import albumentations as A
from albumentations.pytorch import ToTensorV2
import math
import sys
import os
import json
from tqdm import tqdm
import cv2
import numpy as np

from .bd_yolo.utils import yolo_gpu_augment

from matplotlib import pyplot as plt
import matplotlib.patches as patches

# ----------------------------------------------------------------------------------
# Utility helpers
# ----------------------------------------------------------------------------------

def one_cycle(y1: float = 0.0, y2: float = 1.0, steps: int = 1000):
    """Ultralytics’ half‑cosine schedule (utils/torch_utils.py)."""
    return lambda x: ((1 + math.cos(math.pi * x / steps)) / 2) * (y2 - y1) + y1


# ----------------------------------------------------------------------------------
# Optimiser + Scheduler with optional Layer‑wise LR Decay (LLRD)
# ----------------------------------------------------------------------------------

def get_optimize_schedular(
    model: nn.Module,
    batch_size: int,
    nbs: int = 64,
    name: str = "AdamW",
    lr: float = 1e-3,
    momentum: float = 0.9,
    decay: float = 5e-4,
    lrf: float = 0.01,
    epochs: int = 100,
    cos_lr: bool = True,
    llrd: bool = False,
    llrd_decay: float = 0.8,
):
    """Return `(optimizer, scheduler)` with optional layer‑wise LR decay.
    Uses *id*-based membership checks and a visited set so every parameter
    appears in exactly **one** param group (fixes duplicate‑parameter error)."""

    # ------------------------------------------------------------------
    # 1.  Split params into buckets   (weight‑decay | no‑decay | bias)
    # ------------------------------------------------------------------
    g0 = []  # with‑decay
    g1 = []  # no‑decay (norm weights)
    g2 = []  # biases
    norm_layers = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)

    for m in model.modules():
        for n, p in m.named_parameters(recurse=False):
            if not p.requires_grad:
                continue
            if n == "bias":
                g2.append(p)
            elif n == "weight" and isinstance(m, norm_layers):
                g1.append(p)
            else:
                g0.append(p)

    # Identity sets for O(1) membership checks
    g0_ids, g1_ids, g2_ids = {id(p) for p in g0}, {id(p) for p in g1}, {id(p) for p in g2}

    # ------------------------------------------------------------------
    # 2.  Build param‑groups (depth‑wise if llrd=True)
    # ------------------------------------------------------------------
    param_groups = []

    if llrd:
        layers = (
            list(model.model) if hasattr(model, "model") and isinstance(model.model, nn.Sequential)
            else list(model.children())
        )
        n_layers = len(layers)

        def depth_mult(idx: int) -> float:
            return llrd_decay ** (n_layers - 1 - idx)

        visited = set()
        for depth, layer in enumerate(layers):
            mult = depth_mult(depth)
            pg = {"params": [], "lr_mult": mult, "is_bias": False}
            for p in layer.parameters(recurse=True):
                pid = id(p)
                if pid in visited:
                    continue  # already assigned
                visited.add(pid)

                if pid in g2_ids:
                    pg["weight_decay"] = 0.0
                    pg["is_bias"] = True
                elif pid in g1_ids:
                    pg["weight_decay"] = 0.0
                elif pid in g0_ids:
                    pg["weight_decay"] = decay
                else:
                    continue
                pg["params"].append(p)
            if pg["params"]:
                param_groups.append(pg)
    else:
        param_groups = [
            {"params": g2, "weight_decay": 0.0, "lr_mult": 1.0, "is_bias": True},
            {"params": g0, "weight_decay": decay, "lr_mult": 1.0},
            {"params": g1, "weight_decay": 0.0, "lr_mult": 1.0},
        ]

    # ------------------------------------------------------------------
    # 3.  Scale weight‑decay by effective batch size (Ultralytics behaviour)
    # ------------------------------------------------------------------
    decay *= batch_size / max(nbs, batch_size)
    for pg in param_groups:
        if pg.get("weight_decay", 0.0) != 0:
            pg["weight_decay"] = decay

    # ------------------------------------------------------------------
    # 4.  Create optimiser
    # ------------------------------------------------------------------
    name_low = name.lower()
    if name_low == "adam":
        opt_cls = torch.optim.Adam
        opt_kwargs = {"lr": lr, "betas": (momentum, 0.999)}
    elif name_low == "adamw":
        opt_cls = torch.optim.AdamW
        opt_kwargs = {"lr": lr, "betas": (momentum, 0.999)}
    elif name_low == "rmsprop":
        opt_cls = torch.optim.RMSprop
        opt_kwargs = {"lr": lr, "momentum": momentum}
    elif name_low == "sgd":
        opt_cls = torch.optim.SGD
        opt_kwargs = {"lr": lr, "momentum": momentum, "nesterov": True}
    else:
        raise NotImplementedError(f"Optimizer {name} not supported.")

    optimizer = opt_cls(param_groups, **opt_kwargs)

    # Stamp *initial* LR per group (warm‑up uses this field)
    for pg in optimizer.param_groups:
        pg["initial_lr"] = lr * pg.get("lr_mult", 1.0)
        pg["lr"] = pg["initial_lr"]

    # ------------------------------------------------------------------
    # 5.  Scheduler (keeps group ratios intact)
    # ------------------------------------------------------------------
    if cos_lr:
        lmbd = one_cycle(lrf * lr, lr, epochs)
    else:
        lmbd = lambda x: (1 - x / epochs) * (1 - lrf) + lrf

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbd)
    return optimizer, scheduler


# ----------------------------------------------------------------------------------
#   YOLO Model Wrapper
# ----------------------------------------------------------------------------------

class YOLOModelWrapper(BaseModelWrapper):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ni = 0               # global iteration counter
        self.warmup_steps = None  # filled after first epoch

    # ------------------------------------------------------------------
    #  Configuration & CLI parsing
    # ------------------------------------------------------------------
    def __load_config__(self):
        # 1) Parse YAML config
        with open(self.config_path, "r") as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser()

        # -------------------- core args -------------------- #
        parser.add_argument("--use_mapping",      action="store_true")
        parser.add_argument("--cfg",             type=str)
        parser.add_argument("--conf_thres",      type=float, default=0.25)
        parser.add_argument("--iou_thres",       type=float, default=0.45)
        parser.add_argument("--weight_path",     type=str)
        parser.add_argument("--new_head_weights",action="store_true")
        parser.add_argument("--small",           action="store_true")

        # -------------------- optimiser -------------------- #
        parser.add_argument("--epochs",          type=int,   default=100)
        parser.add_argument("--batch_size",      type=int,   default=64)
        parser.add_argument("--opt",             type=str,   default="sgd")
        parser.add_argument("--cos_lr",          action="store_true")
        parser.add_argument("--lr0",             type=float, default=0.01)
        parser.add_argument("--lrf",             type=float, default=0.1)
        parser.add_argument("--momentum",        type=float, default=0.937)
        parser.add_argument("--weight_decay",    type=float, default=5e-4)

        # LLRD
        parser.add_argument("--llrd",            action="store_true")
        parser.add_argument("--llrd_decay",      type=float, default=0.8)

        # -------------------- warm‑up ---------------------- #
        parser.add_argument("--warmup_epochs",   type=int,   default=3)
        parser.add_argument("--warmup_momentum", type=float, default=0.8)
        parser.add_argument("--warmup_bias_lr",  type=float, default=0.1)

        # -------------------- loss hypers ------------------ #
        parser.add_argument("--hyp_label_smoothing", type=float, default=0.0)
        parser.add_argument("--hyp_fl_gamma",        type=float, default=0.0)
        parser.add_argument("--hyp_anchor_t",        type=float, default=4.0)
        parser.add_argument("--hyp_box",             type=float, default=0.05)
        parser.add_argument("--hyp_obj",             type=float, default=1.0)
        parser.add_argument("--hyp_cls",             type=float, default=0.5)
        parser.add_argument("--hyp_cls_pw",          type=float, default=1.0)
        parser.add_argument("--hyp_obj_pw",          type=float, default=1.0)
        parser.add_argument("--hyp_attack",          type=float)

        # Parse without CLI
        args = parser.parse_args([])

        # Merge YAML → argparse namespace
        for k, v in config.items():
            if hasattr(args, k):
                setattr(args, k, v)

        self.args = args
        self.epochs = args.epochs

        # Mapping (optional)
        if args.use_mapping:
            current_dir = os.path.dirname(os.path.abspath(__file__))
            mapping_path = os.path.join(current_dir, "bd_yolo", "mappings",
                                       f"{self.dataset}_id_mapping.json")
            if not os.path.exists(mapping_path):
                raise FileNotFoundError(mapping_path)
            with open(mapping_path, "r") as f:
                self.id_mapping = {int(k): int(v) for k, v in json.load(f).items()}

    # ------------------------------------------------------------------
    #  Model & optimiser init
    # ------------------------------------------------------------------
    def __initialize_model__(self):
        a = self.args

        # hyper‑config dict passed into YOLOAdapter
        hyper = {
            "hyp_label_smoothing": a.hyp_label_smoothing,
            "hyp_fl_gamma":        a.hyp_fl_gamma,
            "hyp_anchor_t":        a.hyp_anchor_t,
            "hyp_box":             a.hyp_box,
            "hyp_obj":             a.hyp_obj,
            "hyp_cls":             a.hyp_cls,
            "hyp_cls_pw":          a.hyp_cls_pw,
            "hyp_obj_pw":          a.hyp_obj_pw,
            "hyp_attack":          a.hyp_attack,
        }

        # Build YOLO backbone+head through the adapter
        adapter = YOLOAdapter(
            device=self.device,
            cfg=a.cfg,
            img_size=640 if a.small else 1280,
            hyper_config=hyper,
            weights_path=a.weight_path,
            new_head_weights=a.new_head_weights,
            conf_thres=a.conf_thres,
            iou_thres=a.iou_thres,
            id_mapping=getattr(self, "id_mapping", None),
            decrement_ids=not a.use_mapping,
        )

        if self.distributed:
            adapter = nn.SyncBatchNorm.convert_sync_batchnorm(adapter)
            adapter = DDP(adapter, device_ids=[self.local_rank], output_device=self.local_rank)

        self.model = adapter

        # Optimiser + scheduler ------------------------------------------------
        self.optimizer, self.scheduler = get_optimize_schedular(
            model=adapter,
            batch_size=a.batch_size,
            nbs=64,
            name=a.opt,
            lr=a.lr0,
            momentum=a.momentum,
            decay=a.weight_decay,
            lrf=a.lrf,
            epochs=a.epochs,
            cos_lr=a.cos_lr,
            llrd=a.llrd,
            llrd_decay=a.llrd_decay,
        )

        self.scaler = torch.amp.GradScaler()

    # ------------------------------------------------------------------
    #  (Checkpoint helpers unchanged)
    # ------------------------------------------------------------------
    def __save_model__(self):
        adapter = self.model.module if isinstance(self.model, DDP) else self.model
        ckpt = {
            "epoch": self.current_epoch,
            "model_state": adapter.model.state_dict(), # For resuming training
            "ema_state": adapter.ema.ema.state_dict(), # For final inference
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
        }
        torch.save(ckpt, self.checkpoint_path)

    def __load_model__(self):
        ckpt = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
        adapter = self.model.module if isinstance(self.model, DDP) else self.model
        
        adapter.model.load_state_dict(ckpt["model_state"])
        adapter.ema.ema.load_state_dict(ckpt["ema_state"])
        
        self.optimizer.load_state_dict(ckpt["optimizer"])
        self.scheduler.load_state_dict(ckpt["scheduler"])
        self.current_epoch = ckpt["epoch"]

    # ------------------------------------------------------------------
    #  Albumentations pipelines (no change)
    # ------------------------------------------------------------------
    def transform_train(self, bbox_input_format):
        size = 640 if self.args.small else 1280
        t = A.Compose([
            A.Resize(size, size, interpolation=cv2.INTER_LINEAR),
            A.PadIfNeeded(size, size, border_mode=cv2.BORDER_CONSTANT, value=114),
            A.RandomBrightnessContrast(p=0.5),
            A.HueSaturationValue(p=0.5),
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ToFloat(max_value=255.0),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(
            format=bbox_input_format,
            label_fields=["category_ids", "poison_masks", "target_ids"],
            check_each_transform=True,
            filter_invalid_bboxes=True,
            clip=True))
        return t, "yolo"

    def transform_test(self, bbox_input_format):
        size = 640 if self.args.small else 1280

        if self.dataset == "ptsd":
            t = A.Compose([
                A.Resize(size, size, interpolation=cv2.INTER_LINEAR),
                A.PadIfNeeded(size, size, border_mode=cv2.BORDER_CONSTANT, value=114),
                A.ToFloat(max_value=255.0),
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=["category_ids"],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True))
        else:
            t = A.Compose([
                A.Resize(size, size, interpolation=cv2.INTER_LINEAR),
                A.PadIfNeeded(size, size, border_mode=cv2.BORDER_CONSTANT, value=114),
                A.ToFloat(max_value=255.0),
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=["category_ids", "poison_masks", "target_ids"],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True))
        return t, "pascal_voc"

    # ------------------------------------------------------------------
    #  Training loop – only warm‑up section changed to be group‑agnostic
    # ------------------------------------------------------------------
    def train_one_epoch(self, dataloader, epoch):

        # 1) set warm‑up length
        if self.warmup_steps is None:
            self.warmup_steps = self.args.warmup_epochs * len(dataloader)

            for pg in self.optimizer.param_groups:
                pg.setdefault("initial_lr", pg["lr"])

        # 2) prep model
        if self.distributed:
            dataloader.sampler.set_epoch(epoch)
            self.model.module.train()
        else:
            self.model.train()

        epoch_losses = {}
        for images, targets in dataloader:
            images = [img.to(self.device) for img in images]

            # scale targets to [0,1] in cxcywh format
            for i, tgt in enumerate(targets):
                h, w = images[i].shape[1:3]

                tgt["boxes"] = tgt["boxes"].to(self.device) / torch.tensor([w, h, w, h], device=self.device)
                tgt["labels"] = tgt["labels"].to(self.device)
                tgt["poison_masks"] = tgt["poison_masks"].to(self.device)
                tgt["target_labels"] = tgt["target_labels"].to(self.device)

            # images, targets = yolo_gpu_augment(
            #     images, targets,
            #     dataset=dataloader.dataset,
            #     device=self.device,
            #     img_size=640 if self.args.small else 1280,
            # )

            # ----------------------------- warm‑up -----------------------------
            if self.ni <= self.warmup_steps:
                xi = [0, self.warmup_steps]
                for pg in self.optimizer.param_groups:
                    start_lr = self.args.warmup_bias_lr if pg.get("is_bias", False) else 0.0
                    pg["lr"] = np.interp(self.ni, xi, [start_lr, pg["initial_lr"]])
                    if "momentum" in pg:
                        pg["momentum"] = np.interp(self.ni, xi, [self.args.warmup_momentum, self.args.momentum])

            # forward + backward
            loss_dict = self.model(images, targets=targets)

            losses = sum(loss for loss in loss_dict.values())
            loss_val = losses.item()
            if not math.isfinite(loss_val):
                print("Non‑finite loss, aborting!", loss_dict)
                sys.exit(1)

            self.optimizer.zero_grad()
            self.scaler.scale(losses).backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)

            # Check if the model has any gradients
            if not any(p.grad is not None for p in self.model.parameters()):
                raise RuntimeError("Model has no gradients, check your loss function and inputs.")

            self.scaler.step(self.optimizer)
            self.scaler.update()

            model_to_update = self.model.module if isinstance(self.model, DDP) else self.model
            model_to_update.update_ema()

            # bookkeeping
            for k, v in loss_dict.items():
                epoch_losses[k] = epoch_losses.get(k, 0.0) + v.item()
            self.ni += 1

        self.current_epoch += 1
        return epoch_losses
