# custom_trainer.py
from ultralytics.models.yolo.detect import DetectionTrainer
import torch
import random
import gc
import math
import os
import subprocess
import time
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path

import numpy as np
import torch
from torch import distributed as dist
from torch import nn, optim

from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (
    DEFAULT_CFG,
    LOGGER,
    RANK,
    TQDM,
    __version__,
    callbacks,
    clean_url,
    colorstr,
    emojis,
    yaml_save,
)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (
    EarlyStopping,
    ModelEMA,
    convert_optimizer_state_dict_to_fp16,
    init_seeds,
    one_cycle,
    select_device,
    strip_optimizer,
    torch_distributed_zero_first,
)

# Global parameters for differential privacy
epsilon = 500.0  # Privacy budget
delta = 1e-5  # Probability of privacy guarantee not holding
sensitivity = 1.0  # Sensitivity of the loss function

class CustomDetectionTrainer(DetectionTrainer):

    # def _add_input_noise(self, images):
    #     """Function to add DP noise to the input images."""
    #     sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
        
    #     noise = torch.normal(0, 0, size=images.size()).to(images.device)
    #     return images + noise
    
    # def preprocess_batch(self, batch):
    #     return batch
    #     """Preprocess the batch by adding DP noise to the inputs."""
    #     images, targets = batch["img"], batch["cls"]
        
    #     #images = self._add_input_noise(images)
    #     images = images
        
    #     # Ensure the images tensor matches the model's expected type (e.g., FP16)
    #     if self.amp:
    #         batch["img"] = images.half()
    #     else:
    #         batch["img"] = images.float()
        
    #     batch["cls"] = targets  # Ensure targets are on the same device

    #     return batch


    def _add_input_noise(self, images):
        """Function to add DP noise to the input images."""
        sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
        noise = torch.normal(0, sigma, size=images.size()).to(self.device)
        return images + noise
    
    def preprocess_batch(self, batch):
        """Preprocesses a batch of images by scaling and converting to float."""
        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
        if self.args.multi_scale:
            imgs = batch["img"]
            sz = (
                random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
                // self.stride
                * self.stride
            )  # size
            sf = sz / max(imgs.shape[2:])  # scale factor
            if sf != 1:
                ns = [
                    math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
                ]  # new shape (stretched to gs-multiple)
                imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
            batch["img"] = imgs
        
        # Add differential privacy noise to the input images
        batch["img"] = self._add_input_noise(batch["img"])
        return batch

    # def _do_train(self, world_size=1):
    #     """Train completed, evaluate and plot if specified by arguments."""
    #     if world_size > 1:
    #         self._setup_ddp(world_size)
    #     self._setup_train(world_size)

    #     nb = len(self.train_loader)  # number of batches
    #     nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations
    #     last_opt_step = -1
    #     self.epoch_time = None
    #     self.epoch_time_start = time.time()
    #     self.train_time_start = time.time()
    #     self.run_callbacks("on_train_start")
    #     LOGGER.info(
    #         f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
    #         f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
    #         f"Logging results to {colorstr('bold', self.save_dir)}\n"
    #         f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
    #     )
    #     if self.args.close_mosaic:
    #         base_idx = (self.epochs - self.args.close_mosaic) * nb
    #         self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
    #     epoch = self.start_epoch
    #     self.optimizer.zero_grad()  # zero any resumed gradients to ensure stability on train start
    #     while True:
    #         self.epoch = epoch
    #         self.run_callbacks("on_train_epoch_start")
    #         with warnings.catch_warnings():
    #             warnings.simplefilter("ignore")  # suppress 'Detected lr_scheduler.step() before optimizer.step()'
    #             self.scheduler.step()

    #         self.model.train()
    #         if RANK != -1:
    #             self.train_loader.sampler.set_epoch(epoch)
    #         pbar = enumerate(self.train_loader)
    #         # Update dataloader attributes (optional)
    #         if epoch == (self.epochs - self.args.close_mosaic):
    #             self._close_dataloader_mosaic()
    #             self.train_loader.reset()

    #         if RANK in {-1, 0}:
    #             LOGGER.info(self.progress_string())
    #             pbar = TQDM(enumerate(self.train_loader), total=nb)
    #         self.tloss = None
    #         for i, batch in pbar:
    #             self.run_callbacks("on_train_batch_start")
    #             # Warmup
    #             ni = i + nb * epoch
    #             if ni <= nw:
    #                 xi = [0, nw]  # x interp
    #                 self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
    #                 for j, x in enumerate(self.optimizer.param_groups):
    #                     # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
    #                     x["lr"] = np.interp(
    #                         ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
    #                     )
    #                     if "momentum" in x:
    #                         x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])

    #             # Forward
    #             with torch.cuda.amp.autocast(self.amp):
    #                 #batch = self.preprocess_batch(batch)
    #                 self.loss, self.loss_items = self.model(batch)
    #                 if RANK != -1:
    #                     self.loss *= world_size
    #                 self.tloss = (
    #                     (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
    #                 )

    #             # Backward
    #             self.scaler.scale(self.loss).backward()

    #             # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
    #             if ni - last_opt_step >= self.accumulate:
    #                 self.optimizer_step()
    #                 last_opt_step = ni

    #                 # Timed stopping
    #                 if self.args.time:
    #                     self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
    #                     if RANK != -1:  # if DDP training
    #                         broadcast_list = [self.stop if RANK == 0 else None]
    #                         dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
    #                         self.stop = broadcast_list[0]
    #                     if self.stop:  # training time exceeded
    #                         break

    #             # Log
    #             mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"  # (GB)
    #             loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
    #             losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
    #             if RANK in {-1, 0}:
    #                 pbar.set_description(
    #                     ("%11s" * 2 + "%11.4g" * (2 + loss_len))
    #                     % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
    #                 )
    #                 self.run_callbacks("on_batch_end")
    #                 if self.args.plots and ni in self.plot_idx:
    #                     self.plot_training_samples(batch, ni)

    #             self.run_callbacks("on_train_batch_end")

    #         self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
    #         self.run_callbacks("on_train_epoch_end")
    #         if RANK in {-1, 0}:
    #             final_epoch = epoch + 1 >= self.epochs
    #             self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])

    #             # Validation
    #             if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
    #                 self.metrics, self.fitness = self.validate()
    #             self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
    #             self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
    #             if self.args.time:
    #                 self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)

    #             # Save model
    #             if self.args.save or final_epoch:
    #                 self.save_model()
    #                 self.run_callbacks("on_model_save")

    #         # Scheduler
    #         t = time.time()
    #         self.epoch_time = t - self.epoch_time_start
    #         self.epoch_time_start = t
    #         if self.args.time:
    #             mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
    #             self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
    #             self._setup_scheduler()
    #             self.scheduler.last_epoch = self.epoch  # do not move
    #             self.stop |= epoch >= self.epochs  # stop if exceeded epochs
    #         self.run_callbacks("on_fit_epoch_end")
    #         gc.collect()
    #         torch.cuda.empty_cache()  # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors

    #         # Early Stopping
    #         if RANK != -1:  # if DDP training
    #             broadcast_list = [self.stop if RANK == 0 else None]
    #             dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
    #             self.stop = broadcast_list[0]
    #         if self.stop:
    #             break  # must break all DDP ranks
    #         epoch += 1

    #     if RANK in {-1, 0}:
    #         # Do final val with best.pt
    #         LOGGER.info(
    #             f"\n{epoch - self.start_epoch + 1} epochs completed in "
    #             f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
    #         )
    #         self.final_eval()
    #         if self.args.plots:
    #             self.plot_metrics()
    #         self.run_callbacks("on_train_end")
    #     gc.collect()
    #     torch.cuda.empty_cache()
    #     self.run_callbacks("teardown")

    # # def preprocess_batch(self, batch):
    # #     return batch