import numpy as np
from tqdm import tqdm 
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO
from ultralytics.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
from ultralytics.utils.ops import xywh2xyxy
from ultralytics.utils.loss import BboxLoss #v8DetectionLoss
from ultralytics.utils.metrics import box_iou, DetMetrics
from ultralytics.utils import ops

class PatchAttentionFusion(nn.Module):
    def __init__(self, channels, patch_size=16):
        super().__init__()
        self.patch_size = patch_size
        self.channels = channels
        
        # Patch embedding dimension (patch_size^2 * channels)
        self.patch_dim = patch_size * patch_size * channels
        
        # Linear projections for Q, K, V (shared across all inputs)
        self.q_proj = nn.Linear(self.patch_dim, self.patch_dim)
        self.k_proj = nn.Linear(self.patch_dim, self.patch_dim)
        self.v_proj = nn.Linear(self.patch_dim, self.patch_dim)
        
        # Additional projection to get attention coefficient from self-attention output
        self.coeff_proj = nn.Sequential(
            nn.Linear(self.patch_dim, self.patch_dim // 4),
            nn.ReLU(),
            nn.Linear(self.patch_dim // 4, self.patch_dim),  # Single coefficient per patch
            nn.Sigmoid()  # Ensure coefficient is between 0 and 1
        )

    def extract_patches(self, x):
        B, C, H, W = x.size()
                
        # Extract patches using unfold
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        
        num_patches_h = H // self.patch_size
        num_patches_w = W // self.patch_size
        
        # Reshape to (B, num_patches, patch_dim)
        patches = patches.contiguous().view(B, C, num_patches_h * num_patches_w, -1)
        patches = patches.permute(0, 2, 1, 3).contiguous()  # (B, num_patches, C, patch_size^2)
        patches = patches.view(B, num_patches_h * num_patches_w, -1)  # (B, num_patches, patch_dim)
        
        return patches, (num_patches_h, num_patches_w)


    def reconstruct_coefficients(self, patch_coefficients, patch_grid_shape, original_shape):
        B, C, H, W = original_shape
        num_patches_h, num_patches_w = patch_grid_shape
        
        spatial_coeffs = patch_coefficients.view(B, num_patches_h, num_patches_w, -1)
        spatial_coeffs = spatial_coeffs.permute(0, 3, 1, 2)  # (B, -1, num_patches_h, num_patches_w)
        
        spatial_coeffs = spatial_coeffs.contiguous().view(B, C, H, W) # (B, C, H, W)
        
        return spatial_coeffs

    def forward(self, scaled_inputs):
        B, C, H, W = scaled_inputs.size()
        
        # Extract patches from input
        patches, patch_grid_shape = self.extract_patches(scaled_inputs)  # (B, num_patches, patch_dim)
        
        # Project to Q, K, V for self-attention
        Q = self.q_proj(patches)  # (B, num_patches, patch_dim)
        K = self.k_proj(patches)  # (B, num_patches, patch_dim)
        V = self.v_proj(patches)  # (B, num_patches, patch_dim)
        
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.patch_dim ** 0.5)  # (B, num_patches, num_patches)
        attn_weights = F.softmax(attn_scores, dim=-1)  # (B, num_patches, num_patches)
        x_attn = torch.matmul(attn_weights, V)  # (B, num_patches, patch_dim)
        
        patch_coefficients = self.coeff_proj(x_attn)
        spatial_coefficients = self.reconstruct_coefficients(patch_coefficients, patch_grid_shape, (B, C, H, W))

        return spatial_coefficients


class CustomYOLOPipeline(nn.Module):
    def __init__(self, yolo_config_path, ampm, cat, fcem, srisp):
        super().__init__()
        
        self.ampm = ampm
        self.cat = cat
        self.fcem = fcem
        self.srisp = srisp
        self.epoch = 0

        self.dynamic_weights = PatchAttentionFusion(3) #DynamicAlphaWeights(3)  # Assuming 3 modules
        self.yolo_model = YOLO(yolo_config_path).task_map['detect']['model'](yolo_config_path, nc=1)
        
        for model in [self.ampm, self.cat, self.fcem, self.srisp, self.dynamic_weights]:
            for param in model.parameters():
                param.requires_grad = True 
        self.yolo_model.requires_grad_(True)  # Only train YOLO model
        
    def forward(self, x, k):
                
        img = x['img']
        x_ampm = self.ampm.forward(img)
        
        x_ampm = x_ampm.detach()
        
        x_cat = self.cat.forward(x_ampm)
        x_fcem = self.fcem.forward(x_ampm)
        x_srisp = self.srisp.forward(x_ampm)
        
        # srisp_weights, fcem_weights, cat_weights = self.dynamic_weights(x_srisp, x_fcem, x_cat)
        x_cat_attn = self.dynamic_weights(x_cat)
        x_fcem_attn = self.dynamic_weights(x_fcem)
        x_srisp_attn = self.dynamic_weights(x_srisp)
                
        attentions = F.softmax(torch.stack([x_cat_attn, x_fcem_attn, x_srisp_attn], dim=1), dim=2)
        
        x_cat_attn = x_cat * attentions[:, :, 0, :]
        x_fcem_attn = x_fcem * attentions[:, :, 1, :]
        x_srisp_attn = x_srisp * attentions[:, :, 2, :]
                
        processed_x = x_cat_attn + x_fcem_attn + x_srisp_attn
                
        return self.yolo_model(processed_x)


class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Validation:
    def __init__(self, args, model):
        self.args = args
        self.model = model
        self.metrics = DetMetrics(names={})
        
    def _prepare_batch(self, si, batch):
        idx = batch["batch_idx"] == si
        cls = batch["cls"][idx].squeeze(-1)
        bbox = batch["bboxes"][idx]
        ori_shape = batch["ori_shape"][si]
        imgsz = batch["img"].shape[2:]
        ratio_pad = batch["ratio_pad"][si]
        if len(cls):
            bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz)[[1, 0, 1, 0]] 
            ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)
        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
    
    def _prepare_pred(self, pred, pbatch):
        predn = pred.clone()
        ops.scale_boxes(
            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
        )  
        return predn
    
    def postprocess(self, preds):        
        return ops.non_max_suppression(
            preds,
            conf_thres=0.1, 
            iou_thres=0.8,
            labels=[],
            multi_label=True,
            agnostic=self.args.single_cls,
            max_det=self.args.max_det,
        )
    
    def forward(self, dataloader):
        self.model.eval()
        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])  # tracked stats
        
        self.regression_stats = dict(pred_values=[], target_values=[], conf_scores=[])

        self.interpolate_thresholds = torch.linspace(0.5, 0.95, 10)  # IoU thresholds
        
        for batch in tqdm(dataloader, desc='Evaluating', total=len(dataloader)):
            batch['img'] = batch['img'].cuda(non_blocking=True).float() / 255.0

            with torch.no_grad():
                preds = self.model(batch, 0)

            preds = self.postprocess(preds)

            for si, pred in enumerate(preds):                
                npr = len(pred)
                stat = dict(
                    conf=torch.zeros(0),
                    pred_cls=torch.zeros(0),
                    tp=torch.zeros(npr, self.interpolate_thresholds.numel(), dtype=torch.bool),
                )

                pbatch = self._prepare_batch(si, batch)
                cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
                nl = len(cls)

                stat["target_cls"] = cls
                stat["target_img"] = cls.unique()
                
                predn = self._prepare_pred(pred, pbatch)
                if predn is None or predn.numel() == 0:
                    continue  # no predictions

                stat["conf"] = predn[:, 4]
                stat["pred_cls"] = predn[:, 5]

                if nl:
                    stat["tp"] = self._process_batch(predn, bbox, cls)
                
                if nl > 0 and len(predn) > 0:  # Only if we have both predictions and ground truth
                    matched_gt_coords, matched_pred_coords = self._match_pred_target(predn, bbox, cls)
                    if matched_gt_coords and matched_pred_coords:
                        self.regression_stats['pred_values'].extend(matched_pred_coords)
                        self.regression_stats['target_values'].extend(matched_gt_coords)
                
                self.regression_stats['conf_scores'].append(stat["conf"].detach().cpu())
                    
                for k in self.stats.keys():
                    self.stats[k].append(stat[k].detach().cpu())

        
        stats = {}
        for k, v in self.stats.items():
            if len(v) == 0:
                stats[k] = np.zeros(0)
            else:
                stats[k] = torch.cat(v, 0).cpu().numpy()
        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=1)
        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=1)
        stats.pop("target_img", None)
        
        if len(stats) and stats["tp"].any():
            self.metrics.process(**stats)
        
        additional_metrics = self._calculate_additional_metrics()
        
        map50 = self.metrics.box.map50
        map75 = self.metrics.box.map75
        map5095 = self.metrics.box.map
        
        # Return all metrics
        return map50, map75, map5095, additional_metrics.get('MAE', 0.0), additional_metrics.get('RMSE', 0.0), additional_metrics.get('MAPE', 0.0), \
                additional_metrics.get('DS@0.5', 0.0), additional_metrics.get('DS@0.75', 0.0), additional_metrics.get('mDS@0.5:0.95', 0.0), \

    def _calculate_iou(self, box1, box2):
        x1_max = max(box1[0], box2[0])
        y1_max = max(box1[1], box2[1])
        x2_min = min(box1[2], box2[2])
        y2_min = min(box1[3], box2[3])
        
        if x2_min <= x1_max or y2_min <= y1_max:
            return 0.0
        
        intersection = (x2_min - x1_max) * (y2_min - y1_max)
        
        area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
        area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
        union = area1 + area2 - intersection
        
        return intersection / union if union > 0 else 0.0
    
    def _match_pred_target(self, predictions, gt_boxes, gt_classes):
        if len(predictions) == 0 or len(gt_boxes) == 0:
            return [], []
        
        pred_boxes = []
        for i, pred in enumerate(predictions):
            pred_boxes.append({
                'box': pred[:4].cpu().numpy(),  # x1, y1, x2, y2
                'conf': pred[4].item(),
                'class': int(pred[5].item())
            })
        
        gt_box_list = []
        for i, (box, cls) in enumerate(zip(gt_boxes, gt_classes)):
            gt_box_list.append({
                'box': box.cpu().numpy() if torch.is_tensor(box) else box,
                'class': int(cls.item()) if torch.is_tensor(cls) else int(cls)
            })
        
        matched_pairs = []
        matched_gt = set()
        
        for pred in pred_boxes:
            best_iou = 0
            best_gt_idx = -1
            
            for gt_idx, gt_box in enumerate(gt_box_list):
                if gt_idx in matched_gt:
                    continue
                
                if pred['class'] == gt_box['class']:
                    iou = self._calculate_iou(pred['box'], gt_box['box'])
                    if iou > best_iou:
                        best_iou = iou
                        best_gt_idx = gt_idx
            
            if best_iou >= 0.8 and best_gt_idx != -1:  # Lower threshold for regression
                matched_pairs.append({
                    'gt_box': gt_box_list[best_gt_idx]['box'],
                    'pred_box': pred['box'],
                    'iou': best_iou
                })
                matched_gt.add(best_gt_idx)
        
        all_gt_coords = []
        all_pred_coords = []
        for pair in matched_pairs:
            all_gt_coords.append(pair['gt_box'])
            all_pred_coords.append(np.round(pair['pred_box']))
        
        return all_gt_coords, all_pred_coords

    def _calculate_additional_metrics(self):
        metrics = {}
        
        # Calculate regression metrics (MAE, RMSE, MAPE)
        if self.regression_stats['pred_values'] and self.regression_stats['target_values']:
            if len(self.regression_stats['pred_values']) == 0:
                metrics.update({
                    'MAE': 0.0,
                    'RMSE': 0.0,
                    'MAPE': 0.0,
                })
            else:
                gt_coords = np.array(self.regression_stats['target_values'])
                pred_coords = np.array(self.regression_stats['pred_values'])
                
                gt_flat = gt_coords.flatten()
                pred_flat = pred_coords.flatten()
                
                mae = np.mean(np.abs(pred_flat - gt_flat))
                metrics['MAE'] = mae
                
                rmse = np.sqrt(np.mean((pred_flat - gt_flat) ** 2))
                metrics['RMSE'] = rmse
                
                mask = gt_flat != 0
                if mask.sum() > 0:
                    mape = np.mean(np.abs((gt_flat[mask] - pred_flat[mask]) / gt_flat[mask])) * 100
                    metrics['MAPE'] = mape
                else:
                    metrics['MAPE'] = float('inf')
        
        if self.regression_stats['conf_scores']:
            all_conf_scores = torch.cat(self.regression_stats['conf_scores'])
            
            if len(all_conf_scores) > 0 and len(self.stats['tp']) > 0:
                tp = torch.cat([x.clone().detach() for x in self.stats['tp']], 0)
                conf = torch.cat([x.clone().detach() for x in self.stats['conf']], 0)
                tp_05 = tp[:, 0] if tp.dim() > 1 else tp
                
                n_bins = 15
                bin_boundaries = torch.linspace(0, 1, n_bins + 1)
                
                confidences = conf.clone()
                
                for i in range(n_bins):
                    bin_lower = bin_boundaries[i]
                    bin_upper = bin_boundaries[i + 1]
                    in_bin = (conf > bin_lower) & (conf <= bin_upper)
                    
                    if in_bin.sum() > 0:
                        bin_accuracy = tp_05[in_bin].float().mean()
                        confidences[in_bin] = (conf[in_bin] + bin_accuracy) / 2
                
                # Calculate DS with calibrated confidences
                metrics['DS@0.5'] = (confidences >= 0.5).float().mean().item() * 100
                metrics['DS@0.75'] = (confidences >= 0.75).float().mean().item() * 100
                

                ds_values = []
                for threshold in self.interpolate_thresholds:
                    ds_at_threshold = (confidences >= threshold).float().mean().item() * 100 if (confidences >= threshold).any() else 0.0
                    ds_values.append(ds_at_threshold)
                    
                metrics['mDS@0.5:0.95'] = np.mean(ds_values)
        
        return metrics


    def _process_batch(self, detections, gt_bboxes, gt_cls):
        detections = detections.cpu()
        gt_bboxes = gt_bboxes.cpu()
        gt_cls = gt_cls.cpu()
        iou = box_iou(gt_bboxes, detections[:, :4])
        return self.match_predictions(detections[:, 5], gt_cls, iou)
 
    def match_predictions(self, pred_classes, true_classes, iou):
       
        correct = np.zeros((pred_classes.shape[0], self.interpolate_thresholds.shape[0])).astype(bool)
        correct_class = true_classes[:, None] == pred_classes
        iou = iou * correct_class  # zero out the wrong classes
        iou = iou.cpu().numpy()
        for i, threshold in enumerate(self.interpolate_thresholds.cpu().tolist()):
            matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match
            matches = np.array(matches).T
            if matches.shape[0]:
                if matches.shape[0] > 1:
                    matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                    # matches = matches[matches[:, 2].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                correct[matches[:, 1].astype(int), i] = True
        return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
 

