import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from collections import defaultdict
import mlflow
import torchvision.models as models
from torchvision.models import resnet50, ResNet50_Weights
import random
import numpy as np
import os
from torchviz import make_dot
from utils import EarlyStopping, set_seed, ParamDict, WholeFish, ResNet50_MLP, normalize_to_distribution
from pytorch_msssim import SSIM, MS_SSIM
from geomloss import SamplesLoss
import time

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

class ERMTrainer:
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
                 seed=42, early_stopping=False, **early_stopping_kwargs):
        set_seed(seed)
        self.run_name = run_name
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion 
        self.optimizer = optimizer
        self.checkpoint_path = checkpoint_path
        self.best_val_aa = 0.0
        self.best_val_wga = 0.0
        self.best_val_wga_group = None
        self.best_val_wga_idx = 0
        self.device = device
        self.mode = label_mode
        self.tuple_to_group = {tuple_[:2]: group_ for tuple_, group_ in tuple_to_group.items()}
        self.groups = [group for group in tuple_to_group.values()]
        self.early_stopping = early_stopping
        self.early_stopping_kwargs = early_stopping_kwargs

    def train_epoch(self):
        self.model.train()

        total_loss_sum = 0.0
        num_samples = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        epoch_start_time = time.perf_counter()
        for images, labels, groups in tqdm(self.train_loader, desc="Training"):
            batch_start_time = time.perf_counter()

            images = images.to(self.device, non_blocking=True)
            
            if self.mode == 'binary':
                labels_cast = labels.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_cast = labels.long().view(-1).to(self.device, non_blocking=True)

            group_list = groups.tolist()

            self.optimizer.zero_grad()

            outputs = self.model(images)
            total_loss = self.criterion(outputs, labels_cast)
            
            total_loss.backward()

            self.optimizer.step()

            batch_size = images.size(0)
            total_loss_sum += total_loss.item() * batch_size
            num_samples += batch_size

            if self.mode == 'binary':
                predicted_labels = (outputs >= 0.0).long()
            else:
                predicted_labels = torch.argmax(outputs, dim=1)

            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                total_labels_per_group[g_idx] += 1

            batch_time = time.perf_counter() - batch_start_time
            batch_times.append(batch_time)

        epoch_metrics, label_group_accuracy_map = self.compute_train_epoch_metrics(
                total_loss_sum, num_samples,
                correct_labels_per_group, total_labels_per_group,
                batch_times, epoch_start_time)

        return epoch_metrics, label_group_accuracy_map

    def compute_train_epoch_metrics(self, total_loss_sum, num_samples,
                                    correct_labels_per_group, total_labels_per_group,
                                    batch_times, epoch_start_time):
        epoch_time = time.perf_counter() - epoch_start_time
        epoch_metrics = {}

        epoch_metrics['avg_batch_time'] = np.mean(batch_times).item()
        epoch_metrics['std_batch_time'] = np.std(batch_times, ddof=1).item()
        epoch_metrics['epoch_time'] = epoch_time
        epoch_metrics['total_loss'] = total_loss_sum / num_samples
        epoch_metrics['label_loss'] = total_loss_sum / num_samples

        group_accuracies_label = []
        
        for t_idx in self.tuple_to_group:
            if total_labels_per_group[t_idx] > 0:
                group_accuracies_label.append(correct_labels_per_group[t_idx] / total_labels_per_group[t_idx])
            else:
                group_accuracies_label.append(0.0)

        epoch_metrics['label_group_accuracies'] = torch.tensor(group_accuracies_label)
        epoch_metrics['label_avg_acc'] = sum(correct_labels_per_group.values()) / sum(total_labels_per_group.values())
        epoch_metrics['label_worst_group_acc'] = torch.min(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group_idx'] = torch.argmin(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group'] = self.groups[epoch_metrics['label_worst_group_idx']]

        label_group_accuracy_map = dict(zip(self.groups, group_accuracies_label))

        return epoch_metrics, label_group_accuracy_map

    def save_checkpoint(self, epoch, val_wga, val_aa, val_wga_idx, val_wga_group):
        print(f"💾 Saving checkpoint... (val_wga: {val_wga:.4f}; val_aa: {val_aa:.4f})")
        
        torch.save({
            'epoch': epoch,
            'label_model_state_dict': self.model.state_dict(),
            'optimizer_label_state_dict': self.optimizer.state_dict(),
            'val_wga': val_wga,
            'val_aa': val_aa,
            'val_wga_idx': val_wga_idx,
            'val_wga_group': val_wga_group
        }, self.checkpoint_path)

    def validate_epoch(self):
        self.model.eval()

        label_loss_sum = 0.0
        num_samples = 0

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        epoch_start_time = time.perf_counter()
        with torch.no_grad():
            for images, labels, groups in tqdm(self.val_loader, desc="Validation"):
                images = images.to(self.device, non_blocking=True)
                
                if self.mode == 'binary':
                    labels_cast = labels.float().view(-1, 1).to(self.device, non_blocking=True)
                else:
                    labels_cast = labels.long().view(-1).to(self.device, non_blocking=True)

                group_list = groups.tolist()
    
                outputs_labels = self.model(images)
    
                loss_label = self.criterion(outputs_labels, labels_cast)
    
                batch_size = images.size(0)
                label_loss_sum += loss_label.item() * batch_size
                num_samples += batch_size
    
                if self.mode == 'binary':
                    predicted_labels = (outputs_labels >= 0.0).long()
                else:
                    predicted_labels = torch.argmax(outputs_labels, dim=1)
    
                for i, g_idx in enumerate(group_list):
                    if type(g_idx) is list:
                        g_idx = tuple(g_idx[:2])
                    correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                    total_labels_per_group[g_idx] += 1
        
        epoch_metrics, label_group_accuracy_map = self.compute_val_epoch_metrics(
                label_loss_sum, num_samples,
                correct_labels_per_group, total_labels_per_group,
                epoch_start_time)
        
        return epoch_metrics, label_group_accuracy_map

    def compute_val_epoch_metrics(self, label_loss_sum, num_samples,
                                  correct_labels_per_group, total_labels_per_group,
                                  epoch_start_time):
        epoch_time = time.perf_counter() - epoch_start_time
        epoch_metrics = {}

        epoch_metrics['label_loss'] = label_loss_sum / num_samples
        epoch_metrics['epoch_time'] = epoch_time

        group_accuracies_label = []
        
        for t_idx in self.tuple_to_group:
            if total_labels_per_group[t_idx] > 0:
                group_accuracies_label.append(correct_labels_per_group[t_idx] / total_labels_per_group[t_idx])
            else:
                group_accuracies_label.append(0.0)

        epoch_metrics['label_group_accuracies'] = torch.tensor(group_accuracies_label)
        epoch_metrics['label_avg_acc'] = sum(correct_labels_per_group.values()) / sum(total_labels_per_group.values()) 
        epoch_metrics['label_worst_group_acc'] = torch.min(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group_idx'] = torch.argmin(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group'] = self.groups[epoch_metrics['label_worst_group_idx']]

        label_group_accuracy_map = dict(zip(self.groups, group_accuracies_label))

        return epoch_metrics, label_group_accuracy_map  

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            else:
                early_stop = None
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lr_label", self.optimizer.param_groups[0]['lr'])
            mlflow.log_param("weight_decay_label", self.optimizer.param_groups[0]['weight_decay'])
            print(f"\nlr_label={self.optimizer.param_groups[0]['lr']}")
            print(f"wd_label={self.optimizer.param_groups[0]['weight_decay']}")
            self.record_log(num_epochs, early_stop)

    def record_log(self, num_epochs, early_stop=None):
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            train_epoch_metrics, train_label_map = self.train_epoch()
            val_epoch_metrics, val_label_map = self.validate_epoch()

            print(f"🏋️‍♂️⏳ Train Epoch Time: {train_epoch_metrics['epoch_time']:4f}")
            print(f"🏋️‍♂️📦⏳ Train Avg Batch Time: {train_epoch_metrics['avg_batch_time']:4f} +- {train_epoch_metrics['std_batch_time']:4f}")
            print(f"🏋️‍♂️💯 Train Total Loss: {train_epoch_metrics['total_loss']:.4f}")
            print(f"🏋️‍♂️🔖 Train Label Loss: {train_epoch_metrics['label_loss']:.4f}")
            print(f"🏋️‍♂️🔖⬇️ Worst Group Train Label Accuracy: {train_epoch_metrics['label_worst_group_acc']:.4f} (Group:  {train_epoch_metrics['label_worst_group']})")
            print(f"🏋️‍♂️🔖🎯 Average Train Label Accuracy: {train_epoch_metrics['label_avg_acc']:.4f}")

            print("\n")
            print(f"🕵⏳ Val Epoch Time: {val_epoch_metrics['epoch_time']:4f}")
            print(f"🕵🔖 Val Label Loss: {val_epoch_metrics['label_loss']:.4f}")
            print(f"🕵🔖⬇️ Worst Group Val Label Accuracy: {val_epoch_metrics['label_worst_group_acc']:.4f} (Group: {val_epoch_metrics['label_worst_group']})")
            print(f"🕵🔖🎯 Average Val Label Accuracy: {val_epoch_metrics['label_avg_acc']:.4f}")

            print("\n----------------------------------------------------------------------------\n")

            print("📊 Train Label Group-wise Accuracy:")
            for g in sorted(train_label_map.keys()):
                acc = train_label_map[g]
                print(f"  Group {g}: {acc:.3f}")
                mlflow.log_metric(f"train_{g}_aa", acc, step=epoch)

            print("\n----------------------------------------------------------------------------\n")

            print("📊 Val Label Group-wise Accuracy:")
            for g in sorted(val_label_map.keys()):
                acc = val_label_map[g]
                print(f"  Group {g}: {acc:.3f}")
                mlflow.log_metric(f"val_{g}_aa", acc, step=epoch)

            print("\n----------------------------------------------------------------------------\n")

            mlflow.log_metrics({
                "train_total_loss": train_epoch_metrics['total_loss'],
                "train_label_loss": train_epoch_metrics['label_loss'],
                "train_wga_label": train_epoch_metrics['label_worst_group_acc'],
                "train_wga_label_idx": train_epoch_metrics['label_worst_group_idx'],
                "train_aa_label": train_epoch_metrics['label_avg_acc'],
                "train_avg_batch_time": train_epoch_metrics['avg_batch_time'],
                "train_std_batch_time": train_epoch_metrics['std_batch_time'],
                "train_epoch_time": train_epoch_metrics['epoch_time'],
                "val_epoch_time": val_epoch_metrics['epoch_time'],
                "val_label_loss": val_epoch_metrics['label_loss'],
                "val_wga_label": val_epoch_metrics['label_worst_group_acc'],
                "val_wga_label_idx": val_epoch_metrics['label_worst_group_idx'],
                "val_aa_label": val_epoch_metrics['label_avg_acc'],
            }, step=epoch)

            if 0.5 * (val_epoch_metrics['label_worst_group_acc'] + val_epoch_metrics['label_avg_acc']) > 0.5 * (self.best_val_wga + self.best_val_aa):
                print("✅ Validation improved!")
                self.best_val_wga = val_epoch_metrics['label_worst_group_acc']
                self.best_val_aa = val_epoch_metrics['label_avg_acc']
                self.best_val_wga_idx = val_epoch_metrics['label_worst_group_idx']
                self.best_val_wga_group = val_epoch_metrics['label_worst_group']
                self.save_checkpoint(epoch + 1, val_epoch_metrics['label_worst_group_acc'], \
                                     val_epoch_metrics['label_avg_acc'], val_epoch_metrics['label_worst_group_idx'],
                                    val_epoch_metrics['label_worst_group'])

            else:
                print("🔁 Validation did not improve.")

            if self.early_stopping:
                avg_early_stopping_metric = 0.5 * (val_epoch_metrics['label_worst_group_acc'] + val_epoch_metrics['label_avg_acc'])
                mlflow.log_metric("val_avg_early_stopping_metric", avg_early_stopping_metric, step=epoch)
                early_stop(avg_early_stopping_metric)
                if early_stop.early_stop:
                    print(f"⏹️ Early stopping triggered at epoch {epoch + 1} (val_WAA = {avg_early_stopping_metric:.4f})")
                    mlflow.log_param("early_stopped_epoch", epoch + 1)
                    break

        print(f"🥇⬇️ Highest Worst Group Val Label Accuracy: {self.best_val_wga:.4f} (Group: {self.best_val_wga_group})")
        print(f"🥇🎯 Highest Average Val Label Accuracy: {self.best_val_aa:.4f}")
        
        mlflow.log_metrics({
            "best_val_wga": self.best_val_wga,
            "best_val_aa": self.best_val_aa,
            "best_val_wga_idx": self.best_val_wga_idx
        })
        mlflow.log_param("val_wga_label_group", self.best_val_wga_group)

