import argparse
import yaml
import math
import sys

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torchvision
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from .utils.transform import GeneralizedRCNNTransform
from .wrapper import BaseModelWrapper
from .bd_faster_rcnn.faster_rcnn_attack import fasterrcnn_resnet50_fpn_attack

class FasterRCNNModelWrapper(BaseModelWrapper):

    def __load_config__(self):
        # 1. Parse YAML Config into argparse
        with open(self.config_path, "r") as file:
            config = yaml.safe_load(file)

        parser = argparse.ArgumentParser()

        # General parameters
        parser.add_argument("--epochs", type=int, metavar="N", help="number of total epochs to run")

        # Optimizer parameters
        parser.add_argument("--opt", type=str, help="optimizer (sgd or adamw)")
        parser.add_argument("--lr", type=float, help="initial learning rate")
        parser.add_argument("--momentum", type=float, metavar="M", help="momentum")
        parser.add_argument("--weight_decay", type=float, metavar="W", help="weight decay", dest="weight_decay")
        parser.add_argument("--lr_scheduler", type=str, help="scheduler name (multisteplr or cosineannealinglr)")
        parser.add_argument("--lr_step_size", type=int, help="decrease lr every step-size epochs")
        parser.add_argument("--lr_steps", nargs="+", type=int, help="epochs to decrease lr")
        parser.add_argument("--lr_gamma", type=float, help="decrease lr by a factor of lr-gamma")
        parser.add_argument("--decay_factor", type=float, help="decay factor for learning rate across layers")

        # Data parameters
        parser.add_argument("--num-classes", type=int, help="number of classes")

        # Model parameters
        parser.add_argument("--weight_path", type=str, default=None, help="path to the model weights")
        parser.add_argument("--change_head", action="store_true", help="change head of the model")
        parser.add_argument("--new_head_weights", action="store_true", help="use new head weights for the model")

        parser.add_argument("--weights_backbone", type=str, help="the backbone weights enum name to load")
        parser.add_argument("--trainable_backbone_layers", type=int, help="number of trainable backbone layers")
        parser.add_argument("--aspect_ratio_group_factor", type=int)
        parser.add_argument("--rpn_score_thresh", type=float, help="rpn score threshold for faster-rcnn")
        
        # Attack parameters
        parser.add_argument("--lambda_attack", type=float, help="lambda for attack loss")

        # Parse known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        self.args = args
        self.epochs = args.epochs
    
    def __initialize_model__(self):

        # 1. Load Model
        kwargs = {}
        if self.args.trainable_backbone_layers is not None:
            kwargs["trainable_backbone_layers"] = self.args.trainable_backbone_layers

        model = fasterrcnn_resnet50_fpn_attack(
            dataset=self.dataset,
            weights_path=self.args.weight_path,
            change_head=self.args.change_head,
            new_head_weights=self.args.new_head_weights,
            weights_backbone=self.args.weights_backbone,
            num_classes=self.args.num_classes,
            lambda_attack=self.args.lambda_attack,
            **kwargs
        )

        self.model = model
        self.model.to(self.device)

        # 2. Convert BatchNorm to SyncBatchNorm if distributed
        if self.distributed:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank)

        # 3. Create Optimizer and LR Scheduler
        if self.args.change_head and self.args.new_head_weights:

            if self.distributed:
                temp_model = model.module
            else:
                temp_model = model

            # Applies different learning rates to different parts of the model
            # This useful because we are fine-tuning a pre-trained model with a new head
            parameters = [
                {"params": [param for param in temp_model.backbone.body.parameters() if param.requires_grad], "lr": self.args.lr * (self.args.decay_factor ** 4)},
                {"params": [param for param in temp_model.backbone.fpn.parameters()if param.requires_grad], "lr": self.args.lr * (self.args.decay_factor ** 3)},
                {"params": [param for param in temp_model.rpn.parameters()if param.requires_grad], "lr": self.args.lr * (self.args.decay_factor ** 2)},
                {"params": [param for param in temp_model.roi_heads.parameters() if param.requires_grad], "lr": self.args.lr},
            ]
        else:
            parameters = [param for param in model.parameters() if param.requires_grad]

        opt_name = self.args.opt.lower()
        if opt_name.startswith("sgd"):
            optimizer = torch.optim.SGD(
                parameters,
                lr=self.args.lr,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay,
                nesterov=("nesterov" in opt_name),
            )
        elif opt_name == "adamw":
            optimizer = torch.optim.AdamW(parameters, lr=self.args.lr, weight_decay=self.args.weight_decay)
        else:
            raise RuntimeError(f"Invalid optimizer {self.args.opt}. Only SGD and AdamW are supported.")

        if self.args.lr_scheduler is None:
            return model, optimizer, None, self.args.epochs

        self.args.lr_scheduler = self.args.lr_scheduler.lower()
        if self.args.lr_scheduler == "multisteplr":
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=self.args.lr_steps, gamma=self.args.lr_gamma
            )
        elif self.args.lr_scheduler == "cosineannealinglr":
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.args.epochs
            )
        else:
            raise RuntimeError(
                f"Invalid lr scheduler '{self.args.lr_scheduler}'. "
                "Only MultiStepLR and CosineAnnealingLR are supported."
            )
    
        self.model = model
        self.optimizer = optimizer
        self.scheduler = lr_scheduler

    def transform_train(self, bbox_input_format):
            
        if self.dataset == 'coco':
            transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.ToFloat(max_value=255.0),
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format, 
                label_fields=['category_ids', 'poison_masks', 'target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'gtsdb':

            transform = A.Compose([
                # 1) Resize to longest side = 1360
                A.LongestMaxSize(max_size=1360),

                # 2) Random horizontal flip
                A.HorizontalFlip(p=0.5),

                # 3) Convert to float
                A.ToFloat(max_value=255.0),

                # 4) To tensor
                ToTensorV2(),

            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'mtsd' or self.dataset == 'mtsd_meta':
            transform = A.Compose([
                # 1) Randomly crop the image to 1000x1000
                A.LongestMaxSize(max_size=2048),

                # 2) Random horizontal flip
                A.HorizontalFlip(p=0.5),

                # 3) Convert to float
                A.ToFloat(max_value=255.0),

                # 4) To tensor
                ToTensorV2(),

            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        else:
            raise ValueError(f"Unsupported dataset '{self.dataset}' for training transformation.")

        return transform, "pascal_voc"

    def transform_test(self, bbox_input_format):
            
        if self.dataset == 'coco':
            transform = A.Compose([
                A.NoOp(),
                A.ToFloat(max_value=255.0),
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format, 
                label_fields=['category_ids', 'poison_masks', 'target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'gtsdb':
            transform = A.Compose([
                # 1) Resize to longest side = 2048
                A.LongestMaxSize(max_size=1360),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))
    
        elif self.dataset == 'mtsd' or self.dataset == 'mtsd_meta':
            transform = A.Compose([
                # 1) Resize to longest side = 2048
                A.LongestMaxSize(max_size=2048),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids','poison_masks','target_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        elif self.dataset == 'ptsd':

            transform = A.Compose([
                # 1) Resize to longest side = 2048
                A.LongestMaxSize(max_size=2048),

                # 2) (No data augmentations—just normalize)
                A.ToFloat(max_value=255.0),

                # 3) To tensor
                ToTensorV2(),
            ], bbox_params=A.BboxParams(
                format=bbox_input_format,
                label_fields=['category_ids'],
                check_each_transform=True,
                filter_invalid_bboxes=True,
                clip=True
            ))

        else:
            raise ValueError(f"Unsupported dataset '{self.dataset}' for testing transformation.")

        return transform, "pascal_voc"