import torch
import torch.nn as nn
import numpy as np
import itertools

from tqdm import tqdm
from models.models import *
from models.loss import MMD_loss, PositiveSupCon_loss, PositiveSupCon_loss2, PositiveSupCon_loss3, get_loss_class
from models.loss import CLDA_InterDomainContrastiveLoss, CLDA_InstanceContrastiveLoss, CLDA_ContrastiveLoss
from torch.optim.lr_scheduler import StepLR
from copy import deepcopy
import torch.nn.functional as F

### For Baselines
from models.loss import CDAC_loss, BCE_softlabels, sigmoid_rampup, CrossEntropyWLogits, AdaMatch_loss, univ_ssda_loss, dst_loss, CLDA_ContrastiveLoss
from utils import weights_init
from utils import adjust_learning_rate
from torch.cuda.amp import GradScaler


def get_algorithm_class(algorithm_name):
    """Return the algorithm class with the given name."""
    if algorithm_name not in globals():
        raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
    return globals()[algorithm_name]

class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain adaptation algorithm.
    Subclasses should implement the update() method.
    """

    def __init__(self, configs, backbone, num_epochs, post_epochs):
        super(Algorithm, self).__init__()
        self.configs = configs
        self.num_epochs = num_epochs
        self.post_epochs = post_epochs
        self.cross_entropy = nn.CrossEntropyLoss()
        self.feature_extractor = backbone(configs)
        self.classifier = classifier(configs) # models.classifier
        self.network = nn.Sequential(self.feature_extractor, self.classifier)

    def update(self, src_loader, trg_loader, avg_meter, logger):
        """Update function is common to all algorithm"""
        best_src_risk = float('inf')
        best_model = None

        for epoch in range(1, self.num_epochs+1):
            # training Loop
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)

            # saving the best model based on src risk
            if (epoch+1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk:
                best_src_risk = avg_meter['Src_cls_loss'].avg
                best_head = deepcopy(self.network.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())

            logger.debug(f'[ Epoch : {epoch}/{self.num_epochs}]')
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t: {val.avg:2.4f}')
            logger.debug(f'-----------------------------------------------------')

        last_head = self.network.state_dict()
        last_classifier = self.classifier.state_dict()
        return last_head, best_head, last_classifier, best_classifier


    def training_epoch(self, *args, **kwargs):
        """train loop vary from one method to another"""
        raise NotImplementedError


class NO_ADAPT(Algorithm):
    """
    Lower bound : train on source and test on target.
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)

        # optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
        self.hparams = hparams
        self.device = device
        self.to(self.device)

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        for src_x, src_y, _ in tqdm(src_loader):
            
            src_x, src_y = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True)
            src_feat = self.feature_extractor(src_x)
            src_pred = self.classifier(src_feat)
            
            src_cls_loss = self.cross_entropy(src_pred, src_y)
            loss = src_cls_loss
        
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            losses = {'Src_cls_loss':src_cls_loss.item()}
            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.lr_scheduler.step()

        

class TARGET_ONLY(Algorithm):
    """
    Upper bound: train on target and test on target
    """

    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)
        # optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])
        self.hparams = hparams
        self.device = device

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        for trg_x, trg_y, _ in tqdm(trg_loader):
            
            trg_x, trg_y = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True)
            trg_feat = self.feature_extractor(trg_x)
            trg_pred = self.classifier(trg_feat)

            trg_cls_loss = self.cross_entropy(trg_pred, trg_y)
            loss = trg_cls_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            losses = {'Trg_cls_loss': trg_cls_loss.item()}

            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.lr_scheduler.step()

class LABELED_ONLY(Algorithm):
    """
    Upper bound: train on target and test on target
    """

    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)
        # optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])
        self.hparams = hparams
        self.device = device

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        for src_x, src_y, _ in tqdm(src_loader):
            src_x, src_y = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True)
            src_feat = self.feature_extractor(src_x)
            src_pred = self.classifier(src_feat)

            src_cls_loss = self.cross_entropy(src_pred, src_y)
            loss = src_cls_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            losses = {'Src_cls_loss': src_cls_loss.item()}

            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.lr_scheduler.step()
        
        for trg_x, trg_y, _ in tqdm(trg_loader):
            
            trg_x, trg_y = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True)
            trg_feat = self.feature_extractor(trg_x)
            trg_pred = self.classifier(trg_feat)

            trg_cls_loss = self.cross_entropy(trg_pred, trg_y)
            loss = trg_cls_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            losses = {'Trg_cls_loss': trg_cls_loss.item()}

            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.lr_scheduler.step()

class MoSSDA_target(Algorithm):
    
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Backbone
        self.feature_extractor = backbone(configs)
        # Projection head for contrastive learning
        self.projection_head = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, configs.final_out_channels),
            nn.ReLU(),
            nn.Linear(configs.final_out_channels, configs.final_out_channels)
        )

        # Momentum Encoder applied over the projection head
        self.momentum_encoder = MomentumEncoder(self.projection_head)

        # Network
        self.network = nn.Sequential(self.feature_extractor, self.projection_head)

        # target classifier (post training)
        self.classifier = SSDA_classifier(configs,hparams)

        # Optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr = hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])

        self.hparams = hparams
        self.device = device

        self.kernel_type = configs.kernel_type
        self.mix_type = configs.mix_type

        # losses
        self.ce_loss = nn.CrossEntropyLoss()
        self.mmd_loss = MMD_loss(kernel_type=self.kernel_type)
        self.ctr_loss = get_loss_class(self.mix_type,device)


    def forward(self, src_data, trg_data):
        
        src_x, src_y, src_is_labeled = src_data
        trg_x, trg_y, trg_is_labeled = trg_data

        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)

        # Forward pass through the feature extractor
        src_FE_feat = self.feature_extractor(src_x)
        trg_FE_feat = self.feature_extractor(trg_x)

        # Projection features for contrastive learning
        src_prj_feat = self.projection_head(src_FE_feat)
        trg_prj_feat = self.projection_head(trg_FE_feat)

        loss_mmd = self.mmd_loss(src_FE_feat, trg_FE_feat)
        loss_ctr = 0

        if src_is_labeled.any() and trg_is_labeled.any():
            loss_ctr = self.ctr_loss(src_prj_feat[src_is_labeled==1], trg_prj_feat[trg_is_labeled==1], 
                                     src_y[src_is_labeled==1], trg_y[trg_is_labeled==1])

        total_loss = (loss_mmd*self.hparams['mmd_weight'] + loss_ctr*self.hparams['ctr_weight'])

        losses = {'MMD_loss'  : loss_mmd, 'Ctr_loss': loss_ctr, 'total_loss':total_loss}

        return src_prj_feat, trg_prj_feat, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        
        best_pretrain_risk = float('inf')
        best_posttrain_risk = float('inf')
        best_head = None

        pretrain_epoch = self.num_epochs
        posttrain_epoch = self.post_epochs

        # Training Phase 1
        for epoch in range(1, pretrain_epoch+1):

            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_pretrain_risk:
                best_pretrain_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.network.state_dict())
                
            logger.debug(f'[Phase1-Epoch : {epoch}/{pretrain_epoch}]')

            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.network.state_dict()

        # Training Phase 2
        for epoch in range(1, posttrain_epoch+1):
            
            self.post_training_epoch(trg_loader, avg_meter, epoch)
            
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['target_loss'].avg < best_posttrain_risk:
                best_posttrain_risk = avg_meter['target_loss'].avg
                best_classifier = deepcopy(self.classifier.state_dict())
            logger.debug(f'[Phase2-Epoch : {epoch}/{posttrain_epoch}]')
            
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t: {val.avg:2.4f}')
                
            logger.debug(f'-----------------------------------------------------------')
        last_classifier = self.classifier.state_dict()

        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            
            # Forward pass
            src_prj_feat, trg_prj_feat, total_loss, losses = self.forward(src_data, trg_data)
            
            # Backpropagation
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            self.momentum_encoder.update_momentum_encoder()##

            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, 32)

        self.lr_scheduler.step()

    def post_training_epoch(self, trg_loader, avg_meter, epoch):
        
        for k, v in self.network.named_parameters():
            v.requires_grad = False

        for k, v in self.classifier.named_parameters():
            v.requires_grad = True

        for trg_data in trg_loader:
            trg_x, trg_y, trg_is_labeled = trg_data
            trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)
            trg_FE_feat = self.feature_extractor(trg_x)
            trg_prj_feat = self.projection_head(trg_FE_feat.detach())

            trg_out = self.classifier(trg_prj_feat)
            target_loss = self.cross_entropy(trg_out, trg_y)

            losses = {'target_loss': target_loss.item()}

            self.classifier.optimizer.zero_grad()
            target_loss.backward()
            self.classifier.optimizer.step()

            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.classifier.lr_scheduler.step()