class TwinModelTrainer:
    def __init__(self, run_name, label_model, group_model, train_loader, val_loader, criterion_label, criterion_group,
                 optimizer_label, optimizer_group, tuple_to_group, metric='softiou', lambda_group=1, lambda_metric=1, 
                 normalize=False, label_mode='binary', group_mode='binary', blur=0.05, scaling=0.9,
                 device='cuda', checkpoint_path='checkpoint.pth', seed=42, early_stopping=False, **early_stopping_kwargs):
        set_seed(seed)
        self.run_name = run_name
        self.label_model = label_model.to(device)
        self.group_model = group_model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion_label = criterion_label
        self.criterion_group = criterion_group
        self.optimizer_label = optimizer_label
        self.optimizer_group = optimizer_group
        self.checkpoint_path = checkpoint_path
        self.best_val_aa = 0.0
        self.best_val_wga = 0.0
        self.best_val_wga_group = None
        self.best_val_wga_idx = 0
        self.device = device
        self.metric = metric
        self.label_mode = label_mode
        self.group_mode = group_mode
        self.lambda_metric = lambda_metric
        self.lambda_group = lambda_group
        self.tuple_to_group = {tuple_[:2]: group_ for tuple_, group_ in tuple_to_group.items()}
        self.groups = [group for group in tuple_to_group.values()]
        self.normalize = normalize
        self.early_stopping = early_stopping
        self.early_stopping_kwargs = early_stopping_kwargs
        self.blur = blur
        self.scaling = scaling

    def softiou_loss(self, map1, map2, epsilon=1e-6):
        map1_flat = map1.view(map1.size(0), -1)
        map2_flat = map2.view(map2.size(0), -1)

        intersection = (map1_flat * map2_flat).sum(dim=1)
        union = map1_flat.sum(dim=1) + map2_flat.sum(dim=1) - intersection

        iou_per_item = (intersection + epsilon) / (union + epsilon)
        return iou_per_item.mean()

    def soft_dice_loss(self, map1, map2, epsilon=1e-6):
        map1_flat = map1.view(map1.size(0), -1)
        map2_flat = map2.view(map2.size(0), -1)

        intersection = (map1_flat * map2_flat).sum(dim=1)
        union = map1_flat.sum(dim=1) + map2_flat.sum(dim=1)

        dice_per_item = (2 * intersection + epsilon) / (union + epsilon)
        return dice_per_item.mean()

    def cosine_sim_loss(self, map1, map2):
        map1 = normalize_to_distribution(map1)
        map2 = normalize_to_distribution(map2)
        return F.cosine_similarity(map1.view(map1.size(0), -1), map2.view(map2.size(0), -1)).mean()

    def mse_loss(self, map1, map2):
        return -(((map1.view(map1.size(0), -1) - map2.view(map2.size(0), -1)) ** 2).mean())

    def rmse_loss(self, map1, map2):
        return -torch.sqrt((((map1.view(map1.size(0), -1) - map2.view(map2.size(0), -1)) ** 2).mean()))
        
    def mae_loss(self, map1, map2):
        return -((torch.abs(map1.view(map1.size(0), -1) - map2.view(map2.size(0), -1))).mean())

    def ssim_loss(self, map1, map2):
        if map1.dim() == 3:
            map1 = map1.unsqueeze(1)
            map2 = map2.unsqueeze(1)
        ssim = SSIM(data_range=1, size_average=True, win_size=11,
                    channel=1)
        ssim_loss = ssim(map1, map2)
        return ssim_loss
        
    def ms_ssim_loss(self, map1, map2):
        if map1.dim() == 3:
            map1 = map1.unsqueeze(1)
            map2 = map2.unsqueeze(1)
        ms_ssim = MS_SSIM(data_range=1, size_average=True, win_size=11,
                    channel=1)
        ms_ssim_loss = ms_ssim(map1, map2)
        return ms_ssim_loss

    def ncc_loss(self, map1, map2, eps=1e-8):
        mean_map1 = map1.mean(dim=(-2, -1), keepdim=True)
        mean_map2 = map2.mean(dim=(-2, -1), keepdim=True)
        map1_sub = map1 - mean_map1
        map2_sub = map2 - mean_map2
        num = torch.sum((map1_sub * map2_sub), dim=(-2, -1), keepdim=True)
        map1_denom = torch.sqrt(torch.sum((map1_sub ** 2), dim=(-2, -1), keepdim=True))
        map2_denom = torch.sqrt(torch.sum((map2_sub ** 2), dim=(-2, -1), keepdim=True))
        denom = (map1_denom * map2_denom) + eps
        return torch.mean(num / denom)

    def em_dist_loss(self, map1, map2):
        map1 = normalize_to_distribution(map1)
        map2 = normalize_to_distribution(map2)
        loss = SamplesLoss(
            loss="sinkhorn",
            p=1,
            blur=0.05,
            reach=None,
            scaling=0.9,
            debias=False,
            backend="online",
            verbose=False,
        )
        emd = loss(map1, map2)
        return -emd

    def kl_div_loss(self, input_map, target_map):
        input_map = normalize_to_distribution(input_map)
        target_map = normalize_to_distribution(target_map)
        
        log_input = torch.log(input_map)
        return -F.kl_div(log_input, target_map, reduction='batchmean')

    def js_div_loss(self, map1, map2):
        mean_map = normalize_to_distribution((map1 + map2) / 2)
        map1 = normalize_to_distribution(map1)
        map2 = normalize_to_distribution(map2)
        log_input = torch.log(mean_map)

        kl_div_0 = F.kl_div(log_input, map1, reduction='batchmean')
        kl_div_1 = F.kl_div(log_input, map2, reduction='batchmean')
        return -((kl_div_0 + kl_div_1) / 2)

    def js_dist_loss(self, map1, map2):
        mean_map = normalize_to_distribution((map1 + map2) / 2)
        map1 = normalize_to_distribution(map1)
        map2 = normalize_to_distribution(map2)
        log_input = torch.log(mean_map)

        kl_div_0 = F.kl_div(log_input, map1, reduction='batchmean')
        kl_div_1 = F.kl_div(log_input, map2, reduction='batchmean')
        js_div = ((kl_div_0 + kl_div_1) / 2)
        return -torch.sqrt(js_div)
        
    def get_grad_cam_map(self, model, images, targets):
        activations = {}
        def hook_fn(module, input, output):
            if output.requires_grad:
                output.retain_grad() 
            activations['value'] = output
                                                                               
        hook_handle = model.target_layer.register_forward_hook(hook_fn)

        logits = model(images)
        
        if self.label_mode == 'binary' and model is self.label_model:
            target_logit_for_grad = torch.where(targets.bool(), logits, -logits)
        elif self.group_mode == 'binary' and model is self.group_model:
            target_logit_for_grad = torch.where(targets.bool(), logits, -logits)
        else:
            target_logit_for_grad = logits[torch.arange(logits.size(0)), targets].unsqueeze(1)

        gradients = torch.autograd.grad(outputs=target_logit_for_grad.sum(),
                                         inputs=activations['value'],
                                         retain_graph=True,
                                         create_graph=False)[0]

        hook_handle.remove()

        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)

        cam = torch.relu((weights * activations['value']).sum(dim=1))

        if self.normalize or self.metric == 'ssim' or self.metric == 'ms_ssim':
            batch_min = cam.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
            batch_max = cam.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
    
            normalized_cam = (cam - batch_min) / (batch_max - batch_min + 1e-6)
            return normalized_cam, logits
        else:
            return cam, logits

    def train_epoch(self):
        self.label_model.train()
        self.group_model.train()

        total_loss_sum = 0.0
        label_loss_sum = 0.0
        group_loss_sum = 0.0
        metric_loss_sum = 0.0
        num_samples = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)
        correct_groups_per_group = defaultdict(int)
        total_groups_per_group = defaultdict(int)

        epoch_start_time = time.perf_counter()
        for images, labels, groups in tqdm(self.train_loader, desc="Training"):
            batch_start_time = time.perf_counter()

            images = images.to(self.device, non_blocking=True)
            if self.label_mode == 'binary':
                labels_cast = labels.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_cast = labels.long().view(-1).to(self.device, non_blocking=True)
            
            if self.group_mode == 'binary':
                groups_cast = groups[:, 0].float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                groups_cast = groups[:, 0].long().view(-1).to(self.device, non_blocking=True)

            group_list = groups.tolist()

            self.optimizer_label.zero_grad()
            self.optimizer_group.zero_grad()

            label_cam_map, outputs_labels = self.get_grad_cam_map(self.label_model, images, labels_cast.long().squeeze().to(self.device))
            group_cam_map, outputs_groups = self.get_grad_cam_map(self.group_model, images, groups[:, 0].long().squeeze().to(self.device))

            if self.metric == 'softiou':
                loss_metric = self.softiou_loss(group_cam_map, label_cam_map)
            elif self.metric == 'cosine':
                loss_metric = self.cosine_sim_loss(group_cam_map, label_cam_map)
            elif self.metric == 'mse':
                loss_metric = self.mse_loss(group_cam_map, label_cam_map)
            elif self.metric == 'mae':
                loss_metric = self.mae_loss(group_cam_map, label_cam_map)
            elif self.metric == 'rmse':
                loss_metric = self.rmse_loss(group_cam_map, label_cam_map)
            elif self.metric == 'kl_div':
                loss_metric = self.kl_div_loss(label_cam_map, group_cam_map)
            elif self.metric == 'js_div':
                loss_metric = self.js_div_loss(label_cam_map, group_cam_map)
            elif self.metric == 'js_dist':
                loss_metric = self.js_dist_loss(label_cam_map, group_cam_map)
            elif self.metric == 'em_dist':
                loss_metric = self.em_dist_loss(label_cam_map, group_cam_map)
            elif self.metric == 'ncc':
                loss_metric = self.ncc_loss(label_cam_map, group_cam_map)
            elif self.metric == 'ssim':
                loss_metric = self.ssim_loss(label_cam_map, group_cam_map)
            elif self.metric == 'ms_ssim':
                loss_metric = self.ms_ssim_loss(label_cam_map, group_cam_map)
            elif self.metric == 'soft_dice':
                loss_metric = self.soft_dice_loss(label_cam_map, group_cam_map)
            else:
                raise ValueError(f"Unknown metric: {self.metric}")

            loss_label = self.criterion_label(outputs_labels, labels_cast)
            loss_group = self.criterion_group(outputs_groups, groups_cast)

            total_loss = loss_label + self.lambda_group * loss_group + self.lambda_metric * loss_metric
            
            total_loss.backward()

            self.optimizer_label.step()
            self.optimizer_group.step()

            batch_size = images.size(0)
            total_loss_sum += total_loss.item() * batch_size
            label_loss_sum += loss_label.item() * batch_size
            group_loss_sum += loss_group.item() * batch_size
            metric_loss_sum += loss_metric.item() * batch_size
            num_samples += batch_size

            if self.label_mode == 'binary':
                predicted_labels = (outputs_labels >= 0.0).long()
            else:
                predicted_labels = torch.argmax(outputs_labels, dim=1)

            if self.group_mode == 'binary':
                predicted_groups = (outputs_groups >= 0.0).long()
            else:
                predicted_groups = torch.argmax(outputs_groups, dim=1)

            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                total_labels_per_group[g_idx] += 1
            
            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_groups_per_group[g_idx] += (predicted_groups[i] == groups[:, 0][i]).item()
                total_groups_per_group[g_idx] += 1

            batch_times.append(time.perf_counter() - batch_start_time)
                
        epoch_time = time.perf_counter() - epoch_start_time
        epoch_metrics = {}

        epoch_metrics['total_loss'] = total_loss_sum / num_samples
        epoch_metrics['label_loss'] = label_loss_sum / num_samples
        epoch_metrics['group_loss'] = group_loss_sum / num_samples
        epoch_metrics['metric_loss'] = metric_loss_sum / num_samples
        epoch_metrics['avg_batch_time'] = np.mean(batch_times).item()
        epoch_metrics['std_batch_time'] = np.std(batch_times).item()
        epoch_metrics['epoch_time'] = epoch_time

        group_accuracies_label = []
        
        for t_idx in self.tuple_to_group:
            if total_labels_per_group[t_idx] > 0:
                group_accuracies_label.append(correct_labels_per_group[t_idx] / total_labels_per_group[t_idx])
            else:
                group_accuracies_label.append(0.0)

        epoch_metrics['label_group_accuracies'] = torch.tensor(group_accuracies_label)
        epoch_metrics['label_avg_acc'] = sum(correct_labels_per_group.values()) / sum(total_labels_per_group.values())
        epoch_metrics['label_worst_group_acc'] = torch.min(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group_idx'] = torch.argmin(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group'] = self.groups[epoch_metrics['label_worst_group_idx']]

        group_accuracies_group = []
        for t_idx in self.tuple_to_group:
            if total_groups_per_group[t_idx] > 0:
                group_accuracies_group.append(correct_groups_per_group[t_idx] / total_groups_per_group[t_idx])
            else:
                group_accuracies_group.append(0.0)

        epoch_metrics['group_group_accuracies'] = torch.tensor(group_accuracies_group)
        epoch_metrics['group_avg_acc'] = sum(correct_groups_per_group.values()) / sum(total_groups_per_group.values()) 
        epoch_metrics['group_worst_group_acc'] = torch.min(epoch_metrics['group_group_accuracies']).item()
        epoch_metrics['group_worst_group_idx'] = torch.argmin(epoch_metrics['group_group_accuracies']).item()
        epoch_metrics['group_worst_group'] = self.groups[epoch_metrics['group_worst_group_idx']]

        label_group_accuracy_map = dict(zip(self.groups, group_accuracies_label))
        group_group_accuracy_map = dict(zip(self.groups, group_accuracies_group))

        return epoch_metrics, label_group_accuracy_map, group_group_accuracy_map

    def save_checkpoint(self, epoch, val_wga, val_aa, val_wga_idx, val_wga_group):
        print(f"💾 Saving checkpoint... (val_wga: {val_wga:.4f}; val_aa: {val_aa:.4f})")
        
        torch.save({
            'epoch': epoch,
            'label_model_state_dict': self.label_model.state_dict(),
            'group_model_state_dict': self.group_model.state_dict(),
            'optimizer_label_state_dict': self.optimizer_label.state_dict(),
            'optimizer_group_state_dict': self.optimizer_group.state_dict(),
            'val_wga': val_wga,
            'val_aa': val_aa,
            'val_wga_idx': val_wga_idx,
            'val_wga_group': val_wga_group
        }, self.checkpoint_path)

    def validate_epoch(self):
        self.label_model.eval()
        self.group_model.eval()

        label_loss_sum = 0.0
        group_loss_sum = 0.0
        num_samples = 0

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)
        correct_groups_per_group = defaultdict(int)
        total_groups_per_group = defaultdict(int)

        epoch_start_time = time.perf_counter()
        with torch.no_grad():
            for images, labels, groups in tqdm(self.val_loader, desc="Validation"):
                images = images.to(self.device, non_blocking=True)
                if self.label_mode == 'binary':
                    labels_cast = labels.float().view(-1, 1).to(self.device, non_blocking=True)
                else:
                    labels_cast = labels.long().view(-1).to(self.device, non_blocking=True)

                if self.group_mode == 'binary':
                    groups_cast = groups[:, 0].float().view(-1, 1).to(self.device, non_blocking=True)
                else:
                    groups_cast = groups[:, 0].long().view(-1).to(self.device, non_blocking=True)
                    
                group_list = groups.tolist()
    
                outputs_labels = self.label_model(images)
                outputs_groups = self.group_model(images)
    
                loss_label = self.criterion_label(outputs_labels, labels_cast)
                loss_group = self.criterion_group(outputs_groups, groups_cast)
    
                batch_size = images.size(0)
                label_loss_sum += loss_label.item() * batch_size
                group_loss_sum += loss_group.item() * batch_size
                num_samples += batch_size
    
                if self.label_mode == 'binary':
                    predicted_labels = (outputs_labels >= 0.0).long()
                else:
                    predicted_labels = torch.argmax(outputs_labels, dim=1)
    
                if self.group_mode == 'binary':
                    predicted_groups = (outputs_groups >= 0.0).long()
                else:
                    predicted_groups = torch.argmax(outputs_groups, dim=1)
    
                for i, g_idx in enumerate(group_list):
                    if type(g_idx) is list:
                        g_idx = tuple(g_idx[:2])
                    correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                    total_labels_per_group[g_idx] += 1
                
                for i, g_idx in enumerate(group_list):
                    if type(g_idx) is list:
                        g_idx = tuple(g_idx[:2])
                    correct_groups_per_group[g_idx] += (predicted_groups[i] == groups[:, 0][i]).item()
                    total_groups_per_group[g_idx] += 1

        epoch_time = time.perf_counter() - epoch_start_time
        epoch_metrics = {}

        epoch_metrics['label_loss'] = label_loss_sum / num_samples
        epoch_metrics['group_loss'] = group_loss_sum / num_samples
        epoch_metrics['epoch_time'] = epoch_time

        group_accuracies_label = []
        
        for t_idx in self.tuple_to_group:
            if total_labels_per_group[t_idx] > 0:
                group_accuracies_label.append(correct_labels_per_group[t_idx] / total_labels_per_group[t_idx])
            else:
                group_accuracies_label.append(0.0)

        epoch_metrics['label_group_accuracies'] = torch.tensor(group_accuracies_label)
        epoch_metrics['label_avg_acc'] = sum(correct_labels_per_group.values()) / sum(total_labels_per_group.values()) 
        epoch_metrics['label_worst_group_acc'] = torch.min(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group_idx'] = torch.argmin(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group'] = self.groups[epoch_metrics['label_worst_group_idx']]

        group_accuracies_group = []
        for t_idx in self.tuple_to_group:
            if total_groups_per_group[t_idx] > 0:
                group_accuracies_group.append(correct_groups_per_group[t_idx] / total_groups_per_group[t_idx])
            else:
                group_accuracies_group.append(0.0)

        epoch_metrics['group_group_accuracies'] = torch.tensor(group_accuracies_group)
        epoch_metrics['group_avg_acc'] = sum(correct_groups_per_group.values()) / sum(total_groups_per_group.values()) 
        epoch_metrics['group_worst_group_acc'] = torch.min(epoch_metrics['group_group_accuracies']).item()
        epoch_metrics['group_worst_group_idx'] = torch.argmin(epoch_metrics['group_group_accuracies']).item()
        epoch_metrics['group_worst_group'] = self.groups[epoch_metrics['group_worst_group_idx']]

        label_group_accuracy_map = dict(zip(self.groups, group_accuracies_label))
        group_group_accuracy_map = dict(zip(self.groups, group_accuracies_group))

        return epoch_metrics, label_group_accuracy_map, group_group_accuracy_map

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lambda_group", self.lambda_group)
            mlflow.log_param("metric", self.metric)
            mlflow.log_param("lambda_metric", self.lambda_metric)
            mlflow.log_param("lr_label", self.optimizer_label.param_groups[0]['lr'])
            mlflow.log_param("lr_group", self.optimizer_group.param_groups[0]['lr'])
            mlflow.log_param("weight_decay_label", self.optimizer_label.param_groups[0]['weight_decay'])
            mlflow.log_param("weight_decay_group", self.optimizer_group.param_groups[0]['weight_decay'])
            print(f"\nlr_label={self.optimizer_label.param_groups[0]['lr']}; lr_group={self.optimizer_group.param_groups[0]['lr']}")
            print(f"wd_label={self.optimizer_label.param_groups[0]['weight_decay']}; wd_group={self.optimizer_group.param_groups[0]['weight_decay']}")
            print(f"lambda_metric={self.lambda_metric}; lambda_group={self.lambda_group}")
            for epoch in range(num_epochs):
                print(f"\nEpoch {epoch+1}/{num_epochs}")
                train_epoch_metrics, train_label_map, train_group_map = self.train_epoch()
                val_epoch_metrics, val_label_map, val_group_map = self.validate_epoch()

                print(f"🏋️‍♂️⏳ Train Epoch Time: {train_epoch_metrics['epoch_time']:4f}")
                print(f"🏋️‍♂️📦⏳ Train Avg Batch Time: {train_epoch_metrics['avg_batch_time']:4f} +- {train_epoch_metrics['std_batch_time']:4f}")
                print(f"🏋️‍♂️💯 Train Total Loss: {train_epoch_metrics['total_loss']:.4f}")
                print(f"🏋️‍♂️🔖 Train Label Loss: {train_epoch_metrics['label_loss']:.4f}")
                print(f"🏋️‍♂️🌍 Train Group Loss: {train_epoch_metrics['group_loss']:.4f}")
                print(f"🏋️‍♂️🔀 Train Metric ({self.metric.title()}) Loss: {train_epoch_metrics['metric_loss']:.4f}")
                print(f"🏋️‍♂️🔖⬇️ Worst Group Train Label Accuracy: {train_epoch_metrics['label_worst_group_acc']:.4f} (Group:  {train_epoch_metrics['label_worst_group']})")
                print(f"🏋️‍♂️🔖🎯 Average Train Label Accuracy: {train_epoch_metrics['label_avg_acc']:.4f}")
                print(f"🏋️‍♂️🌍⬇️ Worst Group Train Group Accuracy: {train_epoch_metrics['group_worst_group_acc']:.4f} (Group: {train_epoch_metrics['group_worst_group']})")
                print(f"🏋️‍♂️🌍🎯 Average Train Group Accuracy: {train_epoch_metrics['group_avg_acc']:.4f}")

                print("\n")
                print(f"🕵⏳ Val Epoch Time: {val_epoch_metrics['epoch_time']:4f}")
                print(f"🕵🔖 Val Label Loss: {val_epoch_metrics['label_loss']:.4f}")
                print(f"🕵🌍 Val Group Loss: {val_epoch_metrics['group_loss']:.4f}")
                print(f"🕵🔖⬇️ Worst Val Label Accuracy: {val_epoch_metrics['label_worst_group_acc']:.4f} (Group: {val_epoch_metrics['label_worst_group']})")
                print(f"🕵🔖🎯 Average Val Label Accuracy: {val_epoch_metrics['label_avg_acc']:.4f}")
                print(f"🕵🌍⬇️ Worst Group Val Group Accuracy: {val_epoch_metrics['group_worst_group_acc']:.4f} (Group:  {val_epoch_metrics['group_worst_group']})")
                print(f"🕵🌍🎯 Average Val Group Accuracy: {val_epoch_metrics['group_avg_acc']:.4f}")

                print("\n----------------------------------------------------------------------------\n")

                print("📊 Train Label Group-wise Accuracy:")
                for g in sorted(train_label_map.keys()):
                    acc = train_label_map[g]
                    print(f"  Group {g}: {acc:.3f}")
                    mlflow.log_metric(f"train_{g}_aa", acc, step=epoch)

                print("\n----------------------------------------------------------------------------\n")

                print("📊 Val Label Group-wise Accuracy:")
                for g in sorted(val_label_map.keys()):
                    acc = val_label_map[g]
                    print(f"  Group {g}: {acc:.3f}")
                    mlflow.log_metric(f"val_{g}_aa", acc, step=epoch)

                print("\n----------------------------------------------------------------------------\n")

                print("📊 Train Group Group-wise Accuracy:")
                for g in sorted(train_group_map.keys()):
                    acc = train_group_map[g]
                    print(f"  Group {g}: {acc:.3f}")
                    mlflow.log_metric(f"train_{g}_group_aa", acc, step=epoch)

                print("\n----------------------------------------------------------------------------\n")

                print("📊 Val Group Group-wise Accuracy:")
                for g in sorted(val_group_map.keys()):
                    acc = val_group_map[g]
                    print(f"  Group {g}: {acc:.3f}")
                    mlflow.log_metric(f"val_{g}_group_aa", acc, step=epoch)

                print("\n----------------------------------------------------------------------------\n")

                mlflow.log_metrics({
                    "train_total_loss": train_epoch_metrics['total_loss'],
                    "train_label_loss": train_epoch_metrics['label_loss'],
                    "train_group_loss": train_epoch_metrics['group_loss'],
                    "train_metric_loss": train_epoch_metrics['metric_loss'],
                    "train_wga_label": train_epoch_metrics['label_worst_group_acc'],
                    "train_wga_label_idx": train_epoch_metrics['label_worst_group_idx'],
                    "train_aa_label": train_epoch_metrics['label_avg_acc'],
                    "train_wga_group": train_epoch_metrics['group_worst_group_acc'],
                    "train_aa_group": train_epoch_metrics['group_avg_acc'],
                    "train_wga_group_idx": train_epoch_metrics['group_worst_group_idx'],
                    "train_avg_batch_time": train_epoch_metrics['avg_batch_time'],
                    "train_epoch_time": train_epoch_metrics['epoch_time'],
                    "train_std_batch_time": train_epoch_metrics['std_batch_time'],
                    "val_epoch_time": val_epoch_metrics['epoch_time'],
                    "val_label_loss": val_epoch_metrics['label_loss'],
                    "val_group_loss": val_epoch_metrics['group_loss'],
                    "val_wga_label": val_epoch_metrics['label_worst_group_acc'],
                    "val_wga_label_idx": val_epoch_metrics['label_worst_group_idx'],
                    "val_aa_label": val_epoch_metrics['label_avg_acc'],
                    "val_wga_group": val_epoch_metrics['group_worst_group_acc'],
                    "val_aa_group": val_epoch_metrics['group_avg_acc'],
                    "val_wga_group_idx": val_epoch_metrics['group_worst_group_idx']
                }, step=epoch)
    
                if 0.5 * (val_epoch_metrics['label_worst_group_acc'] + val_epoch_metrics['label_avg_acc']) > 0.5 * (self.best_val_wga + self.best_val_aa):
                    print("✅ Validation improved!")
                    self.best_val_wga = val_epoch_metrics['label_worst_group_acc']
                    self.best_val_aa = val_epoch_metrics['label_avg_acc']
                    self.best_val_wga_idx = val_epoch_metrics['label_worst_group_idx']
                    self.best_val_wga_group = val_epoch_metrics['label_worst_group']
                    self.save_checkpoint(epoch + 1, val_epoch_metrics['label_worst_group_acc'], \
                                         val_epoch_metrics['label_avg_acc'], val_epoch_metrics['label_worst_group_idx'],
                                        val_epoch_metrics['label_worst_group'])

                else:
                    print("🔁 Validation did not improve.")

                if self.early_stopping:
                    avg_early_stopping_metric = 0.5 * (val_epoch_metrics['label_worst_group_acc'] + val_epoch_metrics['label_avg_acc'])
                    mlflow.log_metric("val_avg_early_stopping_metric", avg_early_stopping_metric, step=epoch)
                    early_stop(avg_early_stopping_metric)
                    if early_stop.early_stop:
                        print(f"⏹️ Early stopping triggered at epoch {epoch + 1} (val_WAA = {avg_early_stopping_metric:.4f})")
                        mlflow.log_param("early_stopped_epoch", epoch + 1)
                        break

            print(f"🥇⬇️ Highest Worst Group Val Label Accuracy: {self.best_val_wga:.4f} (Group: {self.best_val_wga_group})")
            print(f"🥇🎯 Highest Average Val Label Accuracy: {self.best_val_aa:.4f}")
            
            mlflow.log_metrics({
                "best_val_wga": self.best_val_wga,
                "best_val_aa": self.best_val_aa,
                "best_val_wga_idx": self.best_val_wga_idx
            })
            mlflow.log_param("val_wga_label_group", self.best_val_wga_group)

