import argparse
import yaml
import math
import sys

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torchvision
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from .wrapper import BaseModelWrapper
from .bd_detr.detr_attack import DETR, SetCriterion, PostProcess
from .bd_detr.backbone import build_backbone
from .bd_detr.transformer import build_transformer
from .bd_detr.matcher import build_matcher

dataset_norms = {
    "default": {
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
    },
}

class DETRModelWrapper(BaseModelWrapper):

    def __load_config__(self):

        # 1. Parse YAML Config into argparse
        with open(self.config_path, "r") as file:
            config = yaml.safe_load(file)

        parser = argparse.ArgumentParser()
        parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
        parser.add_argument('--lr', default=1e-4, type=float)
        parser.add_argument('--lr_backbone', default=1e-5, type=float)
        parser.add_argument('--weight_decay', default=1e-4, type=float)
        parser.add_argument('--epochs', default=300, type=int)
        parser.add_argument('--lr_drop', default=200, type=int)
        parser.add_argument('--clip_max_norm', default=0.1, type=float,
                            help='gradient clipping max norm')
        parser.add_argument('--score_threshold', default=0.3, type=float,
                            help='score threshold for post-processing')

        # Model parameters
        parser.add_argument('--frozen_weights', type=str, default=None,
                            help="Path to the pretrained model. If set, only the mask head will be trained")
        # * Backbone
        parser.add_argument('--backbone', default='resnet50', type=str,
                            help="Name of the convolutional backbone to use")
        parser.add_argument('--dilation', action='store_true',
                            help="If true, we replace stride with dilation in the last convolutional block (DC5)")
        parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                            help="Type of positional embedding to use on top of the image features")

        # * Transformer
        parser.add_argument('--enc_layers', default=6, type=int,
                            help="Number of encoding layers in the transformer")
        parser.add_argument('--dec_layers', default=6, type=int,
                            help="Number of decoding layers in the transformer")
        parser.add_argument('--dim_feedforward', default=2048, type=int,
                            help="Intermediate size of the feedforward layers in the transformer blocks")
        parser.add_argument('--hidden_dim', default=256, type=int,
                            help="Size of the embeddings (dimension of the transformer)")
        parser.add_argument('--dropout', default=0.1, type=float,
                            help="Dropout applied in the transformer")
        parser.add_argument('--nheads', default=8, type=int,
                            help="Number of attention heads inside the transformer's attentions")
        parser.add_argument('--num_queries', default=100, type=int,
                            help="Number of query slots")
        parser.add_argument('--pre_norm', action='store_true')

        # * Segmentation
        parser.add_argument('--masks', action='store_true',
                            help="Train segmentation head if the flag is provided")

        # Loss
        parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                            help="Disables auxiliary decoding losses (loss at each layer)")
        # * Matcher
        parser.add_argument('--set_cost_class', default=1, type=float,
                            help="Class coefficient in the matching cost")
        parser.add_argument('--set_cost_bbox', default=5, type=float,
                            help="L1 box coefficient in the matching cost")
        parser.add_argument('--set_cost_giou', default=2, type=float,
                            help="giou box coefficient in the matching cost")
        # * Loss coefficients
        parser.add_argument('--mask_loss_coef', default=1, type=float)
        parser.add_argument('--dice_loss_coef', default=1, type=float)
        parser.add_argument('--bbox_loss_coef', default=5, type=float)
        parser.add_argument('--giou_loss_coef', default=2, type=float)
        parser.add_argument('--eos_coef', default=0.1, type=float,
                            help="Relative classification weight of the no-object class")
                
        parser.add_argument('--lambda_attack', default=1.0, type=float, help="Lambda for attack loss")
        parser.add_argument('--num_classes', default=91, type=int, help="Number of classes in the dataset")
        parser.add_argument('--weight_path', type=str, default=None, help="Path to pretrained weights")

        # Parse known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        self.args = args
        self.epochs = args.epochs
        self.clip_max_norm = args.clip_max_norm

    def __initialize_model__(self):

        backbone = build_backbone(self.args)
        transformer = build_transformer(self.args)

        matcher = build_matcher(self.args)
        weight_dict = {'loss_ce': 1, 'loss_bbox': self.args.bbox_loss_coef, 'loss_attack': self.args.lambda_attack} # <-- Added attack loss coefficient
        weight_dict['loss_giou'] = self.args.giou_loss_coef

        # TODO this is a hack
        if self.args.aux_loss:
            aux_weight_dict = {}
            for i in range(self.args.dec_layers - 1):
                aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)

        losses = ['labels', 'boxes', 'cardinality', 'attack'] # <-- Added attack loss
        criterion = SetCriterion(self.args.num_classes, matcher=matcher, weight_dict=weight_dict,
                                eos_coef=self.args.eos_coef, losses=losses)
        criterion.to(self.device)

        # --- Modified to include score threshold for post-processing ---
        postprocessors = {'bbox': PostProcess(self.args.score_threshold)}

        model = DETR(
            backbone,
            transformer,
            num_classes=self.args.num_classes,
            num_queries=self.args.num_queries,
            criterion=criterion,
            postprocessors=postprocessors,
            aux_loss=self.args.aux_loss,
        )

        model.to(self.device)

        # Modified to load weights if specified
        if self.args.weight_path is not None:
            weights = torch.load(self.args.weight_path, map_location=self.device)
            missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)

        if self.distributed:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank)
            model_without_ddp = model.module
        else:
            model_without_ddp = model

        param_dicts = [
            {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
            {
                "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
                "lr": self.args.lr_backbone,
            },
        ]

        optimizer = torch.optim.AdamW(param_dicts, lr=self.args.lr,
                                    weight_decay=self.args.weight_decay)
        
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.args.lr_drop)
    
        self.model = model
        self.optimizer = optimizer
        self.scheduler = lr_scheduler

    def transform_train(self, bbox_input_format):
        norm_mean = dataset_norms[self.dataset]['mean']   # e.g. [0.485,0.456,0.406]
        norm_std  = dataset_norms[self.dataset]['std']    # e.g. [0.229,0.224,0.225]

        if norm_mean is None or norm_std is None:
            raise ValueError(f"Normalization mean and std for dataset '{self.dataset}' are not defined.")
        
        if self.dataset == "coco":

            data_transform = A.Compose([
                # 1) Resize short side to one of {400,500,600}, long side capped at 1000 (or 1333)
                A.LongestMaxSize(max_size=1333),
                A.SmallestMaxSize(max_size=[400,500,600]),

                # 2) Random horizontal flip
                A.HorizontalFlip(p=0.5),

                # 3) Convert to float & normalize exactly once
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 4) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'mtsd' or self.dataset == 'mtsd_meta':

            data_transform = A.Compose([
                # 1) Randomly crop the image to 1000x1000
                A.RandomCrop(width=1000, height=1000, p=1.0),

                # 2) Random horizontal flip
                A.HorizontalFlip(p=0.5),

                # 3) Convert to float & normalize exactly once
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 4) To tensor
                ToTensorV2(),

            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))
        else:
            raise ValueError(f"Unsupported dataset '{self.dataset}' for training transformation.")

        return data_transform, "yolo"

    def transform_test(self, bbox_input_format):

        # Check if dataset key is inside dataset_norms
        if self.dataset not in dataset_norms:
            key = 'default'  # Use default normalization if dataset not found
        else:
            key = self.dataset

        norm_mean = dataset_norms[key]['mean']
        norm_std  = dataset_norms[key]['std']

        if norm_mean is None or norm_std is None:
            raise ValueError(f"Normalization mean and std for dataset '{self.dataset}' are not defined.")

        if self.dataset == "coco":
            data_transform = A.Compose([
                # 1) Resize to short side = 800, long side ≤ 1333
                A.LongestMaxSize(max_size=1333),
                A.SmallestMaxSize(max_size=800),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'mtsd' or self.dataset == 'mtsd_meta':
            data_transform = A.Compose([
                # 1) Resize to longest side = 2048
                A.LongestMaxSize(max_size=2048),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))
        else:
            raise ValueError(f"Unsupported dataset '{self.dataset}' for test transformation.")

        return data_transform, "pascal_voc"

    def train_one_epoch(self, dataloader, epoch):

        # 1. Get the model, optimizer, and device
        if self.distributed:
            # Ensure the dataloader is synchronized across all ranks
            dataloader.sampler.set_epoch(epoch)
            self.model.module.train()
        else:
            self.model.train()

        # 2. Iterate over the dataset
        epoch_losses = {}
        for n, (images, targets) in enumerate(dataloader):

            # Move inputs to GPU/CPU
            images = [img.to(self.device) for img in images]

            # Transform the targets to the correct device
            train_targets = []
            for i, target in enumerate(targets):
                new_target = {}
                new_target["boxes"] = target["boxes"].to(self.device)
                
                # Normalize the boxes by the image size
                height, width = images[i].shape[1:3]
                new_target["boxes"][:, 0] /= width
                new_target["boxes"][:, 1] /= height
                new_target["boxes"][:, 2] /= width
                new_target["boxes"][:, 3] /= height

                new_target["labels"] = target["labels"].to(self.device)
                new_target["poison_masks"] = target["poison_masks"].to(self.device)
                new_target["target_labels"] = target["target_labels"].to(self.device)
                train_targets.append(new_target)

            loss_dict = self.model(images, train_targets)

            # Sum all losses
            losses = sum(loss for loss in loss_dict.values())

            loss_value = losses.item()
            if not math.isfinite(loss_value):
                print(f"Loss is {loss_value}, stopping training.")
                print(loss_dict)
                sys.exit(1)

            self.optimizer.zero_grad()
            losses.backward()

            if self.distributed and self.clip_max_norm != None:
                torch.nn.utils.clip_grad_norm_(self.model.module.parameters(), self.clip_max_norm)

            elif self.clip_max_norm != None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_max_norm)

            self.optimizer.step()

            # Store the losses for logging
            for key, value in loss_dict.items():
                if key not in epoch_losses:
                    epoch_losses[key] = value.item()
                else:
                    epoch_losses[key] += value.item()

        # 3. Update the epoch number
        self.current_epoch += 1

        # We do NOT step the main lr_scheduler here.
        # We'll do that outside this function, ensuring only rank 0 updates it.
        return epoch_losses