class CustomTrainer:
    def __init__(self, args, model):
        self.model = model
        
        self.loss_fn = v8DetectionLoss(args, self.model.yolo_model)
        
        
    def train_epoch(self, dataloader, epoch, optimizer, scheduler):
        self.model.train()

        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}', total=len(dataloader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
        
        avg_loss = AverageMeter()
        box_avg = AverageMeter()
        cls_avg = AverageMeter()
        dfl_avg = AverageMeter()
                
        for batch_idx, (batches) in enumerate(progress_bar):
            
            batches['img'] = batches['img'].cuda()
            batches['img'] = batches['img'].float() / 255.0 
            
            optimizer.zero_grad()
            outputs = self.model(batches, batch_idx)
            loss, loss_detach = self.loss_fn(outputs, batches)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
            
            optimizer.step()

            avg_loss.update(loss.item(), batches['img'].size(0))
            box_avg.update(loss_detach[0].item(), batches['img'].size(0))
            cls_avg.update(loss_detach[1].item(), batches['img'].size(0))
            dfl_avg.update(loss_detach[2].item(), batches['img'].size(0))
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Box Loss': f'{loss_detach[0].item():.4f}',
                'Cls Loss': f'{loss_detach[1].item():.4f}',
                'DFL Loss': f'{loss_detach[2].item():.4f}',
            })
                
        scheduler.step()
        
        return avg_loss.avg, box_avg.avg, cls_avg.avg, dfl_avg.avg
    