class IRMTrainer(ERMTrainer):
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, irm_lambda, irm_penalty_anneal_iters, lr, weight_decay,
                 label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
                 seed=42, early_stopping=False, **early_stopping_kwargs):
        self.train_loader_len = min(len(train_loader[0]), len(train_loader[1]))
        self.train_loader_copy = train_loader
        super().__init__(run_name, model, train_loader, val_loader, criterion, optimizer,
                         tuple_to_group, label_mode, device, checkpoint_path, seed, early_stopping,
                         **early_stopping_kwargs)
        self.update_count = 0
        self.lr = lr
        self.weight_decay = weight_decay
        self.irm_lambda = irm_lambda
        self.irm_penalty_anneal_iters = irm_penalty_anneal_iters
        self.scale = torch.tensor(1.).to(self.device).requires_grad_()

    def _irm_penalty(self, logits, y):
        loss_1 = self.criterion(logits[::2] * self.scale, y[::2])
        loss_2 = self.criterion(logits[1::2] * self.scale, y[1::2])
        grad_1 = torch.autograd.grad(loss_1, [self.scale], create_graph=True)[0]
        grad_2 = torch.autograd.grad(loss_2, [self.scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

    def train_epoch(self):
        self.model.train()

        total_loss_sum = 0.0
        num_samples = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        self.train_loader = zip(*self.train_loader_copy)
        epoch_start_time = time.perf_counter()
        for (images_0, labels_0, groups_0), (images_1, labels_1, groups_1) in tqdm(self.train_loader, desc="Training", total=self.train_loader_len):
            batch_start_time = time.perf_counter()
            penalty_weight = (self.irm_lambda if self.update_count
                              >= self.irm_penalty_anneal_iters else
                              1.0)
            nll = 0.
            penalty = 0.
            
            images_0 = images_0.to(self.device, non_blocking=True)
            if self.mode == 'binary':
                labels_0_cast = labels_0.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_0_cast = labels_0.long().view(-1).to(self.device, non_blocking=True)

            group_0_list = groups_0.tolist()

            images_1 = images_1.to(self.device, non_blocking=True)
            if self.mode == 'binary':
                labels_1_cast = labels_1.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_1_cast = labels_1.long().view(-1).to(self.device, non_blocking=True)
                
            group_1_list = groups_1.tolist()

            logits_0 = self.model(images_0)
            logits_1 = self.model(images_1)
            
            nll += self.criterion(logits_0, labels_0_cast)
            nll += self.criterion(logits_1, labels_1_cast)

            penalty += self._irm_penalty(logits_0, labels_0_cast)
            penalty += self._irm_penalty(logits_1, labels_1_cast)

            nll /= 2
            penalty /= 2
            loss = nll + (penalty_weight * penalty)

            if self.update_count == self.irm_penalty_anneal_iters:
                self.optimizer = torch.optim.SGD(
                    self.model.parameters(),
                    lr=self.lr,
                    weight_decay=self.weight_decay)

            self.optimizer.zero_grad()
            
            loss.backward()

            self.optimizer.step()

            self.update_count += 1

            batch_size = images_0.size(0) + images_1.size(0)
            total_loss_sum += loss.item() * batch_size
            num_samples += batch_size

            outputs = torch.concat((logits_0, logits_1), 0)
            labels = torch.concat((labels_0, labels_1), 0)
            group_list = group_0_list + group_1_list

            if self.mode == 'binary':
                predicted_labels = (outputs >= 0.0).long()
            else:
                predicted_labels = torch.argmax(outputs, dim=1)

            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                total_labels_per_group[g_idx] += 1

            batch_times.append(time.perf_counter() - batch_start_time)

        epoch_metrics, label_group_accuracy_map = self.compute_train_epoch_metrics(
                total_loss_sum, num_samples,
                correct_labels_per_group, total_labels_per_group,
                batch_times, epoch_start_time)

        return epoch_metrics, label_group_accuracy_map

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            mlflow.log_param("irm_lambda", self.irm_lambda)
            mlflow.log_param("irm_penalty_anneal_iters", self.irm_penalty_anneal_iters)
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lr_label", self.optimizer.param_groups[0]['lr'])
            mlflow.log_param("weight_decay_label", self.optimizer.param_groups[0]['weight_decay'])
            print(f"\nirm_lambda={self.irm_lambda}")
            print(f"irm_penalty_anneal_iters={self.irm_penalty_anneal_iters}")
            print(f"lr_label={self.optimizer.param_groups[0]['lr']}")
            print(f"wd_label={self.optimizer.param_groups[0]['weight_decay']}")
            if self.early_stopping:
                self.record_log(num_epochs, early_stop)
            else:
                self.record_log(num_epochs)

class GroupDROTrainer(ERMTrainer):
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, groupdro_eta, label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
                 seed=42, early_stopping=False, **early_stopping_kwargs):
        super().__init__(run_name, model, train_loader, val_loader, criterion, optimizer,
                         tuple_to_group, label_mode, device, checkpoint_path, seed, early_stopping,
                         **early_stopping_kwargs)
        self.q = torch.Tensor()
        self.groupdro_eta = groupdro_eta
    
    def train_epoch(self):
        self.model.train()

        if not len(self.q):
            self.q = torch.ones(len(self.groups)).to(self.device)

        total_loss_sum = 0.0
        num_samples = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        epoch_start_time = time.perf_counter()
        for images, labels, groups in tqdm(self.train_loader, desc="Training"):
            batch_start_time = time.perf_counter()
            
            losses = torch.zeros(len(self.groups)).to(self.device)
            images = images.to(self.device, non_blocking=True)

            if self.mode == 'binary':
                labels_cast = labels.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_cast = labels.long().view(-1).to(self.device, non_blocking=True)

            group_list = groups.tolist()

            self.optimizer.zero_grad()

            outputs = self.model(images)

            for i, group in enumerate(self.tuple_to_group):
                group_mask = torch.tensor([tuple(g[:2]) == group for g in group_list],
                                          dtype=torch.bool, device=self.device)
                if group_mask.sum() == 0:
                    continue
                images_group = images[group_mask]
                labels_group = labels_cast[group_mask]
                outputs_group = outputs[group_mask]
                losses[i] = self.criterion(outputs_group, labels_group)
                self.q[i] *= (self.groupdro_eta * losses[i].data).exp()

            self.q /= self.q.sum()
            total_loss = torch.dot(losses, self.q)
            
            total_loss.backward()

            self.optimizer.step()

            batch_size = images.size(0)
            total_loss_sum += total_loss.item() * batch_size
            num_samples += batch_size

            if self.mode == 'binary':
                predicted_labels = (outputs >= 0.0).long()
            else:
                predicted_labels = torch.argmax(outputs, dim=1)

            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                total_labels_per_group[g_idx] += 1

            batch_times.append(time.perf_counter() - batch_start_time)

        epoch_metrics, label_group_accuracy_map = self.compute_train_epoch_metrics(
                total_loss_sum, num_samples,
                correct_labels_per_group, total_labels_per_group,
                batch_times, epoch_start_time)

        return epoch_metrics, label_group_accuracy_map

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lr_label", self.optimizer.param_groups[0]['lr'])
            mlflow.log_param("groupdro_eta", self.groupdro_eta)
            mlflow.log_param("weight_decay_label", self.optimizer.param_groups[0]['weight_decay'])
            print(f"\ngroupdro_eta={self.groupdro_eta}")
            print(f"lr_label={self.optimizer.param_groups[0]['lr']}")
            print(f"wd_label={self.optimizer.param_groups[0]['weight_decay']}")
            if self.early_stopping:
                self.record_log(num_epochs, early_stop)
            else:
                self.record_log(num_epochs)

class AbstractMMDTrainer(ERMTrainer):
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, mmd_gamma, gaussian=False, label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
                 seed=42, early_stopping=False, **early_stopping_kwargs):
        self.train_loader_len = min(len(train_loader[0]), len(train_loader[1]))
        self.train_loader_copy = train_loader
        super().__init__(run_name, model, train_loader, val_loader, criterion, optimizer,
                         tuple_to_group, label_mode, device, checkpoint_path, seed, early_stopping,
                         **early_stopping_kwargs)
        self.gaussian = gaussian
        if self.gaussian is True:
            self.kernel_type = 'gaussian'
        else:
            self.kernel_type = 'mean_cov'
        self.mmd_gamma = mmd_gamma
        self.nmb = len(train_loader)

    def my_cdist(self, x1, x2):
        x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        res = torch.addmm(x2_norm.transpose(-2, -1),
                          x1,
                          x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
        return res.clamp_min_(1e-30)
    
    def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                       1000]):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)
        
        for g in gamma:
            K.add_(torch.exp(D.mul(-g)))
        
        return K

    def mmd(self, x, y):
        if self.kernel_type == "gaussian":
            Kxx = self.gaussian_kernel(x, x).mean()
            Kyy = self.gaussian_kernel(y, y).mean()
            Kxy = self.gaussian_kernel(x, y).mean()
            return Kxx + Kyy - 2 * Kxy
        else:
            mean_x = x.mean(0, keepdim=True)
            mean_y = y.mean(0, keepdim=True)
            cent_x = x - mean_x
            cent_y = y - mean_y
            cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
            cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

            mean_diff = (mean_x - mean_y).pow(2).mean()
            cova_diff = (cova_x - cova_y).pow(2).mean()

            return mean_diff + cova_diff
    
    def train_epoch(self):
        self.model.train()

        total_loss_sum = 0.0
        num_samples = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        self.train_loader = zip(*self.train_loader_copy)

        epoch_start_time = time.perf_counter()
        for (images_0, labels_0, groups_0), (images_1, labels_1, groups_1) in tqdm(self.train_loader, desc="Training", total=self.train_loader_len):
            batch_start_time = time.perf_counter()
            
            objective = 0.
            penalty = 0.
            
            images_0 = images_0.to(self.device, non_blocking=True)
            
            if self.mode == 'binary':
                labels_0_cast = labels_0.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_0_cast = labels_0.long().view(-1).to(self.device, non_blocking=True)

            group_0_list = groups_0.tolist()

            images_1 = images_1.to(self.device, non_blocking=True)

            if self.mode == 'binary':
                labels_1_cast = labels_1.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_1_cast = labels_1.long().view(-1).to(self.device, non_blocking=True)

            group_1_list = groups_1.tolist()

            features_0 = self.model.backbone(images_0)
            features_1 = self.model.backbone(images_1)

            logits_0 = self.model.classifier(features_0)
            logits_1 = self.model.classifier(features_1)
            
            objective += self.criterion(logits_0, labels_0_cast)
            objective += self.criterion(logits_1, labels_1_cast)

            penalty += self.mmd(features_0, features_1)

            objective /= self.nmb
            if self.nmb > 1:
                penalty /= (self.nmb * (self.nmb - 1) / 2)

            self.optimizer.zero_grad()

            loss = (objective + (self.mmd_gamma * penalty))
            
            loss.backward()

            self.optimizer.step()

            batch_size = images_0.size(0) + images_1.size(0)
            total_loss_sum += loss.item() * batch_size
            num_samples += batch_size

            outputs = torch.concat((logits_0, logits_1), 0)
            labels = torch.concat((labels_0, labels_1), 0)
            group_list = group_0_list + group_1_list

            if self.mode == 'binary':
                predicted_labels = (outputs >= 0.0).long()
            else:
                predicted_labels = torch.argmax(outputs, dim=1)

            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                total_labels_per_group[g_idx] += 1

            batch_times.append(time.perf_counter() - batch_start_time)
                
        epoch_metrics, label_group_accuracy_map = self.compute_train_epoch_metrics(
                total_loss_sum, num_samples,
                correct_labels_per_group, total_labels_per_group,
                batch_times, epoch_start_time)

        return epoch_metrics, label_group_accuracy_map

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lr_label", self.optimizer.param_groups[0]['lr'])
            mlflow.log_param("mmd_gamma", self.mmd_gamma)
            mlflow.log_param("weight_decay_label", self.optimizer.param_groups[0]['weight_decay'])
            print(f"\nmmd_gamma={self.mmd_gamma}")
            print(f"lr_label={self.optimizer.param_groups[0]['lr']}")
            print(f"wd_label={self.optimizer.param_groups[0]['weight_decay']}")
            if self.early_stopping:
                self.record_log(num_epochs, early_stop)
            else:
                self.record_log(num_epochs)