class MoSSDA_source(Algorithm):
    """ in phase 2 ; use source data """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Backbone
        self.feature_extractor = backbone(configs)
        # Projection head for contrastive learning
        self.projection_head = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, configs.final_out_channels),
            nn.ReLU(),
            nn.Linear(configs.final_out_channels, configs.final_out_channels)
        )

        # Momentum Encoder applied over the projection head
        self.momentum_encoder = MomentumEncoder(self.projection_head)

        # Network
        self.network = nn.Sequential(self.feature_extractor, self.projection_head)

        # target classifier (post training)
        self.classifier = SSDA_classifier(configs,hparams)

        # Optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr = hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])

        self.hparams = hparams
        self.device = device

        self.kernel_type = configs.kernel_type
        self.mix_type = configs.mix_type

        # losses
        self.ce_loss = nn.CrossEntropyLoss()
        self.mmd_loss = MMD_loss(kernel_type=self.kernel_type)
        self.ctr_loss = get_loss_class(self.mix_type, self.device)


    def forward(self, src_data, trg_data):
        
        src_x, src_y, src_is_labeled = src_data
        trg_x, trg_y, trg_is_labeled = trg_data

        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)

        # Forward pass through the feature extractor
        src_FE_feat = self.feature_extractor(src_x)
        trg_FE_feat = self.feature_extractor(trg_x)

        # Projection features for contrastive learning
        src_prj_feat = self.projection_head(src_FE_feat)
        trg_prj_feat = self.projection_head(trg_FE_feat)

        loss_mmd = self.mmd_loss(src_FE_feat, trg_FE_feat)
        loss_ctr = 0

        if src_is_labeled.any() and trg_is_labeled.any():
            loss_ctr = self.ctr_loss(src_prj_feat[src_is_labeled==1], trg_prj_feat[trg_is_labeled==1], 
                                     src_y[src_is_labeled==1], trg_y[trg_is_labeled==1])

        total_loss = (loss_mmd*self.hparams['mmd_weight'] + loss_ctr*self.hparams['ctr_weight'])

        losses = {'MMD_loss'  : loss_mmd, 'Ctr_loss': loss_ctr, 'total_loss':total_loss}

        return src_prj_feat, trg_prj_feat, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        
        best_pretrain_risk = float('inf')
        best_posttrain_risk = float('inf')
        best_head = None

        pretrain_epoch = self.num_epochs
        posttrain_epoch = self.post_epochs

        # Training Phase 1
        for epoch in range(1, pretrain_epoch+1):

            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_pretrain_risk:
                best_pretrain_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.network.state_dict())
                
            logger.debug(f'[Phase1-Epoch : {epoch}/{pretrain_epoch}]')

            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.network.state_dict()

        # Training Phase 2
        for epoch in range(1, posttrain_epoch+1):
            
            self.post_training_epoch(src_loader, avg_meter, epoch)
            
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['source_loss'].avg < best_posttrain_risk:
                best_posttrain_risk = avg_meter['source_loss'].avg
                best_classifier = deepcopy(self.classifier.state_dict())
            logger.debug(f'[Phase2-Epoch : {epoch}/{posttrain_epoch}]')
            
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t: {val.avg:2.4f}')
                
            logger.debug(f'-----------------------------------------------------------')
        last_classifier = self.classifier.state_dict()

        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            
            # Forward pass
            src_prj_feat, trg_prj_feat, total_loss, losses = self.forward(src_data, trg_data)
            
            # Backpropagation
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            self.momentum_encoder.update_momentum_encoder()##

            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, 32)

        self.lr_scheduler.step()

    def post_training_epoch(self, src_loader, avg_meter, epoch):
        
        for k, v in self.network.named_parameters():
            v.requires_grad = False

        for k, v in self.classifier.named_parameters():
            v.requires_grad = True

        for src_data in src_loader:
            src_x, src_y, src_is_labeled = src_data
            src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
            src_FE_feat = self.feature_extractor(src_x)
            src_prj_feat = self.projection_head(src_FE_feat.detach())

            src_out = self.classifier(src_prj_feat)
            source_loss = self.cross_entropy(src_out, src_y)

            losses = {'source_loss': source_loss.item()}

            self.classifier.optimizer.zero_grad()
            source_loss.backward()
            self.classifier.optimizer.step()

            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.classifier.lr_scheduler.step()

class MoSSDA_all(Algorithm):
    """in phase 2 ; use all data (source + target)"""
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Backbone
        self.feature_extractor = backbone(configs)
        # Projection head for contrastive learning
        self.projection_head = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, configs.final_out_channels),
            nn.ReLU(),
            nn.Linear(configs.final_out_channels, configs.final_out_channels)
        )
        
        # Momentum Encoder applied over the projection head
        self.momentum_encoder = MomentumEncoder(self.projection_head)

        # Network
        self.network = nn.Sequential(self.feature_extractor, self.projection_head)

        # target classifier (post training)
        self.classifier = SSDA_classifier(configs,hparams)

        # Optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr = hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])

        self.hparams = hparams
        self.device = device

        self.kernel_type = configs.kernel_type
        self.mix_type = configs.mix_type

        # losses
        self.ce_loss = nn.CrossEntropyLoss()
        self.mmd_loss = MMD_loss(kernel_type=self.kernel_type)
        self.ctr_loss = get_loss_class(self.mix_type, self.device)


    def forward(self, src_data, trg_data):
        
        src_x, src_y, src_is_labeled = src_data
        trg_x, trg_y, trg_is_labeled = trg_data

        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)

        # Forward pass through the feature extractor
        src_FE_feat = self.feature_extractor(src_x)
        trg_FE_feat = self.feature_extractor(trg_x)
        
        # Projection features for contrastive learning
        src_prj_feat = self.projection_head(src_FE_feat)
        trg_prj_feat = self.projection_head(trg_FE_feat)
        
        loss_mmd = self.mmd_loss(src_FE_feat, trg_FE_feat)
        loss_ctr = 0

        if src_is_labeled.any() and trg_is_labeled.any():
            loss_ctr = self.ctr_loss(src_prj_feat[src_is_labeled==1], trg_prj_feat[trg_is_labeled==1], 
                                     src_y[src_is_labeled==1], trg_y[trg_is_labeled==1])

        total_loss = (loss_mmd*self.hparams['mmd_weight'] + loss_ctr*self.hparams['ctr_weight'])

        losses = {'MMD_loss'  : loss_mmd, 'Ctr_loss': loss_ctr, 'total_loss':total_loss}

        return src_prj_feat, trg_prj_feat, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        
        best_pretrain_risk = float('inf')
        best_posttrain_risk = float('inf')
        best_head = None

        pretrain_epoch = self.num_epochs
        posttrain_epoch = self.post_epochs

        # Training Phase 1
        for epoch in range(1, pretrain_epoch+1):

            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_pretrain_risk:
                best_pretrain_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.network.state_dict())
                
            logger.debug(f'[Phase1-Epoch : {epoch}/{pretrain_epoch}]')

            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.network.state_dict()

        # Training Phase 2
        for epoch in range(1, posttrain_epoch+1):
            
            self.post_training_epoch(src_loader, trg_loader, avg_meter, epoch)
            
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['post_loss'].avg < best_posttrain_risk:
                best_posttrain_risk = avg_meter['post_loss'].avg
                best_classifier = deepcopy(self.classifier.state_dict())
            logger.debug(f'[Phase2-Epoch : {epoch}/{posttrain_epoch}]')
            
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t: {val.avg:2.4f}')
                
            logger.debug(f'-----------------------------------------------------------')
        last_classifier = self.classifier.state_dict()

        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            
            # Forward pass
            src_prj_feat, trg_prj_feat, total_loss, losses = self.forward(src_data, trg_data)
            
            # Backpropagation
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            self.momentum_encoder.update_momentum_encoder()##

            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, 32)

        self.lr_scheduler.step()

    def post_training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        for k, v in self.network.named_parameters():
            v.requires_grad = False

        for k, v in self.classifier.named_parameters():
            v.requires_grad = True
            
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in joint_loader:
            
            src_x, src_y, src_is_labeled = src_data
            src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
            src_FE_feat = self.feature_extractor(src_x)
            src_prj_feat = self.projection_head(src_FE_feat.detach())

            src_out = self.classifier(src_prj_feat)
            source_loss = self.cross_entropy(src_out, src_y)

            trg_x, trg_y, trg_is_labeled = trg_data
            trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)
            trg_FE_feat = self.feature_extractor(trg_x)
            trg_prj_feat = self.projection_head(trg_FE_feat.detach())

            trg_out = self.classifier(trg_prj_feat)
            target_loss = self.cross_entropy(trg_out, trg_y)

            total_loss = source_loss + target_loss
            losses = {'source_loss': source_loss.item(), 'target_loss' : target_loss.item(), 'post_loss' : total_loss}

            self.classifier.optimizer.zero_grad()
            total_loss.backward()
            self.classifier.optimizer.step()

            for key, val in losses.items():
                avg_meter[key].update(val,32)

        self.classifier.lr_scheduler.step()

