import torch
import torch.nn as nn
from copy import deepcopy

from .yolov5.models.yolo import Model
from .yolov5.utils.general import non_max_suppression
from .loss import ComputeLoss


class ModelEMA:
    def __init__(self, model, decay=0.9999, updates=0):
        self.ema = deepcopy(model)
        self.ema.eval()
        self.updates = updates
        self.decay = lambda x: decay * (1 - torch.exp(torch.tensor(-x / 2000, dtype=torch.float32)))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.state_dict()
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1.0 - d) * msd[k].detach()


class YOLOAdapter(nn.Module):
    def __init__(
        self,
        device,
        cfg: str,
        img_size: int = 1280,
        hyper_config: dict = None,
        weights_path: str = None,
        new_head_weights: bool = False,
        conf_thres: float = 0.25,
        iou_thres: float = 0.45,
        id_mapping: dict = None,
        decrement_ids: bool = True
    ):
        super().__init__()

        self.device = device
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres
        self.id_mapping = id_mapping
        self.decrement_ids = decrement_ids
        self.img_size = img_size

        if id_mapping is not None:
            # for inference reverse-lookup
            self.id_reverse_mapping = {v: k for k, v in id_mapping.items()}

        # load model
        self.model = Model(cfg, ch=3).to(device)

        if weights_path is not None:
            ckpt = torch.load(weights_path, map_location=device, weights_only=False)
        
            # If it doesnt have a 'model' key, it is probably a state_dict
            if "model" in ckpt:
                model_or_sd = ckpt["model"]
            else:
                model_or_sd = ckpt

            sd = (
                model_or_sd.state_dict()
                if isinstance(model_or_sd, nn.Module)
                else model_or_sd
            )

            if new_head_weights:
                model_sd = self.model.state_dict()
                filtered_sd = {
                    k: v for k, v in sd.items()
                    if k in model_sd and model_sd[k].shape == v.shape
                }

                print(f'Loading {len(filtered_sd)} weights from the state_dict')

                self.model.load_state_dict(filtered_sd, strict=False)
                
            else:
                if 'model' in ckpt:
                    model_or_sd = ckpt['model']
                elif 'model_state' in ckpt:
                    model_or_sd = ckpt['model_state']
                else:
                    model_or_sd = ckpt

                sd = (
                    model_or_sd.state_dict()
                    if isinstance(model_or_sd, nn.Module)
                    else model_or_sd
                )

                self.model.load_state_dict(sd, strict=True)
        else:
            print("No weights provided, using random initialization")

        self.ema = ModelEMA(self.model)

        self.loss_fn = ComputeLoss(
            self.model,
            hypers=hyper_config,
            id_mapping=id_mapping,
            decrement_ids=decrement_ids
        )

    def update_ema(self):
        # This will be called from the training loop in the wrapper
        self.ema.update(self.model)

    def forward(self, images, targets=None):
        """
        images: List[Tensor[C,H,W]] or Tensor[B,C,H,W] *already* in [0,1]
        targets: training targets (with normalized boxes) or None
        """
        # 1) Get images
        images_tensor = torch.stack(images, dim=0) if isinstance(images, list) else images

        # 2) forward (select model or ema depending on training)
        model_to_use = self.model if self.training else self.ema.ema
        preds = model_to_use(images_tensor)

        # 3) training
        if self.training and targets is not None:
            
            train_targets = []
            for i, target in enumerate(targets):

                # Create an cxcywh to xyxy conversion
                boxes = target['boxes']
                img_height, img_width = images_tensor[i].shape[1:]

                x1y1 = boxes[:, :2] - boxes[:, 2:] / 2
                x2y2 = boxes[:, :2] + boxes[:, 2:] / 2
                xyxy = torch.cat((x1y1, x2y2), dim=-1)

                # Unnormalize boxes to real image size
                xyxy[:, 0] *= img_width
                xyxy[:, 1] *= img_height
                xyxy[:, 2] *= img_width
                xyxy[:, 3] *= img_height

                target['unnormalized_boxes'] = xyxy
                train_targets.append(target)

            return self.loss_fn(preds, train_targets)

        # 4) inference: NMS + scale_coords + remap
        dets = non_max_suppression(
            preds,
            conf_thres=self.conf_thres,
            iou_thres=self.iou_thres
        )

        out = []

        for i, det in enumerate(dets):
            if det is None or not len(det):
                out.append({
                    'boxes':  torch.zeros((0,4), device=self.device),
                    'scores': torch.zeros((0,), device=self.device),
                    'labels': torch.zeros((0,), dtype=torch.long, device=self.device)
                })
                continue

            boxes = det[:, :4]
            scores = det[:, 4]
            classes = det[:, 5].long()

            # remap classes
            if self.id_mapping is not None:
                labels = torch.tensor([
                    self.id_reverse_mapping[int(c)] for c in classes
                ], dtype=torch.long, device=self.device)
            elif self.decrement_ids:
                labels = (classes + 1).clamp(min=0)
            else:
                labels = classes

            out.append({ 'boxes': boxes, 'scores': scores, 'labels': labels })

        return out