class MMDTrainer(AbstractMMDTrainer):
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
             tuple_to_group, mmd_gamma, label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
             seed=42, early_stopping=False, **early_stopping_kwargs):
        super().__init__(run_name=run_name,
                         model=model,
                         train_loader=train_loader,
                         val_loader=val_loader,
                         criterion=criterion,
                         optimizer=optimizer,
                         tuple_to_group=tuple_to_group,
                         mmd_gamma=mmd_gamma,
                         label_mode=label_mode,
                         gaussian=True,
                         device=device,
                         checkpoint_path=checkpoint_path,
                         seed=seed,
                         early_stopping=early_stopping,
                         **early_stopping_kwargs)

class CORALTrainer(AbstractMMDTrainer):
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
             tuple_to_group, mmd_gamma, label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
             seed=42, early_stopping=False, **early_stopping_kwargs):
        super().__init__(run_name=run_name,
                         model=model,
                         train_loader=train_loader,
                         val_loader=val_loader,
                         criterion=criterion,
                         optimizer=optimizer,
                         tuple_to_group=tuple_to_group,
                         mmd_gamma=mmd_gamma,
                         label_mode=label_mode,
                         gaussian=False,
                         device=device,
                         checkpoint_path=checkpoint_path,
                         seed=seed,
                         early_stopping=early_stopping,
                         **early_stopping_kwargs)