class MoSSDA_all_ablation(Algorithm):
    """w/o phase2 = 1-stage learning ; use all data (source + target)"""
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Backbone
        self.feature_extractor = backbone(configs)
        # Projection head for contrastive learning
        self.projection_head = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, configs.final_out_channels),
            nn.ReLU(),
            nn.Linear(configs.final_out_channels, configs.final_out_channels)
        )
        
        # Momentum Encoder applied over the projection head
        self.momentum_encoder = MomentumEncoder(self.projection_head)

        # Network
        self.network = nn.Sequential(self.feature_extractor, self.projection_head)

        # target classifier (post training)
        self.classifier = SSDA_classifier(configs,hparams)

        # Optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr = hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])

        self.hparams = hparams
        self.device = device

        self.kernel_type = configs.kernel_type
        self.mix_type = configs.mix_type

        # losses
        self.ce_loss = nn.CrossEntropyLoss()
        self.mmd_loss = MMD_loss(kernel_type=self.kernel_type)
        self.ctr_loss = get_loss_class(self.mix_type, self.device)


    def forward(self, src_data, trg_data):
        
        src_x, src_y, src_is_labeled = src_data
        trg_x, trg_y, trg_is_labeled = trg_data

        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)

        # Forward pass through the feature extractor
        src_FE_feat = self.feature_extractor(src_x)
        trg_FE_feat = self.feature_extractor(trg_x)
        
        # Projection features for contrastive learning
        src_prj_feat = self.projection_head(src_FE_feat)
        trg_prj_feat = self.projection_head(trg_FE_feat)
        
        loss_mmd = self.mmd_loss(src_FE_feat, trg_FE_feat)
        loss_ctr = 0

        if src_is_labeled.any() and trg_is_labeled.any():
            loss_ctr = self.ctr_loss(src_prj_feat[src_is_labeled==1], trg_prj_feat[trg_is_labeled==1], 
                                     src_y[src_is_labeled==1], trg_y[trg_is_labeled==1])

        src_out = self.classifier(src_prj_feat)
        trg_out = self.classifier(trg_prj_feat)

        source_loss = self.cross_entropy(src_out, src_y)
        target_loss = self.cross_entropy(trg_out, trg_y)
        loss_cls = source_loss + target_loss
        
        total_loss = (loss_mmd*self.hparams['mmd_weight'] + loss_ctr*self.hparams['ctr_weight'] +loss_cls)

        losses = {'MMD_loss'  : loss_mmd, 'Ctr_loss': loss_ctr, 'Cls_loss': loss_cls, 'source_loss': source_loss, 'target_loss': target_loss, 'total_loss':total_loss}

        return src_prj_feat, trg_prj_feat, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        
        best_pretrain_risk = float('inf')
        best_posttrain_risk = float('inf')
        best_head = None

        pretrain_epoch = self.num_epochs
        posttrain_epoch = self.post_epochs

        # Training 1 Stage
        for epoch in range(1, pretrain_epoch+1):

            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            # save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_pretrain_risk:
                best_pretrain_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.network.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())                
            logger.debug(f'[Onephase-Epoch : {epoch}/{pretrain_epoch}]')

            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')
            
        last_head = self.network.state_dict()
        last_classifier = self.classifier.state_dict()
        
        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            
            # Forward pass
            src_prj_feat, trg_prj_feat, total_loss, losses = self.forward(src_data, trg_data)
            
            # Backpropagation
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            self.momentum_encoder.update_momentum_encoder()##

            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, 32)

        self.lr_scheduler.step()
