import os
import pickle
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import average_precision_score

from base.datasets import load_data
from base.networks import OTmap_MLP, Classifier_MLP, Classifier_Linear


class Trainer:
    
    def __init__(self):
        torch.set_num_threads(4) 
        pass
    
    """"""""""""""""""""""""
    """ handling data """
    """"""""""""""""""""""""
    
    def _load_data(self, source, target, batch_size, seed):
        
        (self.allloader_source, self.trainloader_source, self.valloader_source, self.testloader_source), \
        (self.allloader_target, self.trainloader_target, self.valloader_target, self.testloader_target), \
        (self.all_source_categorical_id_set, self.train_source_categorical_id_set, self.val_source_categorical_id_set, self.test_source_categorical_id_set), \
        (self.all_target_categorical_id_set, self.train_target_categorical_id_set, self.val_target_categorical_id_set, self.test_target_categorical_id_set), \
        (self.input_dim, self.all_weights) = load_data(source, target, batch_size, seed)
        
        
    """"""""""""""""""""""""
    """ creating directories """
    """"""""""""""""""""""""
                
    def _create_dir(self, args):
        
        self.output_dir = f'checkpoints/source-{args.source}_target-{args.target}/alg-ftm/'
        os.makedirs(self.output_dir, exist_ok=True)
        

    """"""""""""""""""""""""
    """ building networks """
    """"""""""""""""""""""""
        
    def _build_networks(self, act, bn, dropout):
        
        otmap = OTmap_MLP(input_dim=self.input_dim, 
                          hidden_dims=self.otmap_hidden_dims, 
                          output_dim=self.input_dim,
                          act=act,
                          bn=bn,
                          dropout=dropout,
                          last_act=self.last_act).cuda()      
        return otmap
    
    
    def _rounding_categorical(self, reps, categorical_id_set):
        for categorical_ids in categorical_id_set:
            if len(categorical_ids) > 1:
                reps[:, categorical_ids] = torch.round( reps[:, categorical_ids] - torch.max( reps[:, categorical_ids], dim=1 ).values.view(-1, 1) + 0.5 + 1e-7 )
            elif len(categorical_ids) == 1:
                reps[:, categorical_ids] = torch.round(reps[:, categorical_ids])
        return reps

    
    """"""""""""""""""""""""""""""""""""
    """ internal computing functions """
    """"""""""""""""""""""""""""""""""""
                
    def _compute_ot_loss(self, x, y, weight):
        batch_size = x.size(0)
        diff = ((x - y)**2)
        diff = diff.view(batch_size, -1)
        return (diff * weight).sum(dim=1).mean()
        
        
    def _compute_ipm_loss(self, x, y, sigmas=None):
        assert sigmas != None
        return self._compute_MMD(x, y, sigmas)
    
    
    def _compute_fair_loss(self, preds0, preds1, surrogate=None, tau=None):
        if surrogate == None:
            # loss = (preds0[:, 1] - preds1[:, 1]).abs().mean()
            loss = ((preds0[:, 1] - preds1[:, 1])**2).mean()
        elif surrogate == 'linear':
            loss = (preds0[:, 1].mean() - preds1[:, 1].mean())**2
        elif surrogate == 'hinge':
            loss = (torch.relu(preds0[:, 1] + 1.0).mean() - torch.relu(preds1[:, 1] + 1.0).mean())**2
        elif surrogate == 'slide':
            assert tau != None, ValueError
            term1 = (torch.relu(preds0[:, 1]) / tau) - (torch.relu(preds0[:, 1] - tau) / tau)
            term2 = (torch.relu(preds1[:, 1]) / tau) - (torch.relu(preds1[:, 1] - tau) / tau)
            loss = ((term1.mean() - term2.mean())**2).mean()
        else:
            raise NotImplementedError
        return loss
    
    
    def _Gaussian_kernel_matrix(self, Xi, Xj, sigma=1.0):
        """_summary_
        Args:
            Xi (torch.tensor): (B1, d)
            Xj (torch.tensor): (B2, d)
        """
        matrix = - torch.cdist(Xi, Xj, p=2)**2
        matrix /= (2.0 * sigma**2)
        matrix = torch.exp(matrix)
        return matrix


    def _compute_MMD(self, source_reps, target_reps, sigmas=[1.0]):
        """_summary_
        Args:
            source_reps (torch.tensor): (B1, d)
            target_reps (torch.tensor): (B2, d)
            sigma (float): _description_. Defaults to 1.0.

        Returns:
            _type_: _description_
        """
        
        mmd = 0.0
        for sigma in sigmas:
            KXX = self._Gaussian_kernel_matrix(source_reps, source_reps, sigma=sigma)
            KXY = self._Gaussian_kernel_matrix(source_reps, target_reps, sigma=sigma)
            KYY = self._Gaussian_kernel_matrix(target_reps, target_reps, sigma=sigma)
            mmd += KXX.mean() - 2 * KXY.mean() + KYY.mean()
        # return mmd.mean()
        return mmd.max()


    def _compute_accuracy(self, preds, probs, labels):
        """_summary_
        Args:
            preds (torch.tensor): (B, ) 0 or 1
            probs (torch.tensor): (B, ) [0, 1]
            labels (torch.tensor): (B, ) 0 or 1
        """    
        acc = (preds == labels).float().mean()
        bacc = (preds[labels == 0] == labels[labels == 0]).float().mean()
        bacc += (preds[labels == 1] == labels[labels == 1]).float().mean()
        bacc /= 2.0
        ap = average_precision_score(labels.detach().cpu().numpy(), probs.detach().cpu().numpy())
        
        return round(acc.item(), 4), round(bacc.item(), 4), round(ap.item(), 4)
    
    
    def _compute_fairness(self, preds, probs, labels, groups):
        """_summary_
        Args:
            preds (torch.tensor): (B, ) 0 or 1
            probs (torch.tensor): (B, ) [0, 1]
            labels (torch.tensor): (B, ) 0 or 1
            groups (torch.tensor): (B, ) 0 or 1
        """
        preds0, preds1 = preds[groups == 0], preds[groups == 1]
        probs0, probs1 = probs[groups == 0], probs[groups == 1]
        preds00, preds10 = preds[(groups == 0)*(labels == 0)], preds[(groups == 1)*(labels == 0)]
        probs00, probs10 = probs[(groups == 0)*(labels == 0)], probs[(groups == 1)*(labels == 0)]
        preds01, preds11 = preds[(groups == 0)*(labels == 1)], preds[(groups == 1)*(labels == 1)]
        probs01, probs11 = probs[(groups == 0)*(labels == 1)], probs[(groups == 1)*(labels == 1)]
        
        dp = (preds0.mean() - preds1.mean()).abs()
        mdp = (probs0.mean() - probs1.mean()).abs()
        sqdp = ((probs0**2).mean() - (probs1**2).mean()).abs()
        sdps = []
        for tau in np.linspace(0.1, 1.0, 10):
            sdps.append(
                ( (probs0 > tau).float().mean() - (probs1 > tau).float().mean() ).abs().item()
                )
        sdp = torch.tensor([float(np.mean(sdps))])
        eqopp = (preds01.mean() - preds11.mean()).abs()
        meqopp = (probs01.mean() - probs11.mean()).abs()
        eo = ((preds00.mean() - preds10.mean()).abs() + eqopp) / 2.0
        meo = ((probs00.mean() - probs10.mean()).abs() + meqopp) / 2.0
        
        return round(dp.item(), 4), round(mdp.item(), 4), round(sqdp.item(), 4), round(sdp.item(), 4), round(eqopp.item(), 4), round(meqopp.item(), 4), round(eo.item(), 4), round(meo.item(), 4)


    """"""""""""""""""""""""""""""""""""
    """ internal batch-opening functions """
    """"""""""""""""""""""""""""""""""""


    def _get_full_batch(self, loader, otmap=None, C=None, only_x=False):
        with torch.no_grad():
            all_x, all_y = [], []
            for images, labels in loader:
                images, labels = images.view(-1, self.input_dim).cuda(), labels.cuda()
                if otmap != None:
                    if C != None:
                        all_x.append(C(otmap(images)).detach().argmax(dim=1))
                    else:
                        all_x.append(otmap(images).detach())
                else:
                    all_x.append(images)
                all_y.append(labels)
            all_x, all_y = torch.cat(all_x), torch.cat(all_y)
            if only_x:
                return all_x
            else:
                return all_x, all_y


    def _validate(self, C_model, classifier, valloader_source, valloader_target):
        
        classifier.eval()
        
        # original
        preds0, preds1 = [], []
        labels0, labels1 = [], []
        with torch.no_grad():
            for inputs, labels in valloader_source:
                inputs, labels = inputs.cuda(), labels.cuda()
                if 'MLP' in C_model:
                    inputs = inputs.view(-1, self.input_dim)
                # preds0.append(classifier(inputs).detach())
                preds0.append(classifier(torch.cat([inputs, torch.zeros(inputs.size(0)).view(-1, 1).cuda()], dim=1)).detach())
                labels0.append(labels)
            for inputs, labels in valloader_target:
                inputs, labels = inputs.cuda(), labels.cuda()
                if 'MLP' in C_model:
                    inputs = inputs.view(-1, self.input_dim)
                # preds1.append(classifier(inputs).detach())
                preds1.append(classifier(torch.cat([inputs, torch.ones(inputs.size(0)).view(-1, 1).cuda()], dim=1)).detach())
                labels1.append(labels)
        preds0, preds1 = torch.cat(preds0), torch.cat(preds1)
        labels0, labels1 = torch.cat(labels0), torch.cat(labels1)
        
        all_preds = torch.cat([preds0, preds1]).argmax(dim=1).float().cuda()
        all_probs = torch.cat([preds0, preds1]).cuda()
        all_labels = torch.cat([labels0, labels1]).cuda()
        all_groups = torch.cat([torch.zeros(preds0.size(0)), torch.ones(preds1.size(0))]).cuda()
        all_probs = all_probs[:, 1].flatten()
        
        # flipped
        preds0, preds1 = [], []
        labels0, labels1 = [], []
        with torch.no_grad():
            for inputs, labels in valloader_source:
                inputs, labels = inputs.cuda(), labels.cuda()
                if 'MLP' in C_model:
                    inputs = inputs.view(-1, self.input_dim)
                # preds0.append(classifier(inputs).detach())
                preds0.append(classifier(torch.cat([inputs, torch.ones(inputs.size(0)).view(-1, 1).cuda()], dim=1)).detach())
                labels0.append(labels)
            for inputs, labels in valloader_target:
                inputs, labels = inputs.cuda(), labels.cuda()
                if 'MLP' in C_model:
                    inputs = inputs.view(-1, self.input_dim)
                # preds1.append(classifier(inputs).detach())
                preds1.append(classifier(torch.cat([inputs, torch.zeros(inputs.size(0)).view(-1, 1).cuda()], dim=1)).detach())
                labels1.append(labels)
        preds0, preds1 = torch.cat(preds0), torch.cat(preds1)
        labels0, labels1 = torch.cat(labels0), torch.cat(labels1)
        
        all_flipped_preds = torch.cat([preds0, preds1]).argmax(dim=1).float().cuda()
        
        return all_preds, all_probs, all_labels, all_groups, all_flipped_preds


    """"""""""""""""""""""""
    """"""""""""""""""""""""
    """ training functions """
    """"""""""""""""""""""""
    """"""""""""""""""""""""
    
    """"""""""""""""""""""""""""""
    """ training bidirectional optimal transporter """
    """"""""""""""""""""""""""""""
        
    def train_OT_bidirect(self, args):
        
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        
        print('[Info] Creating directories')
        self._create_dir(args)
        
        print('[Info] Building networks')
        if args.OT_model == 'MLP':
            self.otmap_hidden_dims = [self.input_dim] * 3
        elif args.OT_model == '2MLP':
            self.otmap_hidden_dims = [self.input_dim] * 2
        elif args.OT_model == '1MLP':
            self.otmap_hidden_dims = [self.input_dim]
        
        self.last_act = args.last_act
        
        
        sigmas = [0.01, 0.1, 1.0, 10.0, 100.0]
        
        
        self.all_weights = torch.from_numpy(self.all_weights).cuda()
        otmap_forward = self._build_networks(act=args.act, bn=args.bn, dropout=args.dropout)
        otmap_backward = self._build_networks(act=args.act, bn=args.bn, dropout=args.dropout)
        self.otmap_forward_best = otmap_forward
        self.otmap_backward_best = otmap_backward
          
        
        try:
            map_forward_optimizer = getattr(torch.optim, args.opt)(otmap_forward.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.5, 0.999))
            map_backward_optimizer = getattr(torch.optim, args.opt)(otmap_backward.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.5, 0.999)) 
        except:
            map_forward_optimizer = getattr(torch.optim, args.opt)(otmap_forward.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.85)
            map_backward_optimizer = getattr(torch.optim, args.opt)(otmap_backward.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.85) 
        
        self.data_iter_source = iter(self.allloader_source)
        self.data_iter_target = iter(self.allloader_target)
        self.num_iterations = int(min(len(self.data_iter_source), len(self.data_iter_target))*args.epochs)
        self.saved_iters = list(range(args.val_freq, self.num_iterations + args.val_freq, args.val_freq))
        
        
        train_losses = {'ot': {'forward': [], 'backward': []},
                        'ipm': {'forward': [], 'backward': []}}
        val_losses = {'ot': {'forward': [], 'backward': []},
                      'ipm': {'forward': [], 'backward': []}}


        print('[Info] STEP 1 begins')
        
        if (os.path.exists(self.output_dir + f'otmap_forward_best.pt')) and (os.path.exists(self.output_dir + f'otmap_backward_best.pt')):
            return
        
        
        forward_measures, backward_measures = [], []
        best_forward_measures, best_backward_measures = [], []
        best_forward_measure, best_backward_measure = 1e+10, 1e+10
        
        verbose = 0
        
        for it in range(self.num_iterations):
            it += 1

            print(f'iter: {it}', end='\r')
            
            otmap_forward.train()
            otmap_backward.train()
            
            try:
                source_images, source_labels = next(self.data_iter_source)
                new_epoch = False
            except StopIteration:
                self.data_iter_source = iter(self.trainloader_source)
                source_images, source_labels = next(self.data_iter_source)
            source_images, source_labels = source_images.cuda(), source_labels.cuda()
            source_reps = source_images.view(-1, self.input_dim)
            
            
            try:
                target_images, target_labels = next(self.data_iter_target)
            except StopIteration:
                self.data_iter_target = iter(self.trainloader_target)
                target_images, target_labels = next(self.data_iter_target)
            target_images, target_labels = target_images.cuda(), target_labels.cuda()
            target_reps = target_images.view(-1, self.input_dim)

            otmap_forward.melt()
            otmap_backward.melt()
                
            source_reps = source_reps.detach()
            mapped_source_reps = otmap_forward(source_reps)
            mapped_source_reps = self._rounding_categorical(mapped_source_reps.clone(), self.train_source_categorical_id_set)
            
            ot_forward_loss = self._compute_ot_loss(source_reps, mapped_source_reps, self.all_weights)
            D_forward_loss = self._compute_ipm_loss(mapped_source_reps, target_reps, sigmas=sigmas)
            forward_loss = ot_forward_loss + args.lmda_ipm * D_forward_loss
            
            map_forward_optimizer.zero_grad()
            forward_loss.backward()
            map_forward_optimizer.step()
            
            
            target_reps = target_reps.detach()
            mapped_target_reps = otmap_backward(target_reps)
            mapped_target_reps = self._rounding_categorical(mapped_target_reps.clone(), self.train_target_categorical_id_set)
            
            ot_backward_loss = self._compute_ot_loss(target_reps, mapped_target_reps, self.all_weights)
            D_backward_loss = self._compute_ipm_loss(mapped_target_reps, source_reps, sigmas=sigmas)
            backward_loss = ot_backward_loss + args.lmda_ipm * D_backward_loss
            
            map_backward_optimizer.zero_grad()
            backward_loss.backward()
            map_backward_optimizer.step()
            
            
            train_ot_forward_loss = '{:4.5f}'.format(ot_forward_loss.item())
            train_ot_backward_loss = '{:4.5f}'.format(ot_backward_loss.item())
            train_ipm_forward_loss = '{:4.5f}'.format(D_forward_loss.item())
            train_ipm_backward_loss = '{:4.5f}'.format(D_backward_loss.item())
            
            train_losses['ot']['forward'].append(float(train_ot_forward_loss))
            train_losses['ot']['backward'].append(float(train_ot_backward_loss))
            train_losses['ipm']['forward'].append(float(train_ipm_forward_loss))
            train_losses['ipm']['backward'].append(float(train_ipm_backward_loss))
                    
            
            if (it > args.warmup_steps) and (it % args.val_freq == 0):
                print(f'[STEP 1] Validation [{it}/{self.num_iterations}]')
                verbose += 1
                
                if verbose > 20:
                    return
                
                otmap_forward.network.eval()            
                otmap_backward.network.eval()

                all_inputs = []
                all_source_reps, all_target_reps = [], []
                all_mapped_source_reps, all_mapped_target_reps = [], []
                
                with torch.no_grad():
                    for inputs, _ in self.valloader_source:
                        inputs = inputs.cuda()
                        reps = inputs.view(-1, self.input_dim)
                        mapped_reps = otmap_forward(reps)
                        mapped_reps = self._rounding_categorical(mapped_reps, self.val_source_categorical_id_set)
                        
                        all_inputs.append(inputs)
                        all_source_reps.append(reps.detach())
                        all_mapped_source_reps.append(mapped_reps.detach())
                    for inputs, _ in self.valloader_target:
                        inputs = inputs.cuda()
                        reps = inputs.view(-1, self.input_dim)
                        mapped_reps = otmap_backward(reps)
                        mapped_reps = self._rounding_categorical(mapped_reps, self.val_target_categorical_id_set)
                        
                        all_inputs.append(inputs)
                        all_target_reps.append(reps.detach())
                        all_mapped_target_reps.append(mapped_reps.detach())
                        
                all_inputs = torch.cat(all_inputs)
                all_source_reps, all_target_reps = torch.cat(all_source_reps), torch.cat(all_target_reps)
                all_mapped_source_reps, all_mapped_target_reps = torch.cat(all_mapped_source_reps), torch.cat(all_mapped_target_reps)            
            
                with torch.no_grad():
                    recon_val_loss = torch.tensor([0.0]).item()
                    ot_val_forward_loss = self._compute_ot_loss(all_source_reps, all_mapped_source_reps, self.all_weights).item()
                    ot_val_backward_loss = self._compute_ot_loss(all_target_reps, all_mapped_target_reps, self.all_weights).item()
                    mmd_val_forward_loss = self._compute_ipm_loss(all_mapped_source_reps, all_target_reps, sigmas=sigmas).item()
                    mmd_val_backward_loss = self._compute_ipm_loss(all_mapped_target_reps, all_source_reps, sigmas=sigmas).item()
                
                    val_losses['ot']['forward'].append(float(ot_val_forward_loss))
                    val_losses['ot']['backward'].append(float(ot_val_backward_loss))
                    val_losses['ipm']['forward'].append(float(mmd_val_forward_loss))
                    val_losses['ipm']['backward'].append(float(mmd_val_backward_loss))                
                
                
                is_best = {'forward': False, 'backward': False}
                current_forward_measure = ot_val_forward_loss + args.lmda_ipm * mmd_val_forward_loss
                current_backward_measure = ot_val_backward_loss + args.lmda_ipm * mmd_val_backward_loss
                                
                is_best['forward'] = (current_forward_measure < best_forward_measure) and (mmd_val_forward_loss > 0.0)
                is_best['backward'] = (current_backward_measure < best_backward_measure) and (mmd_val_backward_loss > 0.0)
                
                forward_measures.append(current_forward_measure)
                backward_measures.append(current_backward_measure)
                                
                if is_best['forward']:
                    best_forward_measure = current_forward_measure
                    best_forward_measures.append(best_forward_measure)
                    print(f'         GOT BEST forward | T loss {round(ot_val_forward_loss, 4)} MMD {round(mmd_val_forward_loss, 4)}')
                    self.otmap_forward_best = otmap_forward
                    
                    torch.save(otmap_forward.network.state_dict(), self.output_dir + 'otmap_forward_best.pt')


                if is_best['backward']:                    
                    best_backward_measure = current_backward_measure
                    best_backward_measures.append(best_backward_measure)
                    print(f'         GOT BEST backward | T loss {round(ot_val_backward_loss, 4)} MMD {round(mmd_val_backward_loss, 4)}')
                    self.otmap_backward_best = otmap_backward
                    
                    torch.save(otmap_backward.network.state_dict(), self.output_dir + 'otmap_backward_best.pt')

                if is_best['forward'] or is_best['backward']:
                    verbose = 0


    """"""""""""""""""""""""""""""
    """ training fair classifier"""
    """"""""""""""""""""""""""""""

    def train_FairC(self, args):
        
        torch.manual_seed(args.C_seed)
        
        self.result_dir = self.output_dir + f'lmda_f-{args.lmda_f}/C_seed-{args.C_seed}/'
        os.makedirs(self.result_dir, exist_ok=True)
        
        # if os.path.exists(self.result_dir + 'best_results.pickle'):
            # return
        
        """ check whether the best OT map is avaiable """      
        
        self.model_dir = self.output_dir
        
        self.otmap_forward_best = OTmap_MLP(input_dim=self.input_dim,
                                            hidden_dims=self.otmap_hidden_dims, 
                                            output_dim=self.input_dim, 
                                            act=args.act, 
                                            bn=args.bn, 
                                            dropout=args.dropout, 
                                            last_act=self.last_act)
        self.otmap_forward_best.eval()
        self.otmap_forward_best.network.load_state_dict(
            torch.load(self.model_dir + 'otmap_forward_best.pt')
            )
        self.otmap_forward_best.cuda()        
        
        self.otmap_backward_best = OTmap_MLP(input_dim=self.input_dim, 
                                            hidden_dims=self.otmap_hidden_dims, 
                                            output_dim=self.input_dim, 
                                            act=args.act, 
                                            bn=args.bn, 
                                            dropout=args.dropout, 
                                            last_act=self.last_act)
        self.otmap_backward_best.eval()
        self.otmap_backward_best.network.load_state_dict(
            torch.load(self.model_dir + 'otmap_backward_best.pt')
            )
        self.otmap_backward_best.cuda()
        print('[Info] load saved OT maps!')
                    
        
        """ Classifier network """
        if args.C_model == 'MLP':
            C_0 = Classifier_MLP(self.input_dim+1,
                               hidden_dims=[self.input_dim, self.input_dim],
                               output_dim=2,
                               act='ReLU').cuda()
            C_1 = Classifier_MLP(self.input_dim+1,
                               hidden_dims=[self.input_dim, self.input_dim],
                               output_dim=2,
                               act='ReLU').cuda()
        elif args.C_model == '1MLP':
            C_0 = Classifier_MLP(self.input_dim+1,
                               hidden_dims=[self.input_dim], 
                               output_dim=2,
                               act='ReLU').cuda()
            C_1 = Classifier_MLP(self.input_dim+1,
                               hidden_dims=[self.input_dim], 
                               output_dim=2,
                               act='ReLU').cuda()
        elif 'Linear' in args.C_model:
            C_0 = Classifier_Linear(self.input_dim+1,
                                  output_dim=2
                                  ).cuda()
            C_1 = Classifier_Linear(self.input_dim+1,
                                  output_dim=2
                                  ).cuda()
        else:
            raise ValueError
        
        
        try:
            optimizer_0 = getattr(torch.optim, args.opt)(C_0.network.parameters(), lr=args.C_lr, weight_decay=args.wd, betas=(0.5, 0.999))
            optimizer_1 = getattr(torch.optim, args.opt)(C_1.network.parameters(), lr=args.C_lr, weight_decay=args.wd, betas=(0.5, 0.999))
        except:
            optimizer_0 = getattr(torch.optim, args.opt)(C_0.network.parameters(), lr=args.C_lr)
            optimizer_1 = getattr(torch.optim, args.opt)(C_1.network.parameters(), lr=args.C_lr)
        scheduler_0 = torch.optim.lr_scheduler.LambdaLR(optimizer_0, lr_lambda=lambda epoch: 0.999**epoch)
        scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizer_0, step_size=500, gamma=0.5)
        scheduler_1 = torch.optim.lr_scheduler.LambdaLR(optimizer_1, lr_lambda=lambda epoch: 0.999**epoch)
        scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=500, gamma=0.5)

        bests_0 = {'val': {'acc': 0.0, 'bacc': 0.0, 'ap': 0.0,
                         'dp': 1.0, 'mdp': 1.0, 'sqdp': 1.0, 'sdp': 1.0,
                         'eqopp': 1.0, 'meqopp': 1.0, 'eo': 1.0, 'meo': 1.0},
                 'test': {'acc': 0.0, 'bacc': 0.0, 'ap': 0.0,
                          'dp': 1.0, 'mdp': 1.0, 'sqdp': 1.0, 'sdp': 1.0,
                          'eqopp': 1.0, 'meqopp': 1.0, 'eo': 1.0, 'meo': 1.0}
                 }
        bests_1 = {'val': {'acc': 0.0, 'bacc': 0.0, 'ap': 0.0,
                         'dp': 1.0, 'mdp': 1.0, 'sqdp': 1.0, 'sdp': 1.0,
                         'eqopp': 1.0, 'meqopp': 1.0, 'eo': 1.0, 'meo': 1.0},
                 'test': {'acc': 0.0, 'bacc': 0.0, 'ap': 0.0,
                          'dp': 1.0, 'mdp': 1.0, 'sqdp': 1.0, 'sdp': 1.0,
                          'eqopp': 1.0, 'meqopp': 1.0, 'eo': 1.0, 'meo': 1.0}
                 }

        self.data_iter_source = iter(self.trainloader_source)
        self.data_iter_target = iter(self.trainloader_target)

        """ Training """
        for i in range(args.C_epochs):
            i += 1

            try:
                source_images, source_labels = next(self.data_iter_source)
            except StopIteration:
                self.data_iter_source = iter(self.trainloader_source)
                source_images, source_labels = next(self.data_iter_source)
            source_images, source_labels = source_images.cuda(), source_labels.long().cuda()
            source_reps = source_images.view(-1, self.input_dim)
            
            try:
                target_images, target_labels = next(self.data_iter_target)
            except StopIteration:
                self.data_iter_target = iter(self.trainloader_target)
                target_images, target_labels = next(self.data_iter_target)
            target_images, target_labels = target_images.cuda(), target_labels.long().cuda()
            target_reps = target_images.view(-1, self.input_dim)
            
            """ C_0 """
            C_0.network.train()
            source_preds = C_0(torch.cat([source_reps, torch.zeros(source_reps.size(0)).view(-1, 1).cuda()], dim=1))
            target_preds = C_0(torch.cat([target_reps, torch.ones(target_reps.size(0)).view(-1, 1).cuda()], dim=1))
            ce_loss = F.cross_entropy(source_preds, source_labels)
            ce_loss += F.cross_entropy(target_preds, target_labels) 
            
            self.otmap_forward_best.network.eval()
            mapped_source_reps = self.otmap_forward_best(source_reps).detach()
            mapped_source_preds = C_0(torch.cat([mapped_source_reps, torch.ones(mapped_source_reps.size(0)).view(-1, 1).cuda()], dim=1))

            fair_loss = args.lmda_f * self._compute_fair_loss(source_preds, mapped_source_preds)
            
            loss = ce_loss + fair_loss
            
            optimizer_0.zero_grad()
            loss.backward()
            optimizer_0.step()
            scheduler_0.step()
            
            """ C_1 """
            C_1.network.train()
            source_preds = C_1(torch.cat([source_reps, torch.zeros(source_reps.size(0)).view(-1, 1).cuda()], dim=1))
            target_preds = C_1(torch.cat([target_reps, torch.ones(target_reps.size(0)).view(-1, 1).cuda()], dim=1))
            ce_loss = F.cross_entropy(source_preds, source_labels)
            ce_loss += F.cross_entropy(target_preds, target_labels) 
            
            self.otmap_backward_best.network.eval()
            mapped_target_reps = self.otmap_backward_best(target_reps).detach()
            mapped_target_preds = C_1(torch.cat([mapped_target_reps, torch.zeros(mapped_target_reps.size(0)).view(-1, 1).cuda()], dim=1))

            fair_loss = args.lmda_f * self._compute_fair_loss(target_preds, mapped_target_preds)
            
            loss = ce_loss + fair_loss
            
            optimizer_1.zero_grad()
            loss.backward()
            optimizer_1.step()
            scheduler_1.step()
            
            print(f'[STEP 2] lmda_f {args.lmda_f} [{i}/{args.C_epochs}] Loss: {round(loss.item(), 4)} | CE Loss: {round(ce_loss.item(), 4)} | Fair Loss: {round(fair_loss.item(), 4)}', end='\r')
            
            if (i > 100) and (i % args.C_val_freq == 0):
                
                """ C_0 """
                C_0.network.eval()
                all_preds, all_probs, all_labels, all_groups, all_flipped_preds = self._validate(args.C_model, C_0, self.valloader_source, self.valloader_target)
                acc, bacc, ap = self._compute_accuracy(all_preds, all_probs, all_labels)
                dp, mdp, sqdp, sdp, eqopp, meqopp, eo, meo = self._compute_fairness(all_preds, all_probs, all_labels, all_groups)
                better = acc >= bests_0['val']['acc'] if args.lmda_f == 0.0 else (acc - mdp) >= (bests_0['val']['acc'] - bests_0['val']['mdp'])
                non_trivial = (dp != 0.0) and (dp != 1.0)
                if better and non_trivial:
                    print(f'Validation [{i}/{args.C_epochs}] Got BEST!')
                    bests_0['val']['acc'], bests_0['val']['bacc'], bests_0['val']['ap'] = acc, bacc, ap
                    bests_0['val']['dp'], bests_0['val']['mdp'], bests_0['val']['sqdp'], bests_0['val']['sdp'] = dp, mdp, sqdp, sdp
                    bests_0['val']['eqopp'], bests_0['val']['meqopp'], bests_0['val']['eo'], bests_0['val']['meo'] = eqopp, meqopp, eo, meo
                    all_preds, all_probs, all_labels, all_groups, all_flipped_preds = self._validate(args.C_model, C_0, self.testloader_source, self.testloader_target)
                    acc, bacc, ap = self._compute_accuracy(all_preds, all_probs, all_labels)
                    dp, mdp, sqdp, sdp, eqopp, meqopp, eo, meo = self._compute_fairness(all_preds, all_probs, all_labels, all_groups)
                    bests_0['test']['acc'], bests_0['test']['bacc'], bests_0['test']['ap'] = acc, bacc, ap
                    bests_0['test']['dp'], bests_0['test']['mdp'], bests_0['test']['sqdp'], bests_0['test']['sdp'] = dp, mdp, sqdp, sdp
                    bests_0['test']['eqopp'], bests_0['test']['meqopp'], bests_0['test']['eo'], bests_0['test']['meo'] = eqopp, meqopp, eo, meo


                """ C_1 """
                C_1.network.eval()
                all_preds, all_probs, all_labels, all_groups, all_flipped_preds = self._validate(args.C_model, C_1, self.valloader_source, self.valloader_target)
                acc, bacc, ap = self._compute_accuracy(all_preds, all_probs, all_labels)
                dp, mdp, sqdp, sdp, eqopp, meqopp, eo, meo = self._compute_fairness(all_preds, all_probs, all_labels, all_groups)
                better = acc >= bests_1['val']['acc'] if args.lmda_f == 0.0 else (acc - mdp) >= (bests_1['val']['acc'] - bests_1['val']['mdp'])
                non_trivial = (dp != 0.0) and (dp != 1.0)
                if better and non_trivial:
                    print(f'Validation [{i}/{args.C_epochs}] Got BEST!')
                    bests_1['val']['acc'], bests_1['val']['bacc'], bests_1['val']['ap'] = acc, bacc, ap
                    bests_1['val']['dp'], bests_1['val']['mdp'], bests_1['val']['sqdp'], bests_1['val']['sdp'] = dp, mdp, sqdp, sdp
                    bests_1['val']['eqopp'], bests_1['val']['meqopp'], bests_1['val']['eo'], bests_1['val']['meo'] = eqopp, meqopp, eo, meo
                    all_preds, all_probs, all_labels, all_groups, all_flipped_preds = self._validate(args.C_model, C_1, self.testloader_source, self.testloader_target)
                    acc, bacc, ap = self._compute_accuracy(all_preds, all_probs, all_labels)
                    dp, mdp, sqdp, sdp, eqopp, meqopp, eo, meo = self._compute_fairness(all_preds, all_probs, all_labels, all_groups)
                    bests_1['test']['acc'], bests_1['test']['bacc'], bests_1['test']['ap'] = acc, bacc, ap
                    bests_1['test']['dp'], bests_1['test']['mdp'], bests_1['test']['sqdp'], bests_1['test']['sdp'] = dp, mdp, sqdp, sdp
                    bests_1['test']['eqopp'], bests_1['test']['meqopp'], bests_1['test']['eo'], bests_1['test']['meo'] = eqopp, meqopp, eo, meo

                    
        with open(self.result_dir + 'best_results_0.pickle', 'wb') as f:
            pickle.dump(bests_0, f)
            print(f'[STEP 2] results saved at {self.result_dir}best_results_0.pickle')
                    
        with open(self.result_dir + 'best_results_1.pickle', 'wb') as f:
            pickle.dump(bests_1, f)
            print(f'[STEP 2] results saved at {self.result_dir}best_results_1.pickle')