class FishTrainer(ERMTrainer):
    def __init__(self, run_name, model, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, meta_lr, num_classes, label_mode='binary', device='cuda', checkpoint_path='checkpoint.pth',
                 seed=42, early_stopping=False, **early_stopping_kwargs):
        self.train_loader_len = min(len(train_loader[0]), len(train_loader[1]))
        self.train_loader_copy = train_loader
        super().__init__(run_name, model, train_loader, val_loader, criterion, optimizer,
                         tuple_to_group, label_mode, device, checkpoint_path, seed, early_stopping,
                         **early_stopping_kwargs)
        self.meta_lr = meta_lr
        self.num_classes = num_classes
        self.lr = self.optimizer.param_groups[0]['lr']
        self.weight_decay = self.optimizer.param_groups[0]['weight_decay']
        self.optimizer_inner_state = None
        self.create_clone()

    def create_clone(self):
        self.model_inner = WholeFish(self.num_classes, 
                                     weights=self.model.state_dict()).to(self.device)
        self.optimizer_inner = torch.optim.SGD(
            self.model_inner.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        if self.optimizer_inner_state is not None:
            self.optimizer_inner.load_state_dict(self.optimizer_inner_state)

    def fish(self, meta_weights, inner_weights, lr_meta):
        meta_weights = ParamDict(meta_weights)
        inner_weights = ParamDict(inner_weights)
        meta_weights += lr_meta * (inner_weights - meta_weights)
        return meta_weights
    
    def train_epoch(self):
        self.model.train()
        self.model_inner.train()

        total_loss_sum = 0.0
        num_samples = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        self.train_loader = zip(*self.train_loader_copy)

        epoch_start_time = time.perf_counter()
        for (images_0, labels_0, groups_0), (images_1, labels_1, groups_1) in tqdm(self.train_loader, desc="Training", total=self.train_loader_len):
            batch_start_time = time.perf_counter()
            
            self.create_clone()

            images_0 = images_0.to(self.device, non_blocking=True)
            if self.mode == 'binary':
                labels_0_cast = labels_0.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_0_cast = labels_0.long().view(-1).to(self.device, non_blocking=True)

            group_0_list = groups_0.tolist()

            images_1 = images_1.to(self.device, non_blocking=True)
            if self.mode == 'binary':
                labels_1_cast = labels_1.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_1_cast = labels_1.long().view(-1).to(self.device, non_blocking)
            group_1_list = groups_1.tolist()

            logits_0 = self.model_inner(images_0)
            self.optimizer_inner.zero_grad()
            loss_0 = self.criterion(logits_0, labels_0_cast)
            loss_0.backward()
            self.optimizer_inner.step()

            logits_1 = self.model_inner(images_1)
            self.optimizer_inner.zero_grad()
            loss_1 = self.criterion(logits_1, labels_1_cast)
            loss_1.backward()
            self.optimizer_inner.step()

            self.optimizer_inner_state = self.optimizer_inner.state_dict()
            
            meta_weights = self.fish(
                meta_weights=self.model.state_dict(),
                inner_weights=self.model_inner.state_dict(),
                lr_meta=self.meta_lr
            )
            self.model.reset_weights(meta_weights)

            batch_size = images_0.size(0) + images_1.size(0)
            total_loss_sum += loss_0.item() * images_0.size(0) + loss_1.item() * images_1.size(0)
            num_samples += batch_size

            outputs = torch.concat((logits_0, logits_1), 0)
            labels = torch.concat((labels_0, labels_1), 0)
            group_list = group_0_list + group_1_list

            if self.mode == 'binary':
                predicted_labels = (outputs >= 0.0).long()
            else:
                predicted_labels = torch.argmax(outputs, dim=1)

            for i, g_idx in enumerate(group_list):
                if type(g_idx) is list:
                    g_idx = tuple(g_idx[:2])
                correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                total_labels_per_group[g_idx] += 1

            batch_times.append(time.perf_counter() - batch_start_time)

        epoch_metrics, label_group_accuracy_map = self.compute_train_epoch_metrics(
                total_loss_sum, num_samples,
                correct_labels_per_group, total_labels_per_group,
                batch_times, epoch_start_time)

        return epoch_metrics, label_group_accuracy_map

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lr_label", self.optimizer.param_groups[0]['lr'])
            mlflow.log_param("meta_lr", self.meta_lr)
            mlflow.log_param("weight_decay_label", self.optimizer.param_groups[0]['weight_decay'])
            print(f"\nmeta_lr={self.meta_lr}")
            print(f"lr_label={self.optimizer.param_groups[0]['lr']}")
            print(f"wd_label={self.optimizer.param_groups[0]['weight_decay']}")
            self.record_log(num_epochs, early_stop)

class AbstractDANNTrainer(ERMTrainer):
    def __init__(self, run_name, classifier, featurizer, train_loader, val_loader, criterion,
                 optimizer, tuple_to_group, num_domains, num_classes, class_balance, conditional,
                 lr_d, grad_penalty, d_steps_per_g, lambda_, weight_decay_d, label_mode='binary', device='cuda',
                 checkpoint_path='checkpoint.pth', seed=42, early_stopping=False, **early_stopping_kwargs):
        self.train_loader_len = min(len(train_loader[0]), len(train_loader[1]))
        self.train_loader_copy = train_loader
        super().__init__(run_name=run_name,
                         model=classifier,
                         train_loader=train_loader,
                         val_loader=val_loader,
                         criterion=criterion,
                         optimizer=optimizer,
                         tuple_to_group=tuple_to_group,
                         label_mode=label_mode,
                         device=device,
                         checkpoint_path=checkpoint_path,
                         seed=seed,
                         early_stopping=early_stopping,
                         **early_stopping_kwargs)
        self.featurizer = featurizer.to(self.device)
        self.update_count = torch.tensor([0])
        self.conditional = conditional
        self.class_balance = class_balance
        self.lr_d = lr_d
        self.grad_penalty = grad_penalty
        self.d_steps_per_g = d_steps_per_g
        self.lambda_ = lambda_
        self.discriminator = ResNet50_MLP(num_domains).to(self.device)
        self.num_domains = num_domains
        self.weight_decay_d = weight_decay_d
        if num_classes == 1:
            self.class_embeddings = nn.Embedding(num_classes + 1, self.model.in_features).to(self.device)
        else:
            self.class_embeddings = nn.Embedding(num_classes, self.model.in_features).to(self.device)

        self.optimizer_d = torch.optim.SGD(
            (list(self.discriminator.parameters()) +
            list(self.class_embeddings.parameters())),
            lr=self.lr_d,
            weight_decay=self.weight_decay_d)

    def train_epoch(self):
        self.model.train()
        self.featurizer.train()
        self.discriminator.train()

        total_loss_sum = 0.0
        disc_loss_sum = 0.0
        label_loss_sum = 0.0
        num_samples_d = 0
        num_samples_g = 0
        batch_times = []

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        self.train_loader = zip(*self.train_loader_copy)

        epoch_start_time = time.perf_counter()
        for (images_0, labels_0, groups_0), (images_1, labels_1, groups_1) in tqdm(self.train_loader, desc="Training", total=self.train_loader_len):
            batch_start_time = time.perf_counter()
            
            self.update_count += 1
            
            images_0 = images_0.to(self.device, non_blocking=True)

            if self.mode == 'binary':
                labels_0_cast = labels_0.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_0_cast = labels_0.long().view(-1).to(self.device, non_blocking=True)

            group_0_list = groups_0.tolist()

            images_1 = images_1.to(self.device, non_blocking=True)

            if self.mode == 'binary':
                labels_1_cast = labels_1.float().view(-1, 1).to(self.device, non_blocking=True)
            else:
                labels_1_cast = labels_1.long().view(-1).to(self.device, non_blocking=True)

            group_1_list = groups_1.tolist()

            features_0 = self.featurizer(images_0)
            features_1 = self.featurizer(images_1)

            if self.conditional:
                disc_input_0 = features_0 + self.class_embeddings(labels_0.view(-1).long().to(self.device))
                disc_input_1 = features_1 + self.class_embeddings(labels_1.view(-1).long().to(self.device))
            else:
                disc_input_0 = features_0
                disc_input_1 = features_1

            disc_input = torch.cat((disc_input_0, disc_input_1), dim=0)
            
            disc_labels_0 = torch.zeros(images_0.shape[0], dtype=torch.float32, device=self.device).view(-1, 1)
            disc_labels_1 = torch.ones(images_1.shape[0], dtype=torch.float32, device=self.device).view(-1, 1)
            disc_labels = torch.cat((disc_labels_0, disc_labels_1), dim=0)
            
            disc_out = self.discriminator(disc_input)
            labels = torch.cat((labels_0, labels_1), dim=0)

            if self.class_balance:
                y_counts = F.one_hot(labels).sum(dim=0)
                weights = 1. / (y_counts[labels] * y_counts.shape[0]).float().to(self.device)
                if self.num_domains == 1:
                    disc_loss = F.binary_cross_entropy_with_logits(disc_out, disc_labels, reduction='none')
                else:
                    disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
                disc_loss = (weights * disc_loss).sum()
                
            else:
                if self.num_domains == 1:
                    disc_loss = F.binary_cross_entropy_with_logits(disc_out, disc_labels)
                else:
                    disc_loss = F.cross_entropy(disc_out, disc_labels)

            if self.num_domains == 1:
                input_grad = torch.autograd.grad(
                    F.binary_cross_entropy_with_logits(disc_out, disc_labels, reduction='sum'),
                    [disc_input], create_graph=True)[0]
            else:
                input_grad = torch.autograd.grad(
                    F.cross_entropy(disc_out, disc_labels, reduction='sum'),
                    [disc_input], create_graph=True)[0]
            
            grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
            disc_loss += self.grad_penalty * grad_penalty

            batch_size = images_0.size(0) + images_1.size(0)

            if (self.update_count.item() % (1 + self.d_steps_per_g) < self.d_steps_per_g):
                self.optimizer_d.zero_grad()
                disc_loss.backward()
                self.optimizer_d.step()
                disc_loss_sum += disc_loss.item() * batch_size
                num_samples_d += batch_size
            else:
                logits_0 = self.model(features_0)
                logits_1 = self.model(features_1)
                loss_0 = self.criterion(logits_0, labels_0_cast)
                loss_1 = self.criterion(logits_1, labels_1_cast)
                label_loss = loss_0 + loss_1
                gen_loss = ((loss_0 + loss_1) + (self.lambda_ * -disc_loss))
                self.optimizer_d.zero_grad()
                self.optimizer.zero_grad()
                gen_loss.backward()
                self.optimizer.step()
                total_loss_sum += gen_loss.item() * batch_size
                label_loss_sum += label_loss.item() * batch_size
                num_samples_g += batch_size

                outputs = torch.concat((logits_0, logits_1), 0)
                group_list = group_0_list + group_1_list
    
                if self.mode == 'binary':
                    predicted_labels = (outputs >= 0.0).long()
                else:
                    predicted_labels = torch.argmax(outputs, dim=1)
    
                for i, g_idx in enumerate(group_list):
                    if type(g_idx) is list:
                        g_idx = tuple(g_idx[:2])
                    correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                    total_labels_per_group[g_idx] += 1

            batch_times.append(time.perf_counter() - batch_start_time)

        epoch_time = time.perf_counter() - epoch_start_time
        epoch_metrics = {}

        epoch_metrics['total_loss'] = total_loss_sum / num_samples_g
        epoch_metrics['disc_loss'] = disc_loss_sum / num_samples_d
        epoch_metrics['label_loss'] = label_loss_sum / num_samples_g
        epoch_metrics['epoch_time'] = epoch_time
        epoch_metrics['avg_batch_time'] = np.mean(batch_times).item()
        epoch_metrics['std_batch_time'] = np.std(batch_times, ddof=1).item()

        group_accuracies_label = []
        
        for t_idx in self.tuple_to_group:
            if total_labels_per_group[t_idx] > 0:
                group_accuracies_label.append(correct_labels_per_group[t_idx] / total_labels_per_group[t_idx])
            else:
                group_accuracies_label.append(0.0)

        epoch_metrics['label_group_accuracies'] = torch.tensor(group_accuracies_label)
        epoch_metrics['label_avg_acc'] = sum(correct_labels_per_group.values()) / sum(total_labels_per_group.values())
        epoch_metrics['label_worst_group_acc'] = torch.min(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group_idx'] = torch.argmin(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group'] = self.groups[epoch_metrics['label_worst_group_idx']]

        label_group_accuracy_map = dict(zip(self.groups, group_accuracies_label))

        return epoch_metrics, label_group_accuracy_map

    def fit(self, num_epochs):
        with mlflow.start_run(run_name=self.run_name, nested=True):
            if self.early_stopping:
                early_stop = EarlyStopping(**self.early_stopping_kwargs)
                mlflow.log_param("patience", self.early_stopping_kwargs['patience'])
                mlflow.log_param("min_delta", self.early_stopping_kwargs['min_delta'])
                mlflow.log_param("mode", self.early_stopping_kwargs['mode'])
                mlflow.log_param("early_stopping_metric", "Average of WGA and AA")
            mlflow.log_param("num_epochs", num_epochs)
            mlflow.log_param("lr_label", self.optimizer.param_groups[0]['lr'])
            mlflow.log_param("lr_d", self.lr_d)
            mlflow.log_param("grad_penalty", self.grad_penalty)
            mlflow.log_param("d_steps_per_g", self.d_steps_per_g)
            mlflow.log_param("lambda_", self.lambda_)
            mlflow.log_param("weight_decay_label", self.optimizer.param_groups[0]['weight_decay'])
            mlflow.log_param("weight_decay_disc", self.optimizer_d.param_groups[0]['weight_decay'])
            print(f"\ngrad_penalty={self.grad_penalty}")
            print(f"d_steps_per_g={self.d_steps_per_g}")
            print(f"lambda_={self.lambda_}")
            print(f"lr_label={self.optimizer.param_groups[0]['lr']}, lr_disc={self.lr_d}")
            print(f"wd_label={self.optimizer.param_groups[0]['weight_decay']}, wd_disc={self.optimizer_d.param_groups[0]['weight_decay']}")
            for epoch in range(num_epochs):
                print(f"\nEpoch {epoch+1}/{num_epochs}")
                train_epoch_metrics, train_label_map = self.train_epoch()
                val_epoch_metrics, val_label_map = self.validate_epoch()

                print(f"🏋️‍♂️⏳ Train Epoch Time: {train_epoch_metrics['epoch_time']:4f}")
                print(f"🏋️‍♂️📦⏳ Train Avg Batch Time: {train_epoch_metrics['avg_batch_time']:4f} +- {train_epoch_metrics['std_batch_time']:4f}")
                print(f"🏋️‍♂️💯 Train Total Loss: {train_epoch_metrics['total_loss']:.4f}")
                print(f"🏋️‍♂️🔖 Train Label Loss: {train_epoch_metrics['label_loss']:.4f}")
                print(f"🏋️‍♂️🙉 Train Discriminator Loss: {train_epoch_metrics['disc_loss']:.4f}")
                print(f"🏋️‍♂️🔖⬇️ Worst Group Train Label Accuracy: {train_epoch_metrics['label_worst_group_acc']:.4f} (Group:  {train_epoch_metrics['label_worst_group']})")
                print(f"🏋️‍♂️🔖🎯 Average Train Label Accuracy: {train_epoch_metrics['label_avg_acc']:.4f}")

                print('\n')
                print(f"🕵⏳ Val Epoch Time: {val_epoch_metrics['epoch_time']:4f}")
                print(f"🕵🔖 Val Label Loss: {val_epoch_metrics['label_loss']:.4f}")
                print(f"🕵🔖⬇️ Worst Group Val Label Accuracy: {val_epoch_metrics['label_worst_group_acc']:.4f} (Group: {val_epoch_metrics['label_worst_group']})")
                print(f"🕵🔖🎯 Average Val Label Accuracy: {val_epoch_metrics['label_avg_acc']:.4f}")

                print("\n----------------------------------------------------------------------------\n")

                print("📊 Train Label Group-wise Accuracy:")
                for g in sorted(train_label_map.keys()):
                    acc = train_label_map[g]
                    print(f"  Group {g}: {acc:.3f}")
                    mlflow.log_metric(f"train_{g}_aa", acc, step=epoch)

                print("\n----------------------------------------------------------------------------\n")

                print("📊 Val Label Group-wise Accuracy:")
                for g in sorted(val_label_map.keys()):
                    acc = val_label_map[g]
                    print(f"  Group {g}: {acc:.3f}")
                    mlflow.log_metric(f"val_{g}_aa", acc, step=epoch)

                print("\n----------------------------------------------------------------------------\n")

                mlflow.log_metrics({
                    "train_total_loss": train_epoch_metrics['total_loss'],
                    "train_disc_loss": train_epoch_metrics['disc_loss'],
                    "train_label_loss": train_epoch_metrics['label_loss'],
                    "train_wga_label": train_epoch_metrics['label_worst_group_acc'],
                    "train_wga_label_idx": train_epoch_metrics['label_worst_group_idx'],
                    "train_aa_label": train_epoch_metrics['label_avg_acc'],
                    "train_avg_batch_time": train_epoch_metrics['avg_batch_time'],
                    "train_std_batch_time": train_epoch_metrics['std_batch_time'],
                    "train_epoch_time": train_epoch_metrics['epoch_time'],
                    "val_label_loss": val_epoch_metrics['label_loss'],
                    "val_wga_label": val_epoch_metrics['label_worst_group_acc'],
                    "val_wga_label_idx": val_epoch_metrics['label_worst_group_idx'],
                    "val_aa_label": val_epoch_metrics['label_avg_acc'],
                    "val_epoch_time": val_epoch_metrics['epoch_time']
                }, step=epoch)
    
                if 0.5 * (val_epoch_metrics['label_worst_group_acc'] + val_epoch_metrics['label_avg_acc']) > 0.5 * (self.best_val_wga + self.best_val_aa):
                    print("✅ Validation improved!")
                    self.best_val_wga = val_epoch_metrics['label_worst_group_acc']
                    self.best_val_aa = val_epoch_metrics['label_avg_acc']
                    self.best_val_wga_idx = val_epoch_metrics['label_worst_group_idx']
                    self.best_val_wga_group = val_epoch_metrics['label_worst_group']
                    self.save_checkpoint(epoch + 1, val_epoch_metrics['label_worst_group_acc'], \
                                         val_epoch_metrics['label_avg_acc'], val_epoch_metrics['label_worst_group_idx'],
                                        val_epoch_metrics['label_worst_group'])

                else:
                    print("🔁 Validation did not improve.")

                if self.early_stopping:
                    avg_early_stopping_metric = 0.5 * (val_epoch_metrics['label_worst_group_acc'] + val_epoch_metrics['label_avg_acc'])
                    mlflow.log_metric("val_avg_early_stopping_metric", avg_early_stopping_metric, step=epoch)
                    early_stop(avg_early_stopping_metric)
                    if early_stop.early_stop:
                        print(f"⏹️ Early stopping triggered at epoch {epoch + 1} (val_WAA = {avg_early_stopping_metric:.4f})")
                        mlflow.log_param("early_stopped_epoch", epoch + 1)
                        break

            print(f"🥇⬇️ Highest Worst Group Val Label Accuracy: {self.best_val_wga:.4f} (Group: {self.best_val_wga_group})")
            print(f"🥇🎯 Highest Average Val Label Accuracy: {self.best_val_aa:.4f}")
            
            mlflow.log_metrics({
                "best_val_wga": self.best_val_wga,
                "best_val_aa": self.best_val_aa,
                "best_val_wga_idx": self.best_val_wga_idx
            })
            mlflow.log_param("val_wga_label_group", self.best_val_wga_group)

    def save_checkpoint(self, epoch, val_wga, val_aa, val_wga_idx, val_wga_group):
        print(f"💾 Saving checkpoint... (val_wga: {val_wga:.4f}; val_aa: {val_aa:.4f})")
        
        torch.save({
            'epoch': epoch,
            'classifier_state_dict': self.model.state_dict(),
            'featurizer_state_dict': self.featurizer.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'optimizer_label_state_dict': self.optimizer.state_dict(),
            'optimizer_disc_state_dict': self.optimizer_d.state_dict(),
            'update_count': self.update_count,
            'class_embeddings': self.class_embeddings,
            'val_wga': val_wga,
            'val_aa': val_aa,
            'val_wga_idx': val_wga_idx,
            'val_wga_group': val_wga_group
        }, self.checkpoint_path)

    def validate_epoch(self):
        self.model.eval()
        self.featurizer.eval()

        label_loss_sum = 0.0
        num_samples = 0

        correct_labels_per_group = defaultdict(int)
        total_labels_per_group = defaultdict(int)

        epoch_start_time = time.perf_counter()
        with torch.no_grad():
            for images, labels, groups in tqdm(self.val_loader, desc="Validation"):
                images = images.to(self.device, non_blocking=True)

                if self.mode == 'binary':
                    labels_cast = labels.float().view(-1, 1).to(self.device, non_blocking=True)
                else:
                    labels_cast = labels.long().view(-1).to(self.device, non_blocking=True)

                group_list = groups.tolist()
    
                features_labels = self.featurizer(images)
                outputs_labels = self.model(features_labels)
    
                loss_label = self.criterion(outputs_labels, labels_cast)
    
                batch_size = images.size(0)
                label_loss_sum += loss_label.item() * batch_size
                num_samples += batch_size
    
                if self.mode == 'binary':
                    predicted_labels = (outputs_labels >= 0.0).long()
                else:
                    predicted_labels = torch.argmax(outputs_labels, dim=1)
    
                for i, g_idx in enumerate(group_list):
                    if type(g_idx) is list:
                        g_idx = tuple(g_idx[:2])
                    correct_labels_per_group[g_idx] += (predicted_labels[i] == labels[i]).item()
                    total_labels_per_group[g_idx] += 1

        epoch_time = time.perf_counter() - epoch_start_time
        epoch_metrics = {}

        epoch_metrics['label_loss'] = label_loss_sum / num_samples
        epoch_metrics['epoch_time'] = epoch_time

        group_accuracies_label = []
        
        for t_idx in self.tuple_to_group:
            if total_labels_per_group[t_idx] > 0:
                group_accuracies_label.append(correct_labels_per_group[t_idx] / total_labels_per_group[t_idx])
            else:
                group_accuracies_label.append(0.0)

        epoch_metrics['label_group_accuracies'] = torch.tensor(group_accuracies_label)
        epoch_metrics['label_avg_acc'] = sum(correct_labels_per_group.values()) / sum(total_labels_per_group.values()) 
        epoch_metrics['label_worst_group_acc'] = torch.min(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group_idx'] = torch.argmin(epoch_metrics['label_group_accuracies']).item()
        epoch_metrics['label_worst_group'] = self.groups[epoch_metrics['label_worst_group_idx']]

        label_group_accuracy_map = dict(zip(self.groups, group_accuracies_label))

        return epoch_metrics, label_group_accuracy_map

class DANNTrainer(AbstractDANNTrainer):
    def __init__(self, run_name, classifier, featurizer, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, num_domains, num_classes,
                 lr_d, grad_penalty, d_steps_per_g, lambda_, weight_decay_d, label_mode='binary', device='cuda',
                 checkpoint_path='checkpoint.pth', seed=42, early_stopping=False, **early_stopping_kwargs):
        super().__init__(run_name=run_name,
                         classifier=classifier,
                         featurizer=featurizer,
                         train_loader=train_loader,
                         val_loader=val_loader,
                         criterion=criterion,
                         optimizer=optimizer,
                         tuple_to_group=tuple_to_group,
                         num_domains=num_domains,
                         num_classes=num_classes,
                         lr_d=lr_d,
                         grad_penalty=grad_penalty,
                         d_steps_per_g=d_steps_per_g,
                         lambda_=lambda_,
                         weight_decay_d=weight_decay_d,
                         label_mode=label_mode,
                         device=device,
                         checkpoint_path=checkpoint_path,
                         seed=seed,
                         early_stopping=early_stopping,
                         class_balance=False,
                         conditional=False,
                         **early_stopping_kwargs)

class CDANNTrainer(AbstractDANNTrainer):
    def __init__(self, run_name, classifier, featurizer, train_loader, val_loader, criterion, optimizer,
                 tuple_to_group, num_domains, num_classes,
                 lr_d, grad_penalty, d_steps_per_g, lambda_, weight_decay_d, label_mode='binary', device='cuda',
                 checkpoint_path='checkpoint.pth', seed=42, early_stopping=False, **early_stopping_kwargs):
        super().__init__(run_name=run_name,
                         classifier=classifier,
                         featurizer=featurizer,
                         train_loader=train_loader,
                         val_loader=val_loader,
                         criterion=criterion,
                         optimizer=optimizer,
                         tuple_to_group=tuple_to_group,
                         num_domains=num_domains,
                         num_classes=num_classes,
                         lr_d=lr_d,
                         grad_penalty=grad_penalty,
                         d_steps_per_g=d_steps_per_g,
                         lambda_=lambda_,
                         weight_decay_d=weight_decay_d,
                         label_mode=label_mode,
                         device=device,
                         checkpoint_path=checkpoint_path,
                         seed=seed,
                         early_stopping=early_stopping,
                         class_balance=True,
                         conditional=True,
                         **early_stopping_kwargs)