class CDAC(Algorithm):
    """
    Cross Domain Adaptive Clustering
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Feature extractor
        self.feature_extractor = backbone(configs)
        # Classifier
        self.classifier = SSDA_classifier(configs, hparams)
        weights_init(self.classifier)

        # Optimizer setup for feature extractor
        params = []
        for key, value in dict(self.feature_extractor.named_parameters()).items():
            if value.requires_grad:
                if 'classifier' not in key:
                    params += [{'params':[value], 'learning_rate':hparams['multi'],
                                'weight_decay':hparams['weight_decay']}]
                else:
                    params += [{'params':[value], 'learning_rate':hparams['multi']*10,
                               'weight_decay':hparams['weight_decay']}]
        
        # SGD Optimizers (원본 유지)
        self.optimizer_g = torch.optim.SGD(params, momentum=0.9, weight_decay=hparams['weight_decay'], nesterov=True)
        self.optimizer_f = torch.optim.SGD(list(self.classifier.parameters()), lr=hparams['lr_f'], momentum=0.9, 
                                     weight_decay=hparams['weight_decay'], nesterov=True)

        # Store learning rates for scheduler
        self.param_lr_g = []
        for param_group in self.optimizer_g.param_groups:
            self.param_lr_g.append(param_group['lr'])
        self.param_lr_f = []
        for param_group in self.optimizer_f.param_groups:
            self.param_lr_f.append(param_group['lr'])

        # Loss functions
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.cdac_loss = CDAC_loss()
        self.BCE = BCE_softlabels().to(device)
        
        self.hparams = hparams
        self.device = device

    def zero_grad_all(self):
        """Zero gradients for all optimizers"""
        self.optimizer_g.zero_grad()
        self.optimizer_f.zero_grad()

    def forward(self, src_data, trg_data, step):
        """Forward pass through the network"""
        src_x, src_x_bar, src_x_bar2, src_y, src_is_labeled = src_data
        trg_x, trg_x_bar, trg_x_bar2, trg_y, trg_is_labeled = trg_data
        
        # Move data to device
        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)
        src_x_bar, src_x_bar2 = src_x_bar.cuda(non_blocking=True), src_x_bar2.cuda(non_blocking=True)
        trg_x_bar, trg_x_bar2 = trg_x_bar.cuda(non_blocking=True), trg_x_bar2.cuda(non_blocking=True)

        # Calculate rampup coefficient
        rampup = sigmoid_rampup(step, self.hparams['rampup_length'])
        w_cons = self.hparams['rampup_coef'] * rampup

        # Get labeled data
        im_data_src = src_x[src_is_labeled == 1]
        im_data_trg = trg_x[trg_is_labeled == 1]

        # Process labeled source data
        num_samp_src = im_data_src.shape[0] if src_is_labeled.any() else 0
        labeled_features_src = self.feature_extractor(im_data_src) if src_is_labeled.any() else None
        
        # Process labeled target data
        num_samp_trg = im_data_trg.shape[0] if trg_is_labeled.any() else 0
        labeled_features_trg = self.feature_extractor(im_data_trg) if trg_is_labeled.any() else None

        # Calculate supervised loss if labeled data exists
        ce_loss_src = torch.tensor(0.0).to(self.device)
        ce_loss_trg = torch.tensor(0.0).to(self.device)

        if src_is_labeled.any() and trg_is_labeled.any():
            predictions_src = self.classifier(labeled_features_src)
            predictions_trg = self.classifier(labeled_features_trg)
            ce_loss_trg = self.criterion(predictions_trg, trg_y[trg_is_labeled==1])
            ce_loss_src = self.criterion(predictions_src, src_y[src_is_labeled==1])

        # Weighted supervised loss
        num_samp = num_samp_src + num_samp_trg
        if num_samp > 0:
            ce_loss = (num_samp_src/num_samp) * ce_loss_src + (num_samp_trg/num_samp) * ce_loss_trg
        else:
            ce_loss = torch.tensor(0.0).to(self.device)

        # Get unlabeled target data
        im_data_trg_ul = trg_x[trg_is_labeled==0]
        im_data_bar_trg_ul = trg_x_bar[trg_is_labeled==0]
        im_data_bar2_trg_ul = trg_x_bar2[trg_is_labeled==0]

        if im_data_trg_ul.any():
            # Calculate unsupervised loss
            cdac_loss, unlabeled_features_trg_ul = self.cdac_loss(
                self.hparams, 
                self.feature_extractor, 
                self.classifier,
                im_data_trg_ul, 
                im_data_bar_trg_ul,
                im_data_bar2_trg_ul, 
                self.BCE, 
                w_cons, 
                self.device, 
                None
            )
        else:
            # 라벨이 없는 데이터가 없는 경우 0 손실 반환
            cdac_loss = torch.tensor(0.0, device=self.device)
            # unlabeled_features_trg_ul = torch.zeros((0, feature_dim), device=self.device)

        total_loss = ce_loss + cdac_loss
        
        src_feats = self.feature_extractor(src_x)
        trg_feats = self.feature_extractor(trg_x)
        
        losses = {
            'clf_loss': ce_loss.item(), 
            'unlabeled_loss': cdac_loss.item(), 
            'total_loss': total_loss.item()
        }

        return src_feats, trg_feats, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        """Training process for the algorithm"""
        best_risk = float('inf')
        best_head = None
        best_classifier = None

        train_epoch = self.num_epochs
        
        # Training loop (단일 단계 학습)
        for epoch in range(1, train_epoch+1):
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            
            # Save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_risk:
                best_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.feature_extractor.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())
                
            logger.debug(f'[Epoch : {epoch}/{train_epoch}]')
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.feature_extractor.state_dict()
        last_classifier = self.classifier.state_dict()

        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        """Execute one training epoch"""
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            global_step = (epoch - 1) * len(src_loader) + step
            
            # Update learning rates using inverse scheduler
            self.optimizer_g = inv_lr_scheduler(self.param_lr_g, self.optimizer_g, global_step,
                                               init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            self.optimizer_f = inv_lr_scheduler(self.param_lr_f, self.optimizer_f, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            
            # Forward pass
            labeled_features_src, labeled_features_trg, total_loss, losses = self.forward(src_data, trg_data, global_step)
            
            # First backward pass for supervised loss
            self.zero_grad_all()
            ce_loss = torch.tensor(losses['clf_loss'], requires_grad=True).to(self.device)
            ce_loss.backward(retain_graph=True)
            self.optimizer_g.step()
            self.optimizer_f.step()
            
            # Second backward pass for unsupervised loss
            self.zero_grad_all()
            cdac_loss = torch.tensor(losses['unlabeled_loss'], requires_grad=True).to(self.device)
            cdac_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.feature_extractor.parameters(), 5)
            torch.nn.utils.clip_grad_norm_(self.classifier.parameters(), 5)
            
            self.optimizer_g.step()
            self.optimizer_f.step()

            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))

class PAC(Algorithm):
    """
    PAC: Pretraining(4-class) & Post-training(Consistency) - MoSSDA_all 구조
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Feature Extractor
        self.feature_extractor = backbone(configs).to(device)
        
        # Classifiers
        self.classifier1 = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, 256),
            nn.ReLU(),
            nn.Linear(256, 4)  # 4-class classifier
        )
        self.classifier = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, 256),
            nn.ReLU(),
            nn.Linear(256, configs.num_classes)  # original class classifier
        )
        
        self.classifier1 = self.classifier1.to(device)
        self.classifier = self.classifier.to(device)

        
        # Optimizers
        self.optimizer_g = torch.optim.SGD(self.feature_extractor.parameters(), lr=hparams['learning_rate'], momentum=0.9,
                                     weight_decay=hparams['weight_decay'], nesterov=True)
        self.optimizer_f1 = torch.optim.SGD(self.classifier1.parameters(), lr=hparams['pre_lr_f'], momentum=0.9,
                                      weight_decay=hparams['weight_decay'], nesterov=True) ## pretraining
        self.optimizer_f2 = torch.optim.SGD(self.classifier.parameters(), lr=hparams['lr_f'], momentum=0.9,
                                      weight_decay=hparams['weight_decay'], nesterov=True)
        
        # Loss functions
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.criterion1 = CrossEntropyWLogits(reduction='none').to(device)
        self.scaler = GradScaler()
        self.hparams = hparams
        self.device = device

    def forward(self, src_data, trg_data, phase):
        src_x, src_x_bar, src_y, src_is_labeled, src_rot_x, src_rot = src_data
        trg_x, trg_x_bar, trg_y, trg_is_labeled, trg_rot_x, trg_rot = trg_data

        src_x, src_x_bar, src_y, src_is_labeled, src_rot_x, src_rot = src_x.cuda(non_blocking=True), src_x_bar.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True), src_rot_x.cuda(non_blocking=True), src_rot.cuda(non_blocking=True)
        trg_x, trg_x_bar, trg_y, trg_is_labeled, trg_rot_x, trg_rot = trg_x.cuda(non_blocking=True), trg_x_bar.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True), trg_rot_x.cuda(non_blocking=True), trg_rot.cuda(non_blocking=True)
        
        # src_pretext_x (4D->3D)
        B, R, C, T = src_rot_x.shape # [batch, rotations, chennels, time]
        src_rot_x = src_rot_x.view(-1, C, T) # [batch*rotations, channels, time]
        B, R, C, T = trg_rot_x.shape
        trg_rot_x = trg_rot_x.view(-1,C,T)
        src_rot = src_rot.view(-1)
        trg_rot = trg_rot.view(-1)
        
        src_mask = src_is_labeled == 1
        trg_mask = trg_is_labeled == 1

        ####
        src_feats = self.feature_extractor(src_x)
        trg_feats = self.feature_extractor(trg_x)
        src_aug_feats = self.feature_extractor(src_rot_x)
        trg_aug_feats = self.feature_extractor(trg_rot_x)
        
        if phase == 'pretrain':
            src_pretext_pred = self.classifier1(src_aug_feats)
            trg_pretext_pred = self.classifier1(trg_aug_feats)
            src_loss = self.criterion(src_pretext_pred, src_rot)
            trg_loss = self.criterion(trg_pretext_pred, trg_rot)
            loss = src_loss + trg_loss
            losses = {'src_loss':src_loss, 'trg_loss':trg_loss, 'pretrain_loss':loss}
            return src_feats, trg_feats, loss, losses
        else:
            src_pred = self.classifier(src_feats)
            trg_pred = self.classifier(trg_feats)
            cls_loss = 0
            
            if src_is_labeled.any() and trg_is_labeled.any():
                src_loss = self.criterion(src_pred[src_mask], src_y[src_mask]) 
                trg_loss = self.criterion(trg_pred[trg_mask], trg_y[trg_mask])
                cls_loss = src_loss + trg_loss
            
            # Consistency Loss (if enabled)
            if self.hparams.get('cons_wt', 0) > 0 and self.hparams['cons_wt'] > 0:
                unlabeled_mask = ~trg_mask
                if unlabeled_mask.sum()>0:
                    trg_x_bar = trg_x_bar[~trg_mask]
                    with torch.no_grad():
                        feats_unl = self.feature_extractor(trg_x[~trg_mask])
                        pseudo_labels = torch.softmax(self.classifier(feats_unl), dim=1)
                        confs, _ = torch.max(pseudo_labels, dim=1)
                        pl_mask = (confs > self.hparams['cons_threshold']).float()
                    aug_pred = self.classifier(self.feature_extractor(trg_x_bar))
                    loss_cons = (self.criterion1(aug_pred, pseudo_labels) * pl_mask).mean()
                else:
                    loss_cons = torch.tensor(0.0 , device = self.device)
                total_loss = cls_loss + self.hparams['cons_wt'] * loss_cons
                losses = {'cls_loss':cls_loss, 'cons_loss':loss_cons, 'posttrain_loss':total_loss}
                return src_feats, trg_feats, total_loss, losses
            else:
                losses = {'cls_loss':cls_loss, 'total_loss':total_loss}
                return src_feats, trg_feats, total_loss, losses

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        ## pretraining
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        for step, (src_data, trg_data) in tqdm(joint_loader):
            # 인라인 학습률 조정
            global_step = (epoch-1)*len(src_loader) + step
            lr_g = self.hparams['pre_lr'] / (1 + self.hparams['gamma'] * global_step) ** 0.75
            lr_f = self.hparams['pre_lr_f'] / (1 + self.hparams['gamma'] * global_step) ** 0.75
            for param_group in self.optimizer_g.param_groups:
                param_group['pre_lr'] = lr_g
            for param_group in self.optimizer_f1.param_groups:
                param_group['pre_lr'] = lr_f

            # 순전파 및 역전파
            src_feats, trg_feats, loss, losses = self.forward(src_data, trg_data, phase='pretrain')
            self.optimizer_g.zero_grad()
            self.optimizer_f1.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer_g)
            self.scaler.step(self.optimizer_f1)
            self.scaler.update()
            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))

    def post_training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        for step, (src_data, trg_data) in tqdm(joint_loader):
            global_step = (epoch-1)*len(src_loader) + step
            lr_g = self.hparams['learning_rate'] / (1 + self.hparams['gamma'] * global_step) ** 0.75
            lr_f = self.hparams['lr_f'] / (1 + self.hparams['gamma'] * global_step) ** 0.75
            for param_group in self.optimizer_g.param_groups:
                param_group['lr'] = lr_g
            for param_group in self.optimizer_f2.param_groups:
                param_group['lr'] = lr_f

            # 순전파 및 역전파
            src_feats, trg_feats, loss, losses = self.forward(src_data, trg_data, phase='posttrain')
            self.optimizer_g.zero_grad()
            self.optimizer_f2.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer_g)
            self.scaler.step(self.optimizer_f2)
            self.scaler.update()
            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))

    def update(self, src_loader, trg_loader, avg_meter, logger):
        best_pre_risk = float('inf')
        best_pre_head = None
        best_pre_classifier = None

        best_post_risk = float('inf')
        best_post_head = None
        best_post_classifier = None
        
        pretrain_epoch = self.num_epochs
        posttrain_epoch = self.post_epochs
        
        # Pretrain (4-class)
        for epoch in range(1, pretrain_epoch+1):
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            current_loss = avg_meter['pretrain_loss'].avg
            if epoch % 5 == 0 and (best_pre_head is None or current_loss < best_pre_risk):
                best_pre_risk = current_loss
                best_pre_head = deepcopy(self.feature_extractor.state_dict())
                best_pre_classifier = deepcopy(self.classifier1.state_dict())
                
            logger.debug(f'[Pretrain Epoch {epoch}/{self.num_epochs}] Loss: {avg_meter["pretrain_loss"].avg:.4f}')
        
        # Post-train (consistency)
        for epoch in range(1, posttrain_epoch+1):
            self.post_training_epoch(src_loader, trg_loader, avg_meter, epoch)
            current_loss = avg_meter['posttrain_loss'].avg
            if epoch % 5 == 0 and (best_post_head  is None or current_loss < best_post_risk):
                best_post_risk = current_loss
                best_post_head = deepcopy(self.feature_extractor.state_dict())
                best_post_classifier = deepcopy(self.classifier.state_dict())
                
            logger.debug(f'[Post-train Epoch {epoch}/{self.post_epochs}] Loss: {avg_meter["posttrain_loss"].avg:.4f}')

        last_head = self.feature_extractor.state_dict()
        last_classifier = self.classifier.state_dict() 
        ### 기존 다른 class의 best_models return이랑 비교해보기. 0509
        return last_head, best_post_head, last_classifier, best_post_classifier

