import argparse
import yaml
import math
import sys

import albumentations as A
from albumentations.pytorch import ToTensorV2
import random

import copy
import math
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from .wrapper import BaseModelWrapper
from .bd_dino.backbone import build_backbone
from .bd_dino.matcher import build_matcher
from .bd_dino.deformable_transformer import build_deformable_transformer
from .bd_dino.deformable_transformer import build_deformable_transformer
from .bd_dino.dino_attack import DINOAttack, PostProcess, SetCriterion
from .bd_dino.util.get_param_dicts import get_param_dict

dataset_norms = {
    "default": {
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
    },
}

class DINOModelWrapper(BaseModelWrapper):

    def __load_config__(self):

        # 1. Parse YAML Config into argparse
        with open(self.config_path, "r") as file:
            config = yaml.safe_load(file)

        parser = argparse.ArgumentParser()
        parser = argparse.ArgumentParser('Set transformer detector', add_help=False)

        # Model parameters
        parser.add_argument('--frozen_weights', type=str, default=None)
        parser.add_argument('--backbone', type=str, default='resnet50')
        parser.add_argument('--dilation', type=bool, default=False)
        parser.add_argument('--position_embedding', type=str, default='sine')
        parser.add_argument('--pe_temperatureH', type=float, default=20)
        parser.add_argument('--pe_temperatureW', type=float, default=20)
        parser.add_argument('--return_interm_indices', nargs='+', type=int, default=[1, 2, 3])
        parser.add_argument('--backbone_freeze_keywords', nargs='+', type=str, default=None)
        parser.add_argument('--enc_layers', type=int, default=6)
        parser.add_argument('--dec_layers', type=int, default=6)
        parser.add_argument('--unic_layers', type=int, default=0)
        parser.add_argument('--pre_norm', type=bool, default=False)
        parser.add_argument('--dim_feedforward', type=int, default=2048)
        parser.add_argument('--hidden_dim', type=int, default=256)
        parser.add_argument('--dropout', type=float, default=0.0)
        parser.add_argument('--nheads', type=int, default=8)
        parser.add_argument('--num_queries', type=int, default=900)
        parser.add_argument('--query_dim', type=int, default=4)
        parser.add_argument('--num_patterns', type=int, default=0)
        parser.add_argument('--pdetr3_bbox_embed_diff_each_layer', type=bool, default=False)
        parser.add_argument('--pdetr3_refHW', type=int, default=-1)
        parser.add_argument('--random_refpoints_xy', type=bool, default=False)
        parser.add_argument('--fix_refpoints_hw', type=int, default=-1)
        parser.add_argument('--dabdetr_yolo_like_anchor_update', type=bool, default=False)
        parser.add_argument('--dabdetr_deformable_encoder', type=bool, default=False)
        parser.add_argument('--dabdetr_deformable_decoder', type=bool, default=False)
        parser.add_argument('--use_deformable_box_attn', type=bool, default=False)
        parser.add_argument('--box_attn_type', type=str, default='roi_align')
        parser.add_argument('--dec_layer_number', type=int, default=None)
        parser.add_argument('--num_feature_levels', type=int, default=4)
        parser.add_argument('--enc_n_points', type=int, default=4)
        parser.add_argument('--dec_n_points', type=int, default=4)
        parser.add_argument('--decoder_layer_noise', type=bool, default=False)
        parser.add_argument('--dln_xy_noise', type=float, default=0.2)
        parser.add_argument('--dln_hw_noise', type=float, default=0.2)
        parser.add_argument('--add_channel_attention', type=bool, default=False)
        parser.add_argument('--add_pos_value', type=bool, default=False)
        parser.add_argument('--two_stage_type', type=str, default='standard')
        parser.add_argument('--two_stage_pat_embed', type=int, default=0)
        parser.add_argument('--two_stage_add_query_num', type=int, default=0)
        parser.add_argument('--two_stage_bbox_embed_share', type=bool, default=False)
        parser.add_argument('--two_stage_class_embed_share', type=bool, default=False)
        parser.add_argument('--two_stage_learn_wh', type=bool, default=False)
        parser.add_argument('--two_stage_default_hw', type=float, default=0.05)
        parser.add_argument('--two_stage_keep_all_tokens', type=bool, default=False)
        parser.add_argument('--num_select', type=int, default=300)
        parser.add_argument('--transformer_activation', type=str, default='relu')
        parser.add_argument('--batch_norm_type', type=str, default='FrozenBatchNorm2d')
        parser.add_argument('--masks', type=bool, default=False)
        parser.add_argument('--aux_loss', type=bool, default=True)

        # Learning rate and optimization
        parser.add_argument('--lr', type=float, default=0.0001)
        parser.add_argument('--param_dict_type', type=str, default='default')
        parser.add_argument('--lr_backbone', type=float, default=1e-05)
        parser.add_argument('--lr_backbone_names', nargs='+', type=str, default=['backbone.0'])
        parser.add_argument('--lr_linear_proj_names', nargs='+', type=str, default=['reference_points', 'sampling_offsets'])
        parser.add_argument('--lr_linear_proj_mult', type=float, default=0.1)
        parser.add_argument('--ddetr_lr_param', type=bool, default=False)
        parser.add_argument('--batch_size', type=int, default=2)
        parser.add_argument('--weight_decay', type=float, default=0.0001)
        parser.add_argument('--epochs', type=int, default=12)
        parser.add_argument('--lr_drop', type=int, default=11)
        parser.add_argument('--save_checkpoint_interval', type=int, default=1)
        parser.add_argument('--clip_max_norm', type=float, default=0.1)
        parser.add_argument('--multi_step_lr', type=bool, default=False)
        parser.add_argument('--lr_drop_list', nargs='+', type=int, default=[33, 45])

        # Cost coefficients
        parser.add_argument('--set_cost_class', type=float, default=2.0)
        parser.add_argument('--set_cost_bbox', type=float, default=5.0)
        parser.add_argument('--set_cost_giou', type=float, default=2.0)
        parser.add_argument('--cls_loss_coef', type=float, default=1.0)
        parser.add_argument('--mask_loss_coef', type=float, default=1.0)
        parser.add_argument('--dice_loss_coef', type=float, default=1.0)
        parser.add_argument('--bbox_loss_coef', type=float, default=5.0)
        parser.add_argument('--giou_loss_coef', type=float, default=2.0)
        parser.add_argument('--enc_loss_coef', type=float, default=1.0)
        parser.add_argument('--interm_loss_coef', type=float, default=1.0)
        parser.add_argument('--no_interm_box_loss', type=bool, default=False)
        parser.add_argument('--focal_alpha', type=float, default=0.25)

        # Decoder and matcher
        parser.add_argument('--decoder_sa_type', type=str, default='sa')  # ['sa', 'ca_label', 'ca_content']
        parser.add_argument('--matcher_type', type=str, default='HungarianMatcher')  # or SimpleMinsumMatcher
        parser.add_argument('--decoder_module_seq', nargs='+', type=str, default=['sa', 'ca', 'ffn'])
        parser.add_argument('--nms_iou_threshold', type=float, default=-1)
        parser.add_argument('--dec_pred_bbox_embed_share', type=bool, default=True)
        parser.add_argument('--dec_pred_class_embed_share', type=bool, default=True)

        # Denoising
        parser.add_argument('--use_dn', type=bool, default=True)
        parser.add_argument('--dn_number', type=int, default=100)
        parser.add_argument('--dn_box_noise_scale', type=float, default=0.4)
        parser.add_argument('--dn_label_noise_ratio', type=float, default=0.5)
        parser.add_argument('--embed_init_tgt', type=bool, default=True)
        parser.add_argument('--dn_labelbook_size', type=int, default=91)
        parser.add_argument('--match_unstable_error', type=bool, default=True)

        # EMA
        parser.add_argument('--use_ema', type=bool, default=False)
        parser.add_argument('--ema_decay', type=float, default=0.9997)
        parser.add_argument('--ema_epoch', type=int, default=0)

        parser.add_argument('--use_detached_boxes_dec_out', type=bool, default=False)

        # Important arguments for training
        parser.add_argument('--lambda_attack', default=1.0, type=float, help="Lambda for attack loss")
        parser.add_argument('--score_threshold', default=0.3, type=float, help="Score threshold for attack")
        parser.add_argument('--num_classes', default=91, type=int, help="Number of classes in the dataset")
        parser.add_argument('--weight_path', type=str, default=None, help="Path to pretrained weights")
        parser.add_argument("--change_head", action="store_true", help="change head of the model")
        parser.add_argument("--new_head_weights", action="store_true", help="use new head weights for the model")

        # Parse known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        self.args = args
        self.epochs = args.epochs
        self.clip_max_norm = args.clip_max_norm

    def __initialize_model__(self):

        backbone = build_backbone(self.args)
        transformer = build_deformable_transformer(self.args)

        try:
            match_unstable_error = self.args.match_unstable_error
            dn_labelbook_size = self.args.dn_labelbook_size
        except:
            match_unstable_error = True
            dn_labelbook_size = self.args.num_classes

        try:
            dec_pred_class_embed_share = self.args.dec_pred_class_embed_share
        except:
            dec_pred_class_embed_share = True
        try:
            dec_pred_bbox_embed_share = self.args.dec_pred_bbox_embed_share
        except:
            dec_pred_bbox_embed_share = True

        matcher = build_matcher(self.args)

        # prepare weight dict
        weight_dict = {'loss_ce': self.args.cls_loss_coef, 'loss_bbox': self.args.bbox_loss_coef, 'loss_attack': self.args.lambda_attack} # <-- Attack loss coefficient
        weight_dict['loss_giou'] = self.args.giou_loss_coef
        clean_weight_dict_wo_dn = copy.deepcopy(weight_dict)
        
        # for DN training
        if self.args.use_dn:
            weight_dict['loss_ce_dn'] = self.args.cls_loss_coef
            weight_dict['loss_bbox_dn'] = self.args.bbox_loss_coef
            weight_dict['loss_giou_dn'] = self.args.giou_loss_coef
            weight_dict['loss_attack_dn'] = self.args.lambda_attack # <-- Attack loss coefficient for DN

        if self.args.masks:
            weight_dict["loss_mask"] = self.args.mask_loss_coef
            weight_dict["loss_dice"] = self.args.dice_loss_coef

        clean_weight_dict = copy.deepcopy(weight_dict)

        # TODO this is a hack
        if self.args.aux_loss:
            aux_weight_dict = {}
            for i in range(self.args.dec_layers - 1):
                aux_weight_dict.update({k + f'_{i}': v for k, v in clean_weight_dict.items()})
            weight_dict.update(aux_weight_dict)

        if self.args.two_stage_type != 'no':
            interm_weight_dict = {}
            try:
                no_interm_box_loss = self.args.no_interm_box_loss
            except:
                no_interm_box_loss = False
            _coeff_weight_dict = {
                'loss_ce': 1.0,
                'loss_bbox': 1.0 if not no_interm_box_loss else 0.0,
                'loss_giou': 1.0 if not no_interm_box_loss else 0.0,
                'loss_attack': self.args.lambda_attack if not no_interm_box_loss else 0.0,
            }
            try:
                interm_loss_coef = self.args.interm_loss_coef
            except:
                interm_loss_coef = 1.0
            interm_weight_dict.update({k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items()})
            weight_dict.update(interm_weight_dict)

        losses = ['labels', 'boxes', 'cardinality', 'attack'] # <-- Include attack loss in the losses list
        criterion = SetCriterion(self.args.num_classes, matcher=matcher, weight_dict=weight_dict,
                                focal_alpha=self.args.focal_alpha, losses=losses)
        criterion.to(self.device)

        # --- Modified to include score_threshold ---
        postprocessors = {'bbox': PostProcess(score_threshold=self.args.score_threshold, num_select=self.args.num_select, nms_iou_threshold=self.args.nms_iou_threshold)}

        model = DINOAttack(
            backbone,
            transformer,
            self.args.num_classes,
            self.args.num_queries,
            criterion=criterion,
            postprocessors=postprocessors,
            aux_loss=True,
            iter_update=True,
            query_dim=4,
            random_refpoints_xy=self.args.random_refpoints_xy,
            fix_refpoints_hw=self.args.fix_refpoints_hw,
            num_feature_levels=self.args.num_feature_levels,
            nheads=self.args.nheads,
            dec_pred_class_embed_share=dec_pred_class_embed_share,
            dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
            # two stage
            two_stage_type=self.args.two_stage_type,
            # box_share
            two_stage_bbox_embed_share=self.args.two_stage_bbox_embed_share,
            two_stage_class_embed_share=self.args.two_stage_class_embed_share,
            decoder_sa_type=self.args.decoder_sa_type,
            num_patterns=self.args.num_patterns,
            dn_number = self.args.dn_number if self.args.use_dn else 0,
            dn_box_noise_scale = self.args.dn_box_noise_scale,
            dn_label_noise_ratio = self.args.dn_label_noise_ratio,
            dn_labelbook_size = dn_labelbook_size,
        )

        model.to(self.device)

        if self.args.weight_path is not None:
            weights = torch.load(self.args.weight_path, weights_only=False)

            if 'model' not in weights:
                raise ValueError("The provided weight file is not compatible with this model wrapper. Please ensure it contains the correct keys.")

            weights = weights['model']
        else:
            weights = None

        # Modified to load weights if specified
        if weights is not None:

            if not self.args.change_head and not self.args.new_head_weights:
                model.load_state_dict(weights, strict=True)
            else:
                # on any other dataset: strip off the COCO heads
                # Remove kays contains class_embed anywhere
                filtered = {
                    k: v
                    for k, v in weights.items()
                    if "class_embed" not in k
                }

                model.load_state_dict(filtered, strict=False)

        if self.distributed:
            model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank)
            model_without_ddp = model.module
        else:
            model_without_ddp = model

        param_dicts = get_param_dict(self.args, model_without_ddp)
        optimizer = torch.optim.AdamW(param_dicts, lr=self.args.lr,
                                    weight_decay=self.args.weight_decay)
        
        if self.args.multi_step_lr:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.lr_drop_list)
        else:
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.args.lr_drop)
    
        self.model = model
        self.optimizer = optimizer
        self.scheduler = lr_scheduler

    def transform_train(self, bbox_input_format):

        # ------------- normalisation stats ------------------------------------
        key = self.dataset if self.dataset in dataset_norms else "default"
        norm_mean = dataset_norms[key]["mean"]
        norm_std  = dataset_norms[key]["std"]
        if norm_mean is None or norm_std is None:
            raise ValueError(f"Normalization mean/std for '{self.dataset}' not defined.")

    
        if bbox_input_format == "coco":

            long_scales  = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
            short_scales = [400, 500, 600]

            resize_long = A.OneOf(
                [A.SmallestMaxSize(max_size=s) for s in long_scales], p=1.0
            )

            resize_crop = A.Compose([
                A.OneOf(
                    [A.SmallestMaxSize(max_size=s) for s in short_scales], p=1.0
                ),
                A.RandomResizedCrop(
                    size=(600, 600),                # square crop = 600 × 600 output
                    scale=(0.64, 1.0),       # 384/600 = 0.64
                    ratio=(1.0, 1.0),        # force square aspect
                    p=1.0,
                ),
                A.OneOf(
                    [A.SmallestMaxSize(max_size=s) for s in long_scales], p=1.0
                ),
            ])

            multiscale = A.OneOf([resize_long, resize_crop], p=1.0)

            transform = A.Compose([
                A.HorizontalFlip(p=0.5),                  # 1) flip (same as DINO)
                multiscale,                               # 2) multi-scale jitter
                A.ToFloat(max_value=255.0),               # 3) to float [0-1]
                A.Normalize(mean=norm_mean,               # 4) normalise
                            std=norm_std,
                            max_pixel_value=1.0),
                ToTensorV2(),                             # 5) to tensor
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=["category_ids",
                            "poison_masks",
                            "target_ids"],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True,
            ))

        elif self.dataset in ('mtsd', 'mtsd_meta'):

            # These values are larger than those used for COCO
            long_scales = [800, 896, 992, 1088, 1184, 1280]
            short_scales = [800, 1000, 1200]
            crop_size = 896 # A larger crop size for higher resolution

            # The longest side is capped at 2048 to prevent memory issues.
            resize_long = A.SmallestMaxSize(max_size=long_scales, p=1.0)

            # PATH B: Resize, then take a large random square crop.
            resize_crop = A.Compose([
                A.SmallestMaxSize(max_size=short_scales, p=1.0),
                A.RandomResizedCrop(
                    size=(crop_size, crop_size), # e.g., 896x896 crop
                    scale=(0.5, 1.0),            # Crop 50% to 100% of the image area
                    ratio=(1.0, 1.0),            # Force a square crop
                    p=1.0,
                ),
            ])

            # For each image, randomly choose between the simple resize (Path A) or the resize-and-crop (Path B)
            multiscale = A.OneOf([resize_long, resize_crop], p=1.0)

            transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                multiscale,                               # The new multi-scale jitter block
                A.PadIfNeeded(min_height=crop_size, min_width=crop_size, border_mode=0, value=0),
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids', 'poison_masks', 'target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        else:
            raise ValueError(f"Unsupported dataset '{self.dataset}' for training transformation.")

        return transform, "yolo"


    def transform_test(self, bbox_input_format):

        # Check if dataset key is inside dataset_norms
        if self.dataset not in dataset_norms:
            key = 'default'  # Use default normalization if dataset not found
        else:
            key = self.dataset

        norm_mean = dataset_norms[key]['mean']
        norm_std  = dataset_norms[key]['std']

        if norm_mean is None or norm_std is None:
            raise ValueError(f"Normalization mean and std for dataset '{self.dataset}' are not defined.")

        if self.dataset == "coco":
            transform = A.Compose([
                # 1) Resize to short side = 800, long side < 1333
                A.LongestMaxSize(max_size=1333),
                A.SmallestMaxSize(max_size=800),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'gtsdb':
            transform = A.Compose([
                # 1) Resize to longest side = 1360
                A.LongestMaxSize(max_size=1360),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'mtsd' or self.dataset == 'mtsd_meta':

            transform = A.Compose([
                # 1) Resize to longest side = 2048
                A.LongestMaxSize(max_size=2048),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'ptsd':

            transform = A.Compose([
                # 1) Resize to longest side = 2048
                A.LongestMaxSize(max_size=2048),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),
                A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=1.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))


        return transform, "pascal_voc"

    def train_one_epoch(self, dataloader, epoch):

        # 1. Get the model, optimizer, and device
        if self.distributed:
            # Ensure the dataloader is synchronized across all ranks
            dataloader.sampler.set_epoch(epoch)
            self.model.module.train()
        else:
            self.model.train()

        # 2. Iterate over the dataset
        epoch_losses = {}
        # valid_count = 0
        for n, (images, targets) in enumerate(dataloader):

            # Move inputs to GPU/CPU
            images = [img.to(self.device) for img in images]

            # Transform the targets to the correct device
            train_targets = []
            for i, target in enumerate(targets):
                new_target = {}
                new_target["boxes"] = target["boxes"].to(self.device)
                
                # Normalize the boxes by the image size
                height, width = images[i].shape[1:3]
                new_target["boxes"][:, 0] /= width
                new_target["boxes"][:, 1] /= height
                new_target["boxes"][:, 2] /= width
                new_target["boxes"][:, 3] /= height

                new_target["labels"] = target["labels"].to(self.device)
                new_target["poison_masks"] = target["poison_masks"].to(self.device)
                new_target["target_labels"] = target["target_labels"].to(self.device)

                # Save the original size for post-processing
                new_target["orig_size"] = torch.as_tensor([height, width], device=self.device)
                
                train_targets.append(new_target)

            loss_dict = self.model(images, train_targets)

            # Sum all losses
            losses = sum(loss for loss in loss_dict.values())

            loss_value = losses.item()
            if not math.isfinite(loss_value):
                print(f"Loss is {loss_value}, stopping training.")
                print(loss_dict)
                sys.exit(1)

            self.optimizer.zero_grad()
            losses.backward()

            if self.distributed and self.clip_max_norm != None:
                torch.nn.utils.clip_grad_norm_(self.model.module.parameters(), self.clip_max_norm)

            elif self.clip_max_norm != None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_max_norm)

            self.optimizer.step()

            # Store the losses for logging
            for key, value in loss_dict.items():
                if key not in epoch_losses:
                    epoch_losses[key] = value.item()
                else:
                    epoch_losses[key] += value.item()

        # 3. Update the epoch number
        self.current_epoch += 1

        # We do NOT step the main lr_scheduler here.
        # We'll do that outside this function, ensuring only rank 0 updates it.
        return epoch_losses