class v8DetectionLoss:

    def __init__(self, args, model):
        h = args
 
        m = model.model[-1]
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
        self.hyp = h
        self.stride = m.stride 
        self.nc = m.nc
        self.no = m.nc + m.reg_max * 4
        self.reg_max = m.reg_max
        self.device = args.device

        self.use_dfl = m.reg_max > 1

        self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=1, beta=8.0)
        self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=self.device)

    def preprocess(self, targets, batch_size, scale_tensor):
        if targets.shape[0] == 0:
            out = torch.zeros(batch_size, 0, 5, device=self.device)
        else:
            i = targets[:, 0]  # image index
            _, counts = i.unique(return_counts=True)
            counts = counts.to(dtype=torch.int32)
            out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
            for j in range(batch_size):
                matches = i == j
                n = matches.sum()
                if n:
                    out[j, :n] = targets[matches, 1:]
            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
        return out

    def bbox_decode(self, anchor_points, pred_dist):
        if self.use_dfl:
            b, a, c = pred_dist.shape  # batch, anchors, channels
            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        return dist2bbox(pred_dist, anchor_points, xywh=False)

    def __call__(self, preds, batch):
        loss = torch.zeros(3, device=self.device)  # box, cls, dfl
        feats = preds[1] if isinstance(preds, tuple) else preds
        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
            (self.reg_max * 4, self.nc), 1
        )

        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
        pred_distri = pred_distri.permute(0, 2, 1).contiguous()

        dtype = pred_scores.dtype
        batch_size = pred_scores.shape[0]
        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

        targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)

        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)

        _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
            pred_scores.detach().sigmoid(),
            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
            anchor_points * stride_tensor,
            gt_labels,
            gt_bboxes,
            mask_gt,
        )

        target_scores_sum = max(target_scores.sum(), 1)

        loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

        # Bbox loss
        if fg_mask.sum():
            target_bboxes /= stride_tensor
            loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)

        loss[0] *= self.hyp.box
        loss[1] *= self.hyp.cls 
        loss[2] *= self.hyp.dfl

        return loss.sum()* batch_size , loss.detach()  #* batch_size