def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=0.0001,
                     power=0.75, init_lr=0.001):
    lr = init_lr * (1 + gamma * iter_num) ** (- power)
    i = 0
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr * param_lr[i]
        i += 1
    return optimizer

class AdaMatch(Algorithm):
    """
    ADAMATCH: A UNIFIED APPROACH TO SEMISUPERVISED LEARNING AND DOMAIN ADAPTATION
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)

        self.device = device
        # Feature extractor
        self.feature_extractor = backbone(configs).to(device)
        
        # Classifier (Classifier → SSDA_classifier로 변경)
        self.classifier = SSDA_classifier(configs, hparams).to(device)
        weights_init(self.classifier)

        # Optimizer setup
        params = []
        for key, value in dict(self.feature_extractor.named_parameters()).items():
            if value.requires_grad:
                if 'classifier' not in key:
                    params += [{'params': [value], 'learning_rate': hparams['multi'],
                        'weight_decay': hparams['weight_decay']}]
                else:
                    params += [{'params': [value], 'learning_rate': hparams['multi'] * 10,
                        'weight_decay': hparams['weight_decay']}]
        
        # Optimizers (SGD 유지)
        self.optimizer_g = torch.optim.SGD(params, momentum=0.9, weight_decay=hparams['weight_decay'], nesterov=True)
        self.optimizer_f = torch.optim.SGD(list(self.classifier.parameters()), lr=hparams['lr_f'], momentum=0.9, 
                                     weight_decay=hparams['weight_decay'], nesterov=True)

        # Store learning rates for scheduler
        self.param_lr_g = []
        for param_group in self.optimizer_g.param_groups:
            self.param_lr_g.append(param_group["lr"])
        self.param_lr_f = []
        for param_group in self.optimizer_f.param_groups:
            self.param_lr_f.append(param_group["lr"])

        # Mixed precision 지원
        self.scaler = GradScaler()
            
        # Loss functions and parameters
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.adamatch_loss = AdaMatch_loss()
        self.hparams = hparams
        self.device = device

    def zero_grad_all(self):
        """Zero gradients for all optimizers"""
        self.optimizer_g.zero_grad()
        self.optimizer_f.zero_grad()

    def forward(self, src_data, trg_data, step):
        """Forward pass through the network"""
        src_x, src_x_bar, src_x_bar2, src_y, src_is_labeled = src_data
        trg_x, trg_x_bar, trg_x_bar2, trg_y, trg_is_labeled = trg_data
        ### debugging
        # Move data to device
        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)
        src_x_bar = src_x_bar.cuda(non_blocking=True)
        trg_x_bar = trg_x_bar.cuda(non_blocking=True)
        trg_x_bar2 = trg_x_bar2.cuda(non_blocking=True)
        
        # Get labeled data
        im_data_src = src_x[src_is_labeled==1]
        im_data_bar_src = src_x_bar[src_is_labeled==1]
        im_data_trg = trg_x[trg_is_labeled==1]
        im_data_bar_trg = trg_x_bar[trg_is_labeled==1]
        im_data_trg_ul = trg_x[trg_is_labeled==0]
        im_data_bar_trg_ul = trg_x_bar[trg_is_labeled==0]
        im_data_bar2_trg_ul = trg_x_bar2[trg_is_labeled==0]
        gt_labels_src = src_y[src_is_labeled==1]
        gt_labels_trg = trg_y[trg_is_labeled==1]

        # Calculate loss using AdaMatch loss function
        loss, source_loss, target_loss = self.adamatch_loss(
            self.hparams, 
            self.feature_extractor, 
            self.classifier, 
            im_data_src, 
            im_data_bar_src, 
            im_data_trg, 
            im_data_bar_trg, 
            im_data_trg_ul, 
            im_data_bar_trg_ul,
            gt_labels_src, 
            gt_labels_trg, 
            step, 
            self.hparams['warm_steps'], 
            self.device
        )
        
        src_feats = self.feature_extractor(src_x)
        trg_feats = self.feature_extractor(trg_x)
        
        losses = {
            'Clf_loss': source_loss.item(), 
            'Unlabeled_loss': target_loss.item(), 
            'total_loss': loss.item()
        }
        
        return src_feats, trg_feats, loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        """Training process for the algorithm"""
        best_risk = float('inf')
        best_head = None
        best_classifier = None

        train_epoch = self.num_epochs
        
        # Training loop
        for epoch in range(1, train_epoch+1):
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            
            # Save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_risk:
                best_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.feature_extractor.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())
                
            logger.debug(f'[Epoch : {epoch}/{train_epoch}]')
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.feature_extractor.state_dict()
        last_classifier = self.classifier.state_dict()

        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        """Execute one training epoch"""
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        total_steps = len(src_loader)
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            global_step = (epoch - 1) * total_steps + step
            
            # Update learning rates using inverse scheduler
            self.optimizer_g = inv_lr_scheduler(self.param_lr_g, self.optimizer_g, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            self.optimizer_f = inv_lr_scheduler(self.param_lr_f, self.optimizer_f, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            
            # Forward pass
            _, _, loss, losses = self.forward(src_data, trg_data, global_step)
            
            # Backward pass with mixed precision
            self.zero_grad_all()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer_g)
            self.scaler.step(self.optimizer_f)
            self.scaler.update()
            
            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))


class UniSSDA(AdaMatch):
    """
    UniSSDA: Universal Semi-Supervised Domain Adaptation
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(backbone, configs, num_epochs, post_epochs, hparams, device)
        
        # 추가 분류기 초기화
        self.classifier2 = SSDA_classifier(configs, hparams)
        weights_init(self.classifier2)
        
        # 분류기2 옵티마이저
        self.optimizer_f2 = torch.optim.SGD(
            self.classifier2.parameters(),
            lr=hparams['lr_f'],
            momentum=0.9,
            weight_decay=hparams['weight_decay'],
            nesterov=True
        )
        
        # 학습률 파라미터 저장
        self.param_lr_f2 = [pg['lr'] for pg in self.optimizer_f2.param_groups]

            
        # Loss functions and parameters
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.uniSSDA_loss = univ_ssda_loss()
        self.hparams = hparams
        self.device = device

    def zero_grad_all(self):
        """Zero gradients for all optimizers"""
        self.optimizer_g.zero_grad()
        self.optimizer_f.zero_grad()
        self.optimizer_f2.zero_grad()

    def forward(self, src_data, trg_data, step):
        """Forward pass through the network"""
        src_x, src_x_bar, src_x_bar2, src_y, src_is_labeled = src_data
        trg_x, trg_x_bar, trg_x_bar2, trg_y, trg_is_labeled = trg_data
        
        # Move data to device
        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)
        src_x_bar = src_x_bar.cuda(non_blocking=True)
        trg_x_bar = trg_x_bar.cuda(non_blocking=True)
        trg_x_bar2 = trg_x_bar2.cuda(non_blocking=True)

        
        # Get labeled data
        im_data_src = src_x[src_is_labeled==1]
        im_data_bar_src = src_x_bar[src_is_labeled==1]
        im_data_trg = trg_x[trg_is_labeled==1]
        im_data_bar_trg = trg_x_bar[trg_is_labeled==1]
        im_data_trg_ul = trg_x[trg_is_labeled==0]
        im_data_bar_trg_ul = trg_x_bar[trg_is_labeled==0]
        im_data_bar2_trg_ul = trg_x_bar2[trg_is_labeled==0]
        gt_labels_src = src_y[src_is_labeled==1]
        gt_labels_trg = trg_y[trg_is_labeled==1]
        
        # Calculate loss using proposed loss function
        loss, labeled_loss, unlabeled_loss, loss2 = self.uniSSDA_loss(
            self.hparams, 
            self.feature_extractor, 
            self.classifier, 
            self.classifier2, 
            im_data_src, 
            im_data_bar_src, 
            im_data_trg, 
            im_data_bar_trg, 
            im_data_trg_ul, 
            im_data_bar_trg_ul,
            gt_labels_src, 
            gt_labels_trg, 
            step, 
            self.hparams['warm_steps'], 
            self.device
        )
        
        total_loss = loss + loss2
        
        src_feats = self.feature_extractor(src_x)
        trg_feats = self.feature_extractor(trg_x)
        
        losses = {
            'Clf_loss': labeled_loss.item(), 
            'Unlabeled_loss': unlabeled_loss.item(), 
            'Clf_loss2': loss2.item(), 
            'total_loss': total_loss.item()
        }
        
        return src_feats, trg_feats, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        """Training process for the algorithm"""
        best_risk = float('inf')
        best_head = None
        best_classifier = None
        best_classifier2 = None

        train_epoch = self.num_epochs
        
        # Training loop
        for epoch in range(1, train_epoch+1):
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            
            # Save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_risk:
                best_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.feature_extractor.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())
                best_classifier2 = deepcopy(self.classifier2.state_dict())
                
            logger.debug(f'[Epoch : {epoch}/{train_epoch}]')
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.feature_extractor.state_dict()
        last_classifier = self.classifier.state_dict()
        last_classifier2 = self.classifier2.state_dict()

        return last_head, best_head, last_classifier, best_classifier#, last_classifier2, best_classifier2

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        """Execute one training epoch"""
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        total_steps = len(src_loader)
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            global_step = (epoch - 1) * total_steps + step
            
            # Update learning rates using inverse scheduler
            self.optimizer_g = inv_lr_scheduler(self.param_lr_g, self.optimizer_g, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            self.optimizer_f = inv_lr_scheduler(self.param_lr_f, self.optimizer_f, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            self.optimizer_f2 = inv_lr_scheduler(self.param_lr_f2, self.optimizer_f2, global_step,
                                               init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            
            # Forward pass
            _, _, total_loss, losses = self.forward(src_data, trg_data, global_step)
            
            # Backward pass with mixed precision
            self.zero_grad_all()
            self.scaler.scale(total_loss).backward()
            self.scaler.step(self.optimizer_g)
            self.scaler.step(self.optimizer_f)
            self.scaler.step(self.optimizer_f2)
            self.scaler.update()
            
            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))


class DST(AdaMatch):
    """
    Debiased Self-Training (AdaMatch 기반 인라인 최적화 버전)
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(backbone, configs, num_epochs, post_epochs, hparams, device)
        
        # 추가 분류기 및 옵티마이저
        self.classifier2 = SSDA_classifier(configs, hparams)
        self.optimizer_f2 = torch.optim.SGD(self.classifier2.parameters(), 
                                    lr=hparams['lr_f'], momentum=0.9,
                                    weight_decay=hparams['weight_decay'], nesterov=True)
        self.param_lr_f2 = [pg['lr'] for pg in self.optimizer_f2.param_groups]
        self.dst_loss = dst_loss()
        weights_init(self.classifier2)

    def zero_grad_all(self):
        """Zero gradients for all optimizers"""
        self.optimizer_g.zero_grad()
        self.optimizer_f.zero_grad()
        self.optimizer_f2.zero_grad()

    def forward(self, src_data, trg_data, step):
        """Forward pass through the network"""
        src_x, src_x_bar, src_x_bar2, src_y, src_is_labeled = src_data
        trg_x, trg_x_bar, trg_x_bar2, trg_y, trg_is_labeled = trg_data
        
        # Move data to device
        src_x, src_y, src_is_labeled = src_x.cuda(non_blocking=True), src_y.cuda(non_blocking=True), src_is_labeled.cuda(non_blocking=True)
        trg_x, trg_y, trg_is_labeled = trg_x.cuda(non_blocking=True), trg_y.cuda(non_blocking=True), trg_is_labeled.cuda(non_blocking=True)
        src_x_bar = src_x_bar.cuda(non_blocking=True)
        trg_x_bar = trg_x_bar.cuda(non_blocking=True)
        trg_x_bar2 = trg_x_bar2.cuda(non_blocking=True)
        
        # Get labeled data
        im_data_src = src_x[src_is_labeled==1]
        im_data_bar_src = src_x_bar[src_is_labeled==1]
        im_data_trg = trg_x[trg_is_labeled==1]
        im_data_bar_trg = trg_x_bar[trg_is_labeled==1]
        im_data_trg_ul = trg_x[trg_is_labeled==0]
        im_data_bar_trg_ul = trg_x_bar[trg_is_labeled==0]
        im_data_bar2_trg_ul = trg_x_bar2[trg_is_labeled==0]
        gt_labels_src = src_y[src_is_labeled==1]
        gt_labels_trg = trg_y[trg_is_labeled==1]
        
        # Calculate loss using DST loss function
        loss1, loss2 = self.dst_loss(
            self.hparams, 
            self.feature_extractor, 
            self.classifier, 
            self.classifier2, 
            im_data_src, 
            im_data_bar_src, 
            im_data_trg, 
            im_data_bar_trg, 
            im_data_trg_ul, 
            im_data_bar_trg_ul,
            gt_labels_src, 
            gt_labels_trg, 
            step, 
            self.hparams['warm_steps'],
            self.device
        )
        #self, hparams, backbone, classifier, classifier2, im_data_src, im_data_bar_src, 
               # im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg,
               # step, warm_steps, device, ablation=''
        
        total_loss = loss1 + loss2
        
        src_feats = self.feature_extractor(src_x)
        trg_feats = self.feature_extractor(trg_x)
        
        losses = {
            'Clf_loss': loss1.item(), 
            'Clf_loss2': loss2.item(), 
            'total_loss': total_loss.item()
        }
        
        return src_feats, trg_feats, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        """Training process for the algorithm"""
        best_risk = float('inf')
        best_head = None
        best_classifier = None
        # best_classifier2 = None

        train_epoch = self.num_epochs
        
        # Training loop
        for epoch in range(1, train_epoch+1):
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            
            # Save the best model
            if (epoch+1)%5 == 0 and avg_meter['total_loss'].avg < best_risk:
                best_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.feature_extractor.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())
                # best_classifier2 = deepcopy(self.classifier2.state_dict())
                
            logger.debug(f'[Epoch : {epoch}/{train_epoch}]')
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')

        last_head = self.feature_extractor.state_dict()
        last_classifier = self.classifier.state_dict()
        # last_classifier2 = self.classifier2.state_dict()

        return last_head, best_head, last_classifier, best_classifier #, last_classifier2, best_classifier2

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        """Execute one training epoch"""
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        total_steps = len(src_loader)
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            global_step = (epoch - 1) * total_steps + step
            
            # Update learning rates using inverse scheduler
            self.optimizer_g = inv_lr_scheduler(self.param_lr_g, self.optimizer_g, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            self.optimizer_f = inv_lr_scheduler(self.param_lr_f, self.optimizer_f, global_step,
                                              init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            self.optimizer_f2 = inv_lr_scheduler(self.param_lr_f2, self.optimizer_f2, global_step,
                                               init_lr=self.hparams['learning_rate'], gamma=self.hparams['gamma'])
            
            # Forward pass
            _, _, total_loss, losses = self.forward(src_data, trg_data, global_step)
            
            # Backward pass with mixed precision
            self.zero_grad_all()
            self.scaler.scale(total_loss).backward()
            self.scaler.step(self.optimizer_g)
            self.scaler.step(self.optimizer_f)
            self.scaler.step(self.optimizer_f2)
            self.scaler.update()
            
            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))


class CLDA(Algorithm):
    """
    CLDA: Contrastive Learning for Semi-Supervised Domain Adaptation (NeurIPS 2021)
    Paper: https://papers.nips.cc/paper/2021/file/288cd2567953f06e460a33951f55daaf-Paper.pdf
    """
    def __init__(self, backbone, configs, num_epochs, post_epochs, hparams, device):
        super().__init__(configs, backbone, num_epochs, post_epochs)
        
        # Feature extractor
        self.feature_extractor = backbone(configs)
        
        # Classifier - SSDA_classifier로 변경
        self.classifier = SSDA_classifier(configs, hparams)
        weights_init(self.classifier)
        
        # Optimizer setup
        params = []
        for key, value in dict(self.feature_extractor.named_parameters()).items():
            if value.requires_grad:
                if 'classifier' not in key:
                    params += [{'params': [value], 'lr': hparams['multi'],
                                'weight_decay': hparams['weight_decay']}]
                else:
                    params += [{'params': [value], 'lr': hparams['multi'] * 10,
                                'weight_decay': hparams['weight_decay']}]
        
        self.optimizer_g = torch.optim.SGD(params, momentum=0.9,
                                          weight_decay=hparams['weight_decay'], nesterov=True)
        self.optimizer_f = torch.optim.SGD(list(self.classifier.parameters()), lr=hparams.get('lr_f', 1.0), momentum=0.9,
                                          weight_decay=hparams['weight_decay'], nesterov=True)
        
        # Store learning rates for scheduler
        self.param_lr_g = []
        for param_group in self.optimizer_g.param_groups:
            self.param_lr_g.append(param_group["lr"])
        self.param_lr_f = []
        for param_group in self.optimizer_f.param_groups:
            self.param_lr_f.append(param_group["lr"])
        
        # Loss functions - loss 파일에서 import
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.inter_domain_loss = CLDA_InterDomainContrastiveLoss(temperature=hparams.get('temperature', 0.5))
        self.instance_loss = CLDA_InstanceContrastiveLoss(temperature=hparams.get('temperature', 0.5))
        self.contrastive_loss = CLDA_ContrastiveLoss(temperature=hparams.get('temperature', 0.5))
        
        # Hyperparameters
        self.hparams = hparams
        self.device = device
        
        # Class centroids for Inter-Domain Contrastive Alignment
        self.source_centroids = {}
        self.target_centroids = {}
        self.momentum = hparams.get('momentum', 0.9)
        
        # Move to device
        self.feature_extractor = torch.nn.DataParallel(self.feature_extractor).cuda()
        self.classifier = torch.nn.DataParallel(self.classifier).cuda()

    def zero_grad_all(self):
        """Zero gradients for all optimizers"""
        self.optimizer_g.zero_grad()
        self.optimizer_f.zero_grad()

    def compute_centroids(self, features, labels):
        """Compute class centroids for given features and labels"""
        centroids = {}
        unique_labels = torch.unique(labels)
        
        for label in unique_labels:
            mask = (labels == label)
            if mask.sum() > 0:
                centroid = features[mask].mean(dim=0)
                centroids[label.item()] = centroid
        
        return centroids

    def update_centroids(self, new_centroids, old_centroids):
        """Update centroids with momentum"""
        if not old_centroids:
            return new_centroids
        
        updated_centroids = {}
        for class_id, new_centroid in new_centroids.items():
            if class_id in old_centroids:
                updated_centroids[class_id] = (self.momentum * old_centroids[class_id] + 
                                             (1 - self.momentum) * new_centroid)
            else:
                updated_centroids[class_id] = new_centroid
        
        # Keep old centroids for classes not present in new batch
        for class_id, old_centroid in old_centroids.items():
            if class_id not in updated_centroids:
                updated_centroids[class_id] = old_centroid
                
        return updated_centroids

    def forward(self, src_data, trg_data, step):
        """Forward pass through the network"""
        # Unpack data with error handling
        src_x, src_x_bar, src_x_bar2, src_y, src_is_labeled = src_data
        trg_x, trg_x_bar, trg_x_bar2, trg_y, trg_is_labeled = trg_data
        
        # Move to device
        src_x, src_y = src_x.to(self.device, non_blocking=True), src_y.to(self.device, non_blocking=True)
        trg_x, trg_x_aug = trg_x.to(self.device, non_blocking=True), trg_x_bar.to(self.device, non_blocking=True)
        trg_y, trg_is_labeled = trg_y.to(self.device, non_blocking=True), trg_is_labeled.to(self.device, non_blocking=True)
        
        # Get labeled and unlabeled target data
        trg_x_labeled = trg_x[trg_is_labeled == 1]
        trg_y_labeled = trg_y[trg_is_labeled == 1]
        trg_x_unlabeled = trg_x[trg_is_labeled == 0]
        trg_x_aug_unlabeled = trg_x_aug[trg_is_labeled == 0]
        
        # Forward pass through feature extractor
        src_features = self.feature_extractor(src_x)
        trg_labeled_features = self.feature_extractor(trg_x_labeled) if trg_x_labeled.size(0) > 0 else torch.empty(0, src_features.size(1)).to(self.device)
        trg_unlabeled_features = self.feature_extractor(trg_x_unlabeled) if trg_x_unlabeled.size(0) > 0 else torch.empty(0, src_features.size(1)).to(self.device)
        trg_aug_unlabeled_features = self.feature_extractor(trg_x_aug_unlabeled) if trg_x_aug_unlabeled.size(0) > 0 else torch.empty(0, src_features.size(1)).to(self.device)
        
        # Classifier predictions
        src_pred = self.classifier(src_features)
        trg_labeled_pred = self.classifier(trg_labeled_features) if trg_labeled_features.size(0) > 0 else None
        
        # Supervised loss
        supervised_loss = self.criterion(src_pred, src_y)
        if trg_labeled_pred is not None and trg_x_labeled.size(0) > 0:
            supervised_loss += self.criterion(trg_labeled_pred, trg_y_labeled)
        
        # Compute centroids for Inter-Domain Contrastive Alignment
        source_centroids = self.compute_centroids(src_features.detach(), src_y)
        target_centroids = {}
        if trg_labeled_features.size(0) > 0:
            target_centroids = self.compute_centroids(trg_labeled_features.detach(), trg_y_labeled)
        
        # Update centroids with momentum
        self.source_centroids = self.update_centroids(source_centroids, self.source_centroids)
        self.target_centroids = self.update_centroids(target_centroids, self.target_centroids)
        
        # Inter-Domain Contrastive Loss - loss 파일의 함수 사용
        inter_domain_loss = self.inter_domain_loss(self.source_centroids, self.target_centroids)
        
        # Instance Contrastive Loss - loss 파일의 함수 사용
        instance_loss = self.instance_loss(trg_unlabeled_features, trg_aug_unlabeled_features)
        
        # Additional contrastive loss using existing implementation
        contrastive_loss = torch.tensor(0.0).to(self.device)
        if src_features.size(0) > 0 and trg_labeled_features.size(0) > 0:
            contrastive_loss = self.contrastive_loss(src_features, trg_labeled_features, src_y, trg_y_labeled)
        
        # Total loss
        total_loss = (supervised_loss + 
                     self.hparams.get('lambda_inter', 1.0) * inter_domain_loss + 
                     self.hparams.get('lambda_instance', 1.0) * instance_loss +
                     self.hparams.get('lambda_contrastive', 0.1) * contrastive_loss)
        
        losses = {
            'Supervised_loss': supervised_loss.item(),
            'Inter_domain_loss': inter_domain_loss.item(),
            'Instance_loss': instance_loss.item(),
            'Contrastive_loss': contrastive_loss.item(),
            'total_loss': total_loss.item()
        }
        
        return src_features, trg_labeled_features, total_loss, losses

    def update(self, src_loader, trg_loader, avg_meter, logger):
        """Training process for the algorithm"""
        best_risk = float('inf')
        best_head = None
        best_classifier = None
        
        train_epoch = self.num_epochs
        
        # Training loop
        for epoch in range(1, train_epoch + 1):
            self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
            
            # Save the best model
            if (epoch + 1) % 5 == 0 and avg_meter['total_loss'].avg < best_risk:
                best_risk = avg_meter['total_loss'].avg
                best_head = deepcopy(self.feature_extractor.state_dict())
                best_classifier = deepcopy(self.classifier.state_dict())
            
            logger.debug(f'[Epoch : {epoch}/{train_epoch}]')
            for key, val in avg_meter.items():
                logger.debug(f'{key}\t : {val.avg:2.4f}')
            logger.debug(f'---------------------------------------------------------')
        
        last_head = self.feature_extractor.state_dict()
        last_classifier = self.classifier.state_dict()
        
        return last_head, best_head, last_classifier, best_classifier

    def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
        """Execute one training epoch"""
        joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
        
        for step, (src_data, trg_data) in tqdm(joint_loader):
            global_step = (epoch - 1) * len(src_loader) + step
            
            # Update learning rates using inverse scheduler (기존 코드 구조와 일치)
            self.optimizer_g = inv_lr_scheduler(self.param_lr_g, self.optimizer_g, global_step,
                                               init_lr=self.hparams['learning_rate'], 
                                               gamma=self.hparams.get('gamma', 0.0001))
            self.optimizer_f = inv_lr_scheduler(self.param_lr_f, self.optimizer_f, global_step,
                                               init_lr=self.hparams['learning_rate'], 
                                               gamma=self.hparams.get('gamma', 0.0001))
            
            # Forward pass
            _, _, total_loss, losses = self.forward(src_data, trg_data, global_step)
            
            # Backward pass
            self.zero_grad_all()
            total_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.feature_extractor.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.classifier.parameters(), max_norm=1.0)
            
            self.optimizer_g.step()
            self.optimizer_f.step()
            
            # Update average meters
            for key, val in losses.items():
                avg_meter[key].update(val, src_data[0].size(0))
