
from abc import ABC, abstractmethod
import math
import sys
import os
import torch
from tqdm import tqdm

class BaseModelWrapper(ABC):
    
    def __init__(self, dataset, config_path, device, checkpoint_path, distributed=False, local_rank=0):
        self.dataset = dataset
        self.config_path = config_path
        self.device = device
        self.distributed = distributed
        self.local_rank = local_rank
        
        # Check if a checkpoint.pth exists in the checkpoint_path directory
        if checkpoint_path is None:
            raise ValueError("Checkpoint path cannot be None. Please provide a valid path to save or load the model.")
        
        self.checkpoint_path = os.path.join(checkpoint_path, "checkpoint.pth")

        # These attributes will be initialized in the subclasses
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.current_epoch = 0

        print(f"[Rank {self.local_rank}] [BaseModelWrapper] Initializing model with config: {self.config_path}")

        self.__load_config__()

        # Note: This is needed for Faster R-CNN and FCOS models 
        # They have internal transforms that need to be applied
        self.__initialize_model__()

        if os.path.exists(self.checkpoint_path):
            print(f"[Rank {self.local_rank}] [BaseModelWrapper] Loading model from checkpoint: {checkpoint_path}")
            self.__load_model__()

    @abstractmethod
    def __load_config__(self):
        """
        Load the model configuration from the specified path.
        This method should be implemented by subclasses to load their specific configurations.
        """
        pass

    @abstractmethod
    def __initialize_model__(self, train_transform=None, test_transform=None):
        """
        Initialize the model based on the loaded configuration.
        This method should be implemented by subclasses to set up their specific models.
        """
        pass

    @abstractmethod
    def transform_test(self, bbox_input_format):
        """
        Transform the test dataset based on the bounding box input format.
        This method should be implemented by subclasses to apply necessary transformations.
        
        Args:
            bbox_input_format (str): The format of the bounding boxes (e.g., 'pascal_voc', 'yolo').
        """
        pass

    @abstractmethod
    def transform_train(self, bbox_input_format):
        """
        Transform the training dataset based on the bounding box input format.
        This method should be implemented by subclasses to apply necessary transformations.
        
        Args:
            bbox_input_format (str): The format of the bounding boxes (e.g., 'pascal_voc', 'yolo').
        """
        pass

    def __save_model__(self):
    
        if hasattr(self.model, 'module'):
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        ckpt = {
            'current_epoch': self.current_epoch,
            'model': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }

        torch.save(ckpt, self.checkpoint_path)

    def __load_model__(self):
        ckpt = torch.load(self.checkpoint_path, map_location=self.device)

        if hasattr(self.model, 'module'):
            self.model.module.load_state_dict(ckpt['model'])
        else:
            self.model.load_state_dict(ckpt['model'])

        self.optimizer.load_state_dict(ckpt['optimizer'])
        self.scheduler.load_state_dict(ckpt['scheduler'])
        self.current_epoch = ckpt['current_epoch']

    def train_one_epoch(self, dataloader, epoch):

        # 1. Get the model, optimizer, and device
        if self.distributed:
            # Ensure the dataloader is synchronized across all ranks
            dataloader.sampler.set_epoch(epoch)
            self.model.module.train()
        else:
            self.model.train()

        # Temporary learning rate scheduler for warmup
        # This is only used for the first epoch to warm up the learning rate.
        # After that, the main lr_scheduler will be used.
        temp_lr_scheduler = None
        if epoch == 0:

            wramup_iters = min(1000, len(dataloader) - 1)
            temp_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                self.optimizer,
                start_factor=0.1/1000,
                total_iters=wramup_iters,
            )


        # 2. Iterate over the dataset
        epoch_losses = {}
        for i, (images, targets) in enumerate(dataloader):

            # Move inputs to GPU/CPU
            images = [img.to(self.device, non_blocking=True) for img in images]

            # Transform the targets to the correct device
            train_targets = []
            for target in targets:
                new_target = {}
                new_target["boxes"] = target["boxes"].to(self.device, non_blocking=True)
                new_target["labels"] = target["labels"].to(self.device, non_blocking=True)
                new_target["poison_masks"] = target["poison_masks"].to(self.device, non_blocking=True)
                new_target["target_labels"] = target["target_labels"].to(self.device, non_blocking=True)
                train_targets.append(new_target)

                # If new_target labels or target_labels is < 0 raise an error
                if (new_target["labels"] < 0).any():
                    raise ValueError(f"Negative labels found in target: {new_target['labels']}")

            loss_dict = self.model(images, train_targets)

            # Sum all losses
            losses = sum(loss for loss in loss_dict.values())

            loss_value = float(losses.detach())
            if not math.isfinite(loss_value):
                print(f"Loss is {loss_value}, stopping training.")
                print(loss_dict)
                sys.exit(1)

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            if temp_lr_scheduler is not None:
                temp_lr_scheduler.step()

            # Store the losses for logging
            for key, value in loss_dict.items():
                if key not in epoch_losses:
                    epoch_losses[key] = value.item()
                else:
                    epoch_losses[key] += value.item()

        # 3. Update the epoch number
        self.current_epoch += 1

        if self.scheduler is not None:
            self.scheduler.step()

        # We do NOT step the main lr_scheduler here.
        # We'll do that outside this function, ensuring only rank 0 updates it.
        return epoch_losses