from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


import os
import torch.optim as optim
from sklearn.cluster import KMeans

from data.util import get_dataset, IdxDataset, get_opendataset
from module.loss import *
from module.util import get_model, get_backbone, get_pretrained, remove_fc

from util import *
import shutil
import pandas as pd
import warnings

from PIL import Image


from sklearn.mixture import GaussianMixture
warnings.filterwarnings(action='ignore')
import copy

class Learner(object):
    def __init__(self, args):
        self.args = args

        data2model = {
                        'cmnist': args.model,
                        'cifar10c': "ResNet18",
                        'bar': "ResNet18",
                        'bffhq': "ResNet18",
                        'dogs_and_cats': "ResNet18",
                    }
        data2batch_size = {
                        'cmnist': 256,
                        'cifar10c': 256,
                        'bar': 64,
                        'bffhq': 64,
                        'dogs_and_cats': 64,
                        }
        data2preprocess = {
                        'cmnist': True,
                        'cifar10c': True,
                        'bar': True,
                        'bffhq': True,
                        'dogs_and_cats':True,
                        }
        data2numclasses = {
                        'cmnist': 10,
                        'cifar10c': 10,
                        'bar': 6,
                        'bffhq': 2,
                        'dogs_and_cats':2,
                        }
        run_name = args.exp
        
        self.model = data2model[args.dataset]
        self.pretrain = args.resnet_pretrained
        self.batch_size = data2batch_size[args.dataset]

        # logging directories
        self.log_dir = os.makedirs(os.path.join(args.log_dir, args.dataset, args.exp), exist_ok=True)
        self.log_dir = os.path.join(args.log_dir, args.dataset, args.exp)
        self.result_dir = os.path.join(self.log_dir, "result")
        os.makedirs(self.result_dir, exist_ok=True)
        
        self.device = torch.device(args.device)
        self.logger = logger(self.log_dir, args)
        self.args = args

        self.logger(f'model: {self.model} || dataset: {args.dataset}')
        self.logger(f'working with experiment: {args.exp}...')
        
        
            
        self.train_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="train",
            transform_split="train",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
        )
            
        self.pretrain_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="train",
            transform_split="train",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],            
        )

        self.valid_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="valid",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
        )

        self.test_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="test",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
        )

        if args.use_orbis:
            self.open_dataset = get_opendataset(
                args.dataset,
                data_dir=args.data_dir,
                open_type=args.open_type,
                num_classes = data2numclasses[args.dataset],
            )

        self.num_classes = data2numclasses[args.dataset]
        self.train_num_classes = data2numclasses[args.dataset]
        
        if args.use_orbis:
            self.logger(f"Num class: {self.num_classes} // Num data: {len(self.train_dataset)} // Open data: {len(self.open_dataset)}")
        else:
            self.logger(f"Num class: {self.num_classes} // Num data: {len(self.train_dataset)}")
        
        self.train_dataset = IdxDataset(self.train_dataset)
        self.pretrain_dataset = IdxDataset(self.pretrain_dataset)

        # make loader
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True
        )

        self.pretrain_train_loader = DataLoader(
            self.pretrain_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True
        )

        self.pretrain_valid_loader = DataLoader(
            self.pretrain_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False
        )

        self.valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        if args.use_orbis:
            self.open_loader = DataLoader(
                self.open_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=args.num_workers,
                pin_memory=True,
                drop_last=True,
            )


        # define loss
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.logger(f'self.criterion: {self.criterion}')

        self.bias_criterion = GeneralizedCELoss(q=0.7)
        self.logger(f'self.bias_criterion: {self.bias_criterion}')

        self.cont_criterion = SupConLoss()
        self.logger(f'self.cont_criterion: {self.cont_criterion}')
        


        # define model and optimizer
        # self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        # self.model_d = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        self.model_b = get_model(self.model, self.num_classes,self.pretrain).to(self.device)
        self.model_d = get_model(self.model, self.num_classes,self.pretrain).to(self.device)
        

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )
        self.optimizer_d = torch.optim.Adam(
                self.model_d.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )
        
        
        self.best_valid_acc_b, self.best_test_acc_b = 0., 0.
        self.best_valid_acc_d, self.best_test_acc_d = 0., 0.

        if self.args.use_orbis:
            self.orbis(args)

        train_target_attr = self.train_loader.dataset.get_labels()
        
        self.sample_loss_ema_b = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=args.ema_alpha)
        self.sample_loss_ema_d = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=args.ema_alpha)

        self.logger('finished model initialization....')

    # Evaluate per type 
    def evaluate_type(self, model, data_loader, loader_type):
        model.eval()
        
        all_correct = torch.zeros(len(data_loader.dataset))
        all_bcorrect = torch.zeros(len(data_loader.dataset))
        all_loss = torch.zeros(len(data_loader.dataset))
        all_label = torch.zeros(len(data_loader.dataset))
        all_blabel = torch.zeros(len(data_loader.dataset))

        sum_pred = 0
        for batch in tqdm(data_loader, leave=False):
            if loader_type == 'train':
                index = batch[0]
                data = batch[1]
                attr = batch[2]
            else:
                data = batch[0]
                attr = batch[1]
                index = batch[2]
            
            label = attr[:, 0]
            blabel = attr[:, 1]
            data = data.to(self.device)
            label = label.to(self.device)

            with torch.no_grad():
                logit = model(data)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                correct = (pred == label).long()
                bcorrect = (pred.cpu() == blabel).long()
                loss = self.criterion(logit, label)
                
                sum_pred += pred.sum()
            
            all_correct[index] = correct.float().detach().cpu()
            all_bcorrect[index] = bcorrect.float().detach().cpu()
            all_loss[index] = loss.detach().cpu()
            all_label[index] = label.float().cpu()
            all_blabel[index] = blabel.float()

        algn_pos = torch.where(all_label == all_blabel)[0]
        conf_pos = torch.where(all_label != all_blabel)[0]

        algn_acc = torch.mean(all_correct[algn_pos])
        conf_acc = torch.mean(all_correct[conf_pos])

        algn_bac = torch.mean(all_bcorrect[algn_pos])
        conf_bac = torch.mean(all_bcorrect[conf_pos])

        algn_loss = torch.mean(all_loss[algn_pos])
        conf_loss = torch.mean(all_loss[conf_pos])

        all_acc = torch.mean(all_correct)
        
        model.train()
        return all_acc, algn_acc, algn_bac, algn_loss, conf_acc, conf_bac, conf_loss


    

    # evaluation code for vanilla
    def evaluate(self, model, data_loader, evaltype = 'test'):
        model.eval()
        total_correct, total_num = 0, 0
        for batch in tqdm(data_loader, leave=False):
            if evaltype == 'test':
                data = batch[0]
                attr = batch[1]
                index = batch[2]
            else:
                index = batch[0]
                data = batch[1]
                attr = batch[2]

            label = attr[:, 0]
            data = data.to(self.device)
            label = label.to(self.device)

            with torch.no_grad():
                logit = model(data)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                correct = (pred == label).long()
                total_correct += correct.sum()
                total_num += correct.shape[0]

        accs = total_correct/float(total_num)
        model.train()
        return accs


    # Evaluate per type 
    def evaluate_disent_type(self,model_b, model_d, data_loader, loader_type, model='label'):
        model_b.eval()
        model_d.eval()
        
        all_correct = torch.zeros(len(data_loader.dataset))
        all_bcorrect = torch.zeros(len(data_loader.dataset))
        all_loss = torch.zeros(len(data_loader.dataset))
        all_label = torch.zeros(len(data_loader.dataset))
        all_blabel = torch.zeros(len(data_loader.dataset))

        for batch in tqdm(data_loader, leave=False):
            if loader_type == 'train':
                index = batch[0]
                data = batch[1]
                attr = batch[2]
            else:
                data = batch[0]
                attr = batch[1]
                index = batch[2]
            
            label = attr[:, 0]
            blabel = attr[:, 1]
            data = data.to(self.device)
            label = label.to(self.device)

            with torch.no_grad():
                # if 'cmnist' in self.args.dataset and self.args.model == 'MLP':
                #     z_l = model_d.extract(data)
                #     z_b = model_b.extract(data)
                # else:
                #     z_l, z_b = [], []
                #     hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l))
                #     _ = self.model_d(data)
                #     hook_fn.remove()
                #     z_l = z_l[0]
                #     hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
                #     _ = self.model_b(data)
                #     hook_fn.remove()
                #     z_b = z_b[0]

                z_b = self.model_b(data, feat_ext = True)
                z_l = self.model_d(data, feat_ext = True)

                z_origin = torch.cat((z_l, z_b), dim=1)
                if model == 'bias':
                    pred_label = model_b.fc(z_origin)
                else:
                    pred_label = model_d.fc(z_origin)
                pred = pred_label.data.max(1, keepdim=True)[1].squeeze(1)
                correct = (pred == label).long()
                bcorrect = (pred.cpu() == blabel).long()
                loss = self.criterion(pred_label, label)

            all_correct[index] = correct.float().detach().cpu()
            all_bcorrect[index] = bcorrect.float().detach().cpu()
            all_loss[index] = loss.detach().cpu()
            all_label[index] = label.float().cpu()
            all_blabel[index] = blabel.float()

        algn_pos = torch.where(all_label == all_blabel)[0]
        conf_pos = torch.where(all_label != all_blabel)[0]

        algn_acc = torch.mean(all_correct[algn_pos])
        conf_acc = torch.mean(all_correct[conf_pos])

        algn_bac = torch.mean(all_bcorrect[algn_pos])
        conf_bac = torch.mean(all_bcorrect[conf_pos])

        algn_loss = torch.mean(all_loss[algn_pos])
        conf_loss = torch.mean(all_loss[conf_pos])

        all_acc = torch.mean(all_correct)

        model_b.train()
        model_d.train()
        return all_acc, algn_acc, algn_bac, algn_loss, conf_acc, conf_bac, conf_loss
    
    # evaluation code for disent
    def evaluate_disent(self,model_b, model_d, data_loader, model='label'):
        model_b.eval()
        model_d.eval()

        total_correct, total_num = 0, 0

        for batch in tqdm(data_loader, leave=False):
            data  = batch[0]
            attr  = batch[1]
            index = batch[2] 
            
            label = attr[:, 0]
            # label = attr
            data = data.to(self.device)
            label = label.to(self.device)

            with torch.no_grad():
                # if 'cmnist' in self.args.dataset and self.args.model == 'MLP':
                #     z_l = model_d.extract(data)
                #     z_b = model_b.extract(data)
                # else:
                #     z_l, z_b = [], []
                #     hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l))
                #     _ = self.model_d(data)
                #     hook_fn.remove()
                #     z_l = z_l[0]
                #     hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
                #     _ = self.model_b(data)
                #     hook_fn.remove()
                #     z_b = z_b[0]

                z_b = self.model_b(data, feat_ext = True)
                z_l = self.model_d(data, feat_ext = True)
                z_origin = torch.cat((z_l, z_b), dim=1)
                if model == 'bias':
                    pred_label = model_b.fc(z_origin)
                else:
                    pred_label = model_d.fc(z_origin)
                pred = pred_label.data.max(1, keepdim=True)[1].squeeze(1)
                correct = (pred == label).long()
                total_correct += correct.sum()
                total_num += correct.shape[0]

        accs = total_correct/float(total_num)
        model_b.train()
        model_d.train()
        return accs

    def save_best(self, step):
        model_path = os.path.join(self.result_dir, "best_model_d.th")
        state_dict = {
            'steps': step,
            'state_dict': self.model_d.state_dict(),
            'optimizer': self.optimizer_d.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)

        model_path = os.path.join(self.result_dir, "best_model_b.th")
        state_dict = {
            'steps': step,
            'state_dict': self.model_b.state_dict(),
            'optimizer': self.optimizer_b.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)
    

    def vanilla_acc(self, step, inference=None):
        # check label network
        train_accs_d, train_algn_acc_d, train_algn_bac_d, train_algn_loss_d, train_conf_acc_d, train_conf_bac_d, train_conf_loss_d = self.evaluate_type(self.model_d, self.train_loader, loader_type='train')
        
        valid_accs_d, valid_algn_acc_d, valid_algn_bac_d, valid_algn_loss_d, valid_conf_acc_d, valid_conf_bac_d, valid_conf_loss_d = self.evaluate_type(self.model_d, self.valid_loader, loader_type='valid')
        
        test_accs_d, test_algn_acc_d, test_algn_bac_d, test_algn_loss_d, test_conf_acc_d, test_conf_bac_d, test_conf_loss_d = self.evaluate_type(self.model_d, self.test_loader, loader_type='test')


        if inference:
            self.logger(f'test acc: {test_accs_d.item()}')
            import sys
            sys.exit(0)

        if valid_accs_d >= self.best_valid_acc_d:
            self.best_valid_acc_d = valid_accs_d
            # self.save_best(step)
            # self.best_test_acc_d = test_accs_d

        if test_accs_d >= self.best_test_acc_d:
            self.save_best(step)
            self.best_test_acc_d = test_accs_d

        self.logger(f'==========={step}===========')
        self.logger(f'----------Overall----------')
        self.logger(f'valid_d: {valid_accs_d:.4f} || test_d: {test_accs_d:.4f} ')
        self.logger(f'Best Acc: {self.best_test_acc_d:.4f}')
        
        self.logger(f'----------(train) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {train_algn_acc_d:.4f} || Bias Acc: {train_algn_bac_d:.4f} || Loss: {train_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {train_conf_acc_d:.4f} || Bias Acc: {train_conf_bac_d:.4f} || Loss: {train_conf_loss_d:.4f}')
        
        
        self.logger(f'----------(valid) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {valid_algn_acc_d:.4f} || Bias Acc: {valid_algn_bac_d:.4f} || Loss: {valid_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {valid_conf_acc_d:.4f} || Bias Acc: {valid_conf_bac_d:.4f} || Loss: {valid_conf_loss_d:.4f}')

        self.logger(f'----------(test) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {test_algn_acc_d:.4f} || Bias Acc: {test_algn_bac_d:.4f} || Loss: {test_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {test_conf_acc_d:.4f} || Bias Acc: {test_conf_bac_d:.4f} || Loss: {test_conf_loss_d:.4f}')
        self.logger(f'============================')



    def lff_acc(self, step, inference=None):
        # check label network

        train_accs_d, train_algn_acc_d, train_algn_bac_d, train_algn_loss_d, train_conf_acc_d, train_conf_bac_d, train_conf_loss_d = self.evaluate_type(self.model_d, self.train_loader, loader_type='train')
        train_accs_b, train_algn_acc_b, train_algn_bac_b, train_algn_loss_b, train_conf_acc_b, train_conf_bac_b, train_conf_loss_b = self.evaluate_type(self.model_b, self.train_loader, loader_type='train')

        
        valid_accs_d, valid_algn_acc_d, valid_algn_bac_d, valid_algn_loss_d, valid_conf_acc_d, valid_conf_bac_d, valid_conf_loss_d = self.evaluate_type(self.model_d, self.valid_loader, loader_type='valid')
        valid_accs_b, valid_algn_acc_b, valid_algn_bac_b, valid_algn_loss_b, valid_conf_acc_b, valid_conf_bac_b, valid_conf_loss_b = self.evaluate_type(self.model_b, self.valid_loader, loader_type='valid')

        test_accs_d, test_algn_acc_d, test_algn_bac_d, test_algn_loss_d, test_conf_acc_d, test_conf_bac_d, test_conf_loss_d = self.evaluate_type(self.model_d, self.test_loader, loader_type='test')
        test_accs_b, test_algn_acc_b, test_algn_bac_b, test_algn_loss_b, test_conf_acc_b, test_conf_bac_b, test_conf_loss_b = self.evaluate_type(self.model_b, self.test_loader, loader_type='test')

        if inference:
            self.logger(f'test acc: {test_accs_d.item()}')
            import sys
            sys.exit(0)

        if valid_accs_b >= self.best_valid_acc_b:
            self.best_valid_acc_b = valid_accs_b

        if test_accs_b >= self.best_test_acc_b:
            self.best_test_acc_b = test_accs_b

        if valid_accs_d >= self.best_valid_acc_d:
            self.best_valid_acc_d = valid_accs_d
            # self.best_test_acc_d = test_accs_d
            # self.save_best(step)

        if test_accs_d >= self.best_test_acc_d:
            self.best_test_acc_d = test_accs_d
            self.save_best(step)

        self.logger(f'==========={step}===========')
        self.logger(f'----------Overall----------')
        self.logger(f'valid_b: {valid_accs_b:.4f} || test_b: {test_accs_b:.4f} ')
        self.logger(f'valid_d: {valid_accs_d:.4f} || test_d: {test_accs_d:.4f} ')
        self.logger(f'Best Acc: {self.best_test_acc_d:.4f}')
        
        self.logger(f'----------(train) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {train_algn_acc_b:.4f} || Bias Acc: {train_algn_bac_b:.4f} || Loss: {train_algn_loss_b:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {train_conf_acc_b:.4f} || Bias Acc: {train_conf_bac_b:.4f} || Loss: {train_conf_loss_b:.4f}')

        
        self.logger(f'----------(train) Debias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {train_algn_acc_d:.4f} || Bias Acc: {train_algn_bac_d:.4f} || Loss: {train_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {train_conf_acc_d:.4f} || Bias Acc: {train_conf_bac_d:.4f} || Loss: {train_conf_loss_d:.4f}')

        self.logger(f'----------(train) Relative Diff. ----------')
        self.logger(f'[{"Aligned":10s}] Relative Difficulty: {train_algn_loss_b / (train_algn_loss_b + train_algn_loss_d):.4f}')        
        self.logger(f'[{"Conflict":10s}] Relative Difficulty: {train_conf_loss_b / (train_conf_loss_b + train_conf_loss_d):.4f}')

    
        self.logger(f'----------(valid) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {valid_algn_acc_b:.4f} || Bias Acc: {valid_algn_bac_b:.4f} || Loss: {valid_algn_loss_b:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {valid_conf_acc_b:.4f} || Bias Acc: {valid_conf_bac_b:.4f} || Loss: {valid_conf_loss_b:.4f}')

        self.logger(f'----------(valid) Debias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {valid_algn_acc_d:.4f} || Bias Acc: {valid_algn_bac_d:.4f} || Loss: {valid_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {valid_conf_acc_d:.4f} || Bias Acc: {valid_conf_bac_d:.4f} || Loss: {valid_conf_loss_d:.4f}')

        self.logger(f'----------(valid) Relative Diff. ----------')
        self.logger(f'[{"Aligned":10s}] Relative Difficulty: {valid_algn_loss_b / (valid_algn_loss_b + valid_algn_loss_d):.4f}')        
        self.logger(f'[{"Conflict":10s}] Relative Difficulty: {valid_conf_loss_b / (valid_conf_loss_b + valid_conf_loss_d):.4f}')


        self.logger(f'----------(test) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {test_algn_acc_b:.4f} || Bias Acc: {test_algn_bac_b:.4f} || Loss: {test_algn_loss_b:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {test_conf_acc_b:.4f} || Bias Acc: {test_conf_bac_b:.4f} || Loss: {test_conf_loss_b:.4f}')

        self.logger(f'----------(test) Debias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {test_algn_acc_d:.4f} || Bias Acc: {test_algn_bac_d:.4f} || Loss: {test_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {test_conf_acc_d:.4f} || Bias Acc: {test_conf_bac_d:.4f} || Loss: {test_conf_loss_d:.4f}')

        self.logger(f'----------(test) Relative Diff. ----------')
        self.logger(f'[{"Aligned":10s}] Relative Difficulty: {test_algn_loss_b / (test_algn_loss_b + test_algn_loss_d):.4f}')        
        self.logger(f'[{"Conflict":10s}] Relative Difficulty: {test_conf_loss_b / (test_conf_loss_b + test_conf_loss_d):.4f}')
        self.logger(f'============================')


    def disent_acc(self, step, inference=None):
        train_accs_d, train_algn_acc_d, train_algn_bac_d, train_algn_loss_d, train_conf_acc_d, train_conf_bac_d, train_conf_loss_d = self.evaluate_disent_type(self.model_b, self.model_d, self.train_loader, loader_type='train', model='label')
        
        train_accs_b, train_algn_acc_b, train_algn_bac_b, train_algn_loss_b, train_conf_acc_b, train_conf_bac_b, train_conf_loss_b = self.evaluate_disent_type(self.model_b, self.model_d, self.train_loader, loader_type='train',model='bias')


        valid_accs_d, valid_algn_acc_d, valid_algn_bac_d, valid_algn_loss_d, valid_conf_acc_d, valid_conf_bac_d, valid_conf_loss_d = self.evaluate_disent_type(self.model_b, self.model_d, self.valid_loader, loader_type='valid', model='label')
        
        valid_accs_b, valid_algn_acc_b, valid_algn_bac_b, valid_algn_loss_b, valid_conf_acc_b, valid_conf_bac_b, valid_conf_loss_b = self.evaluate_disent_type(self.model_b, self.model_d, self.valid_loader, loader_type='valid',model='bias')

        test_accs_d, test_algn_acc_d, test_algn_bac_d, test_algn_loss_d, test_conf_acc_d, test_conf_bac_d, test_conf_loss_d = self.evaluate_disent_type(self.model_b, self.model_d, self.test_loader, loader_type='test', model='label')
        
        test_accs_b, test_algn_acc_b, test_algn_bac_b, test_algn_loss_b, test_conf_acc_b, test_conf_bac_b, test_conf_loss_b = self.evaluate_disent_type(self.model_b, self.model_d, self.test_loader, loader_type='test',model='bias')




        if inference:
            self.logger(f'test acc: {test_accs_d.item()}')
            import sys
            sys.exit(0)

        if valid_accs_b >= self.best_valid_acc_b:
            self.best_valid_acc_b = valid_accs_b

        if test_accs_b >= self.best_test_acc_b:
            self.best_test_acc_b = test_accs_b

        if valid_accs_d > self.best_valid_acc_d:
            self.best_valid_acc_d = valid_accs_d
            # self.best_test_acc_d = test_accs_d
            # self.save_best(step)

        if test_accs_d >= self.best_test_acc_d:
            self.best_test_acc_d = test_accs_d
            self.save_best(step)

        self.logger(f'==========={step}===========')
        self.logger(f'----------Overall----------')
        self.logger(f'valid_b: {valid_accs_b:.4f} || test_b: {test_accs_b:.4f} ')
        self.logger(f'valid_d: {valid_accs_d:.4f} || test_d: {test_accs_d:.4f} ')
        self.logger(f'Best Acc: {self.best_test_acc_d:.4f}')
        
        self.logger(f'----------(train) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {train_algn_acc_b:.4f} || Bias Acc: {train_algn_bac_b:.4f} || Loss: {train_algn_loss_b:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {train_conf_acc_b:.4f} || Bias Acc: {train_conf_bac_b:.4f} || Loss: {train_conf_loss_b:.4f}')

        
        self.logger(f'----------(train) Debias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {train_algn_acc_d:.4f} || Bias Acc: {train_algn_bac_d:.4f} || Loss: {train_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {train_conf_acc_d:.4f} || Bias Acc: {train_conf_bac_d:.4f} || Loss: {train_conf_loss_d:.4f}')


        self.logger(f'----------(train) Relative Diff. ----------')
        self.logger(f'[{"Aligned":10s}] Relative Difficulty: {train_algn_loss_b / (train_algn_loss_b + train_algn_loss_d):.4f}')        
        self.logger(f'[{"Conflict":10s}] Relative Difficulty: {train_conf_loss_b / (train_conf_loss_b + train_conf_loss_d):.4f}')

    
        self.logger(f'----------(valid) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {valid_algn_acc_b:.4f} || Loss: {valid_algn_loss_b:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {valid_conf_acc_b:.4f} || Loss: {valid_conf_loss_b:.4f}')

        self.logger(f'----------(valid) Debias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {valid_algn_acc_d:.4f} || Loss: {valid_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {valid_conf_acc_d:.4f} || Loss: {valid_conf_loss_d:.4f}')

        self.logger(f'----------(valid) Relative Diff. ----------')
        self.logger(f'[{"Aligned":10s}] Relative Difficulty: {valid_algn_loss_b / (valid_algn_loss_b + valid_algn_loss_d):.4f}')        
        self.logger(f'[{"Conflict":10s}] Relative Difficulty: {valid_conf_loss_b / (valid_conf_loss_b + valid_conf_loss_d):.4f}')


        self.logger(f'----------(test) Bias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {test_algn_acc_b:.4f} || Loss: {test_algn_loss_b:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {test_conf_acc_b:.4f} || Loss: {test_conf_loss_b:.4f}')

        self.logger(f'----------(test) Debias Statistics----------')
        self.logger(f'[{"Aligned":10s}] Acc: {test_algn_acc_d:.4f} || Loss: {test_algn_loss_d:.4f}')        
        self.logger(f'[{"Conflict":10s}] Acc: {test_conf_acc_d:.4f} || Loss: {test_conf_loss_d:.4f}')

        self.logger(f'----------(test) Relative Diff. ----------')
        self.logger(f'[{"Aligned":10s}] Relative Difficulty: {test_algn_loss_b / (test_algn_loss_b + test_algn_loss_d):.4f}')        
        self.logger(f'[{"Conflict":10s}] Relative Difficulty: {test_conf_loss_b / (test_conf_loss_b + test_conf_loss_d):.4f}')
        self.logger(f'============================')



    def pretrain_best_acc(self, i, model_b, best_valid_acc_b, step):
        # check label network
        valid_accs_b = self.evaluate(model_b, self.valid_loader)

        self.logger(f'best: {best_valid_acc_b:.4f}, curr: {valid_accs_b:.4f}')

        if valid_accs_b > best_valid_acc_b:
            best_valid_acc_b = valid_accs_b

            ######### copy parameters #########
            self.best_model_b = copy.deepcopy(model_b)
            self.logger(f'early model {i}th saved...')

        log_dict = {
            f"{i}_pretrain_best_valid_acc": best_valid_acc_b,
        }
        return best_valid_acc_b
 

    def concat_dummy(self, z):
        def hook(model, input, output):
            z.append(output.squeeze())
            return torch.cat((output, torch.zeros_like(output)), dim=1)
        return hook


    def pretrain_b_ensemble_best(self, args):
        train_iter = iter(self.pretrain_train_loader)
        train_num = len(self.pretrain_dataset.dataset)
        epoch, cnt = 0, 0
        index_dict, label_dict, gt_prob_dict = {}, {}, {}

        for i in range(self.args.num_bias_models):
            best_valid_acc_b = 0
            self.logger(f'{i}th model working ...')
            del self.model_b
            self.best_model_b = None
            # self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained, first_stage=True).to(self.device)
            self.model_b = get_backbone(self.model, self.train_num_classes, args=self.args, pretrained=self.args.resnet_pretrained, first_stage=True).to(self.device)
            self.optimizer_b = torch.optim.Adam(self.model_b.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
            for step in tqdm(range(self.args.biased_model_train_iter)):
                try:
                    index, data, attr, _, _ = next(train_iter)
                except:
                    train_iter = iter(self.pretrain_train_loader)
                    index, data, attr, _, _ = next(train_iter)

                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]

                logit_b = self.model_b(data)
                loss_b_update = self.bias_criterion(logit_b, label)
                loss = loss_b_update.mean()

                self.optimizer_b.zero_grad()
                loss.backward()
                self.optimizer_b.step()

                cnt += len(index)
                if cnt >= train_num:
                    self.logger(f'finished epoch: {epoch}')
                    epoch += 1
                    cnt = len(index)

                if step % args.valid_freq == 0:
                    best_valid_acc_b = self.pretrain_best_acc(i, self.model_b, best_valid_acc_b, step)
                    
 
            label_list, bias_list, pred_list, index_list, gt_prob_list, align_flag_list = [], [], [], [], [], []
            self.best_model_b.eval()

            for index, data, attr, _, _ in self.pretrain_valid_loader:
                index = index.to(self.device)
                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                bias_label = attr[:, args.bias_attr_idx]

                logit_b = self.best_model_b(data)
                prob = torch.softmax(logit_b, dim=-1)
                gt_prob = torch.gather(prob, index=label.unsqueeze(1), dim=1).squeeze(1)

                label_list += label.tolist()
                index_list += index.tolist()
                gt_prob_list += gt_prob.tolist()
                align_flag_list += (label == bias_label).tolist()

            index_list = torch.tensor(index_list)
            label_list = torch.tensor(label_list)
            gt_prob_list = torch.tensor(gt_prob_list)
            align_flag_list = torch.tensor(align_flag_list)


            index_l = index_list.clone()
            label_list.clone()[index_l] = label_list
            gt_prob_list.clone()[index_l] = gt_prob_list
            align_flag_list.clone()[index_l] = align_flag_list
            index_list.clone()[index_l] = index_list

            align_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == True)).long()
            conflict_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == False)).long()
            mask = (gt_prob_list > args.biased_model_softmax_threshold).long()

            exceed_align = index_list[align_mask.nonzero().squeeze(1)]
            exceed_conflict = index_list[conflict_mask.nonzero().squeeze(1)]
            exceed_mask = index_list[mask.nonzero().squeeze(1)]

            model_index = i
            index_dict[f'{model_index}_exceed_align'] = exceed_align
            index_dict[f'{model_index}_exceed_conflict'] = exceed_conflict
            index_dict[f'{model_index}_exceed_mask'] = exceed_mask
            label_dict[model_index] = label_list
            gt_prob_dict[model_index] = gt_prob_list

            log_dict = {
                f"{model_index}_exceed_align": len(exceed_align),
                f"{model_index}_exceed_conflict": len(exceed_conflict),
                f"{model_index}_exceed_mask": len(exceed_mask),
            }
            
        exceed_mask = [(gt_prob_dict[i] > args.biased_model_softmax_threshold).long() for i in
                        range(self.args.num_bias_models)]
        exceed_mask_align = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == True)).long() for i in
            range(self.args.num_bias_models)]
        exceed_mask_conflict = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == False)).long() for i in
            range(self.args.num_bias_models)]

        mask_sum = torch.stack(exceed_mask).sum(dim=0)
        mask_sum_align = torch.stack(exceed_mask_align).sum(dim=0)
        mask_sum_conflict = torch.stack(exceed_mask_conflict).sum(dim=0)

        total_exceed_mask = index_list[(mask_sum >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_align = index_list[(mask_sum_align >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_conflict = index_list[(mask_sum_conflict >= self.args.agreement).long().nonzero().squeeze(1)]

        exceed_mask_list = [total_exceed_mask]

        self.logger(f'exceed mask list length: {len(exceed_mask_list)}')
        curr_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                              torch.tensor(total_exceed_mask).long().cuda())
        curr_align_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                    torch.tensor(total_exceed_align).long().cuda())
        curr_conflict_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                       torch.tensor(total_exceed_conflict).long().cuda())
        log_dict = {
            f"total_exceed_align": len(total_exceed_align),
            f"total_exceed_conflict": len(total_exceed_conflict),
            f"total_exceed_mask": len(total_exceed_mask),
        }

        total_exceed_mask = torch.tensor(total_exceed_mask)

        for key, value in log_dict.items():
            self.logger(f"* {key}: {value}")
        self.logger(f"* EXCEED DATA COUNT: {Counter(curr_index_label.squeeze(1).tolist())}")
        self.logger(f"* EXCEED DATA (ALIGN) COUNT: {Counter(curr_align_index_label.squeeze(1).tolist())}")
        self.logger(f"* EXCEED DATA (CONFLICT) COUNT: {Counter(curr_conflict_index_label.squeeze(1).tolist())}")
        return total_exceed_mask

        
    def orbis(self, args):
        if args.open_use:
            print("Relevant sampling ...")
            relevant_log = f'./relevant_log/open_{args.open_type}_target_{args.dataset}_{args.percent}.csv'

            if os.path.isfile(relevant_log):
                print('Relevant log load...')
                alignment = pd.read_csv(relevant_log).to_numpy()
            else:
                print('Relevant log generate...')
                model_open = get_pretrained(self.num_classes, args)
                model_open = remove_fc(model_open)
                model_open.eval()
                all_train_feat = []
                all_train_label = []
                all_train_blabel = []
                with torch.no_grad():
                    for index, data, attr, _, _ in tqdm(self.train_loader):
                        data = data.to(self.device)
                        label = attr[:, args.target_attr_idx]
                        blabel = attr[:, args.bias_attr_idx]
                        index = index
                        
                        feat_d = model_open(data).detach().squeeze().cpu()

                        if len(all_train_feat) == 0:
                            all_train_feat = torch.zeros((len(self.train_loader.dataset), feat_d.shape[1]))
                            all_train_label = torch.zeros(len(self.train_loader.dataset)).long()
                            all_train_blabel = torch.zeros(len(self.train_loader.dataset)).long()
                            
                        all_train_feat[index] = feat_d
                        all_train_label[index] = label
                        all_train_blabel[index] = blabel
                
                
                all_open_feat = []
                with torch.no_grad():
                    for batch in tqdm(self.open_loader):
                        data1, _, index = batch
                        data1 = data1.to(self.device)
                        feat_d = model_open(data1).detach().squeeze().cpu()
                        if len(all_open_feat) == 0:
                            all_open_feat = torch.zeros((len(self.open_loader.dataset), feat_d.shape[1]))
                        all_open_feat[index] = feat_d

                # Class-wise centroid 
                centroid_feat = torch.zeros((self.num_classes, all_train_feat.shape[1]))
                for cidx in range(self.num_classes):
                    pos = torch.where(all_train_label == cidx)[0]
                    centroid_feat[cidx] =  torch.mean(all_train_feat[pos], dim=0)
            


                # Per-sample class l2norm
                centroid_feat_norm = F.normalize(centroid_feat, dim=1)
                all_open_feat_norm = F.normalize(all_open_feat, dim=1)
                alignment = torch.mm(centroid_feat_norm.cuda(), all_open_feat_norm.cuda().T).cpu().numpy()
                
                alignment_df = pd.DataFrame(alignment)
                alignment_df.to_csv(relevant_log, index=False)

            indices = []
            alignment = torch.tensor(alignment)
            for cidx in range(self.num_classes):
                pos = torch.where(alignment[cidx,:] > args.tau)[0].tolist()
                indices.extend(pos)
            
            print(f'Openset sampling ... {len(indices)}')
            self.open_loader.dataset.sampling(indices)
        else:
            self.open_loader.dataset.sampling([])
        train_data = self.train_loader.dataset.dataset.data
        self.open_loader.dataset.add_images(train_data)


    def train_vanilla(self, args):
        self.logger('Training Vanilla ...')
        
        num_updated = 0
        train_iter = iter(self.train_loader)
        train_num = len(self.train_loader.dataset)

        if args.use_orbis:
            open_iter = iter(self.open_loader)

        mask_index = torch.zeros(train_num, 1)
        self.conflicting_index = torch.zeros(train_num, 1)
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0

        if args.use_lr_decay:
            self.scheduler_d = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,gamma=args.lr_gamma)
            
        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _, _  = next(train_iter)
            
            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]

            logit_d = self.model_d(data)
            loss_d = self.criterion(logit_d, label)
            
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d')

            loss = loss_d.mean() 
            

            if args.use_orbis:

                try:
                    open_data1, open_data2, open_index = next(open_iter)
                except:
                    open_iter = iter(self.open_loader)
                    open_data1, open_data2, open_index = next(open_iter)

                open_data1 = open_data1.to(self.device)
                open_data2 = open_data2.to(self.device)
                
                feat_d1 = self.model_d(open_data1, head_ext = True)
                feat_d2 = self.model_d(open_data2, head_ext = True)
                feat_d1 = F.normalize(feat_d1,dim=1)
                feat_d2 = F.normalize(feat_d2,dim=1)
                feat_d = torch.cat((feat_d1.unsqueeze(dim=1), feat_d2.unsqueeze(dim=1)), dim=1)
                cont_loss = self.cont_criterion(feat_d) * args.lbd
                
                if np.isnan(cont_loss.mean().item()):
                    raise NameError('cont_loss')
                
                loss += cont_loss.mean()

            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_d.step()

            if args.use_lr_decay and (step+1) % args.lr_decay_step == 0:
                self.logger('******* learning rate decay .... ********')
                self.logger(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")
            
            if (step+1) % args.valid_freq == 0:
                self.vanilla_acc(step)
                
            cnt += len(index)
            if cnt == train_num:
                self.logger(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0

    def train_lff(self, args):
        self.logger('Training LfF without BE ...')
        
        num_updated = 0
        train_iter = iter(self.train_loader)
        train_num = len(self.train_loader.dataset)
        mask_index = torch.zeros(train_num, 1)
        self.conflicting_index = torch.zeros(train_num, 1)
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0

        if args.use_orbis:
            open_iter = iter(self.open_loader)
        
        self.logger(f'alpha : {self.sample_loss_ema_d.alpha}')
        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,gamma=args.lr_gamma)
            self.scheduler_d = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,gamma=args.lr_gamma)

        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _, _  = next(train_iter)


            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]

            
            flag_conflict = (label != bias_label)
            flag_conflict_index = index[flag_conflict]
            self.conflicting_index[flag_conflict_index] = 1
            self.label_index[index] = label

            logit_b = self.model_b(data)
            logit_d = self.model_d(data)

            loss_b = self.criterion(logit_b, label).cpu().detach()
            loss_d = self.criterion(logit_d, label).cpu().detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d')

            # EMA sample loss
            self.sample_loss_ema_b.update(loss_b, index)
            self.sample_loss_ema_d.update(loss_d, index)

            # class-wise normalize
            loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
            loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b_ema')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d_ema')

            label_cpu = label.cpu()

            # for c in range(self.num_classes):
            for c in torch.unique(label_cpu):
                class_index = np.where(label_cpu == c)[0]
                max_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
                max_loss_d = self.sample_loss_ema_d.max_loss(c)
                loss_b[class_index] /= max_loss_b
                loss_d[class_index] /= max_loss_d

            # re-weighting based on loss value / generalized CE for biased model
            loss_weight = loss_b / (loss_b + loss_d + 1e-8)
            pred = logit_d.data.max(1, keepdim=True)[1].squeeze(1)


            if np.isnan(loss_weight.mean().item()):
                raise NameError('loss_weight')

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)

            loss_b_update = self.bias_criterion(logit_b, label)
            loss_d_update = self.criterion(logit_d, label) * loss_weight.to(self.device)
            
            
            if np.isnan(loss_b_update.mean().item()):
                raise NameError('loss_b_update')

            if np.isnan(loss_d_update.mean().item()):
                raise NameError('loss_d_update')

            loss = loss_b_update.mean() + loss_d_update.mean() 
            num_updated += loss_weight.mean().item() * data.size(0)

            if args.use_orbis:

                try:
                    open_data1, open_data2, open_index = next(open_iter)
                except:
                    open_iter = iter(self.open_loader)
                    open_data1, open_data2, open_index = next(open_iter)

                open_data1 = open_data1.to(self.device)
                open_data2 = open_data2.to(self.device)
                
                feat_d1 = self.model_d(open_data1, head_ext = True)
                feat_d2 = self.model_d(open_data2, head_ext = True)
                feat_d1 = F.normalize(feat_d1,dim=1)
                feat_d2 = F.normalize(feat_d2,dim=1)
                feat_d = torch.cat((feat_d1.unsqueeze(dim=1), feat_d2.unsqueeze(dim=1)), dim=1)
                cont_loss = self.cont_criterion(feat_d)   * args.lbd
                
                if np.isnan(cont_loss.mean().item()):
                    raise NameError('cont_loss')
                
                loss += cont_loss.mean()
            

            self.optimizer_b.zero_grad()
            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_b.step()
            self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_d.step()

            if args.use_lr_decay and (step +1) % args.lr_decay_step == 0:
                self.logger('******* learning rate decay .... ********')
                self.logger(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                self.logger(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

            if (step+1) % args.valid_freq == 0:
                self.lff_acc(step)
                
            cnt += len(index)
            if cnt == train_num:
                self.logger(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0

    def train_disent(self, args):
        epoch, cnt = 0, 0
        self.logger('Training Disent without BE ...')
        train_num = len(self.train_dataset)

        if args.use_orbis:
            open_iter = iter(self.open_loader)

        # self.model_d   : model for predicting intrinsic attributes ((E_i,C_i) in the main paper)
        # self.model_d.fc: fc layer for predicting intrinsic attributes (C_i in the main paper)
        # self.model_b   : model for predicting bias attributes ((E_b, C_b) in the main paper)
        # self.model_b.fc: fc layer for predicting bias attributes (C_b in the main paper)

        if 'cmnist' in args.dataset and args.model == 'MLP':
            model_name = 'mlp_DISENTANGLE'
        else:
            model_name = 'resnet_DISENTANGLE'

        self.model_b = get_model(model_name, self.num_classes,self.pretrain).to(self.device)
        self.model_d = get_model(model_name, self.num_classes,self.pretrain).to(self.device)

        self.optimizer_d = torch.optim.Adam(
            self.model_d.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

        self.optimizer_b = torch.optim.Adam(
            self.model_b.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step, gamma=args.lr_gamma)
            self.scheduler_d = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step, gamma=args.lr_gamma)

        self.bias_criterion = GeneralizedCELoss(q=0.7)

        self.logger(f'criterion: {self.criterion}')
        self.logger(f'bias criterion: {self.bias_criterion}')
        train_iter = iter(self.train_loader)

        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _, _  = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]
            
            # Feature extraction
            # Prediction by concatenating zero vectors (dummy vectors).
            # We do not use the prediction here.
            z_b = self.model_b(data, feat_ext = True)
            z_l = self.model_d(data, feat_ext = True)
            # z_b = []
            # # Use this only for reproducing CIFARC10 of LfF
            # hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
            # _ = self.model_b(data)
            # hook_fn.remove()
            # z_b = z_b[0]
            # z_l = []
            # hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l))
            # _ = self.model_d(data)
            # hook_fn.remove()
            # z_l = z_l[0]



            # z=[z_l, z_b]
            # Gradients of z_b are not backpropagated to z_l (and vice versa) in order to guarantee disentanglement of representation.
            z_conflict = torch.cat((z_l, z_b.detach()), dim=1)
            z_align = torch.cat((z_l.detach(), z_b), dim=1)

            # Prediction using z=[z_l, z_b]
            pred_conflict = self.model_d.fc(z_conflict)
            pred_align = self.model_b.fc(z_align)

            loss_dis_conflict = self.criterion(pred_conflict, label).detach()
            loss_dis_align = self.criterion(pred_align, label).detach()

            # EMA sample loss
            self.sample_loss_ema_d.update(loss_dis_conflict, index)
            self.sample_loss_ema_b.update(loss_dis_align, index)

            # class-wise normalize
            loss_dis_conflict = self.sample_loss_ema_d.parameter[index].clone().detach()
            loss_dis_align = self.sample_loss_ema_b.parameter[index].clone().detach()

            loss_dis_conflict = loss_dis_conflict.to(self.device)
            loss_dis_align = loss_dis_align.to(self.device)

            label_cpu = label.cpu()
            # for c in range(self.num_classes):
            for c in torch.unique(label_cpu):
                class_index = torch.where(label == c)[0].to(self.device)
                max_loss_conflict = self.sample_loss_ema_d.max_loss(c)
                max_loss_align = self.sample_loss_ema_b.max_loss(c)
                loss_dis_conflict[class_index] /= max_loss_conflict
                loss_dis_align[class_index] /= max_loss_align

            loss_weight = loss_dis_align / (loss_dis_align + loss_dis_conflict + 1e-8)                          # Eq.1 (reweighting module) in the main paper
            loss_dis_conflict = self.criterion(pred_conflict, label) * loss_weight.to(self.device) * 0.5             # Eq.2 W(z)CE(C_i(z),y)
            loss_dis_align = self.bias_criterion(pred_align, label)                                             # Eq.2 GCE(C_b(z),y)



            # feature-level augmentation : augmentation after certain iteration (after representation is disentangled at a certain level)
            if step > args.curr_step:
                indices = np.random.permutation(z_b.size(0))
                z_b_swap = z_b[indices]         # z tilde
                label_swap = label[indices]     # y tilde

                # Prediction using z_swap=[z_l, z_b tilde]
                # Again, gradients of z_b tilde are not backpropagated to z_l (and vice versa) in order to guarantee disentanglement of representation.
                z_mix_align = torch.cat((z_l.detach(), z_b_swap), dim=1)
                z_mix_conflict = torch.cat((z_l, z_b_swap.detach()), dim=1)
                
                pred_mix_align = self.model_b.fc(z_mix_align)
                pred_mix_conflict = self.model_d.fc(z_mix_conflict)
                
                loss_swap_conflict = self.criterion(pred_mix_conflict, label) * loss_weight.to(self.device)     # Eq.3 W(z)CE(C_i(z_swap),y)
                loss_swap_align = self.bias_criterion(pred_mix_align, label_swap)                               # Eq.3 GCE(C_b(z_swap),y tilde)
                
                lambda_swap = self.args.lambda_swap  # Eq.3 lambda_swap_b

            else:
                # before feature-level augmentation
                loss_swap_conflict = torch.tensor([0]).float()
                loss_swap_align = torch.tensor([0]).float()
                loss_swap_open = torch.tensor([0]).float()
                lambda_swap = 0


            loss_dis  = loss_dis_conflict.mean() + args.lambda_dis_align * loss_dis_align.mean() # Eq.2 L_dis
            loss_swap = loss_swap_conflict.mean() + args.lambda_swap_align * loss_swap_align.mean()  # Eq.3 L_swap
            loss = loss_dis + lambda_swap * loss_swap       # Eq.4 Total objective

            if args.use_orbis:

                try:
                    open_data1, open_data2, open_index = next(open_iter)
                except:
                    open_iter = iter(self.open_loader)
                    open_data1, open_data2, open_index = next(open_iter)

                open_data1 = open_data1.to(self.device)
                open_data2 = open_data2.to(self.device)
                
                feat_d1 = self.model_d(open_data1, head_ext = True)
                feat_d2 = self.model_d(open_data2, head_ext = True)
                feat_d1 = F.normalize(feat_d1,dim=1)
                feat_d2 = F.normalize(feat_d2,dim=1)
                feat_d = torch.cat((feat_d1.unsqueeze(dim=1), feat_d2.unsqueeze(dim=1)), dim=1)
                cont_loss = self.cont_criterion(feat_d)   * args.lbd
                
                if np.isnan(cont_loss.mean().item()):
                    raise NameError('cont_loss')
                
                loss += cont_loss.mean() 

            self.optimizer_d.zero_grad()
            self.optimizer_b.zero_grad()
            loss.backward()
            self.optimizer_d.step()
            self.optimizer_b.step()

            if step >= args.curr_step and args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_d.step()

            if args.use_lr_decay and (step+1) % args.lr_decay_step == 0:
                self.logger('******* learning rate decay .... ********')
                self.logger(f"self.optimizer_b lr: { self.optimizer_b.param_groups[-1]['lr']}")
                self.logger(f"self.optimizer_d lr: { self.optimizer_d.param_groups[-1]['lr']}")

            if (step+1) % args.valid_freq == 0:
                self.disent_acc(step)

            cnt += data.shape[0]
            if cnt == train_num:
                self.logger(f'finished epoch: {epoch}')
                epoch += 1
                cnt = 0


    def train_lff_be(self, args):
        self.logger('Training LfF with BiasEnsemble ...')
        
        num_updated = 0
        train_iter = iter(self.train_loader)
        train_num = len(self.train_loader.dataset)

        mask_index = torch.zeros(train_num, 1)
        self.conflicting_index = torch.zeros(train_num, 1)
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0
        if args.use_orbis:
            open_iter = iter(self.open_loader)
        
        #### BiasEnsemble ####
        pseudo_align_flag = self.pretrain_b_ensemble_best(args)
        mask_index[pseudo_align_flag] = 1

        del self.model_b
        self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,gamma=args.lr_gamma)
            self.scheduler_d = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,gamma=args.lr_gamma)

        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _, _  = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]

            flag_conflict = (label != bias_label)
            flag_conflict_index = index[flag_conflict]
            self.conflicting_index[flag_conflict_index] = 1
            self.label_index[index] = label

            logit_b = self.model_b(data)
            logit_d = self.model_d(data)

            loss_b = self.criterion(logit_b, label).cpu().detach()
            loss_d = self.criterion(logit_d, label).cpu().detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d')

            # EMA sample loss
            self.sample_loss_ema_b.update(loss_b, index)
            self.sample_loss_ema_d.update(loss_d, index)

            # class-wise normalize
            loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
            loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b_ema')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d_ema')

            label_cpu = label.cpu()

            # for c in range(self.num_classes):
            for c in torch.unique(label_cpu):
                class_index = np.where(label_cpu == c)[0]
                max_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
                max_loss_d = self.sample_loss_ema_d.max_loss(c)
                loss_b[class_index] /= max_loss_b
                loss_d[class_index] /= max_loss_d

            # re-weighting based on loss value / generalized CE for biased model
            loss_weight = loss_b / (loss_b + loss_d + 1e-8)
            pred = logit_d.data.max(1, keepdim=True)[1].squeeze(1)


            if np.isnan(loss_weight.mean().item()):
                raise NameError('loss_weight')

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)

            loss_b_update = self.criterion(logit_b[curr_align_flag], label[curr_align_flag])
            loss_d_update = self.criterion(logit_d, label) * loss_weight.to(self.device)

            
            if np.isnan(loss_b_update.mean().item()):
                raise NameError('loss_b_update')

            if np.isnan(loss_d_update.mean().item()):
                raise NameError('loss_d_update')

            loss = loss_b_update.mean() + loss_d_update.mean()
            num_updated += loss_weight.mean().item() * data.size(0)

            if args.use_orbis:

                try:
                    open_data1, open_data2, open_index = next(open_iter)
                except:
                    open_iter = iter(self.open_loader)
                    open_data1, open_data2, open_index = next(open_iter)

                open_data1 = open_data1.to(self.device)
                open_data2 = open_data2.to(self.device)
                
                feat_d1 = self.model_d(open_data1, head_ext = True)
                feat_d2 = self.model_d(open_data2, head_ext = True)
                feat_d1 = F.normalize(feat_d1,dim=1)
                feat_d2 = F.normalize(feat_d2,dim=1)
                feat_d = torch.cat((feat_d1.unsqueeze(dim=1), feat_d2.unsqueeze(dim=1)), dim=1)
                cont_loss = self.cont_criterion(feat_d) * args.lbd
                
                if np.isnan(cont_loss.mean().item()):
                    raise NameError('cont_loss')
                
                loss += cont_loss.mean() 

            self.optimizer_b.zero_grad()
            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_b.step()
            self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_d.step()

            if args.use_lr_decay and (step+1) % args.lr_decay_step == 0:
                self.logger('******* learning rate decay .... ********')
                self.logger(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                self.logger(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")
            
            if (step+1) % args.valid_freq == 0:
                self.lff_acc(step)

                
            cnt += len(index)
            if cnt == train_num:
                self.logger(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0
    

    def train_disent_be(self, args):
        epoch, cnt = 0, 0
        self.logger('Training DisEnt with BiasEnsemble ...')
        train_num = len(self.train_dataset)

        if args.use_orbis:
            open_iter = iter(self.open_loader)

        # self.model_d   : model for predicting intrinsic attributes ((E_i,C_i) in the main paper)
        # self.model_d.fc: fc layer for predicting intrinsic attributes (C_i in the main paper)
        # self.model_b   : model for predicting bias attributes ((E_b, C_b) in the main paper)
        # self.model_b.fc: fc layer for predicting bias attributes (C_b in the main paper)

        #################
        # define models
        #################
        if 'cmnist' in args.dataset and args.model == 'MLP':
            model_name = 'mlp_DISENTANGLE'
        else:
            model_name = 'resnet_DISENTANGLE'

        self.logger(f'criterion: {self.criterion}')
        self.logger(f'bias criterion: {self.bias_criterion}')

        train_iter = iter(self.train_loader)
        train_num = len(self.train_loader.dataset)

        self.conflicting_index = torch.zeros(train_num, 1)
        self.label_index = torch.zeros(train_num).long().cuda()

        mask_index = torch.zeros(train_num, 1)
        epoch, cnt = 0, 0

        #### BiasEnsemble ####
        pseudo_align_flag = self.pretrain_b_ensemble_best(args)

        del self.model_b
        self.model_b = get_model(model_name, self.num_classes,self.pretrain).to(self.device)
        self.model_d = get_model(model_name, self.num_classes,self.pretrain).to(self.device)

        ##################
        # define optimizer
        ##################

        self.optimizer_d = torch.optim.Adam(
            self.model_d.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

        self.optimizer_b = torch.optim.Adam(
            self.model_b.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,
                                                         gamma=args.lr_gamma)
            self.scheduler_d = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,
                                                         gamma=args.lr_gamma)

        mask_index[pseudo_align_flag] = 1

        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _, _  = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]
            
            flag_align, flag_conflict = (label == bias_label), (label != bias_label)

            flag_conflict_index = index[flag_conflict]
            self.conflicting_index[flag_conflict_index] = 1
            self.label_index[index] = label

            # Feature extraction
            # Prediction by concatenating zero vectors (dummy vectors).
            # # We do not use the prediction here.
            # if 'cmnist' in args.dataset and args.model == 'MLP':
            #     z_l = self.model_d.extract(data)
            #     z_b = self.model_b.extract(data)
            # else:
            #     z_b = []
            #     hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
            #     _ = self.model_b(data)
            #     hook_fn.remove()
            #     z_b = z_b[0]

            #     z_l = []
            #     hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l))
            #     _ = self.model_d(data)
            #     hook_fn.remove()

            #     z_l = z_l[0]

            # Feature extraction
            # Prediction by concatenating zero vectors (dummy vectors).
            # We do not use the prediction here.
            z_b = self.model_b(data, feat_ext = True)
            z_l = self.model_d(data, feat_ext = True)
            # z_b = []
            # # Use this only for reproducing CIFARC10 of LfF
            # hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
            # _ = self.model_b(data)
            # hook_fn.remove()
            # z_b = z_b[0]
            # z_l = []
            # hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l))
            # _ = self.model_d(data)
            # hook_fn.remove()
            # z_l = z_l[0]



            # z=[z_l, z_b]
            # Gradients of z_b are not backpropagated to z_l (and vice versa) in order to guarantee disentanglement of representation.
            z_conflict = torch.cat((z_l, z_b.detach()), dim=1)
            z_align = torch.cat((z_l.detach(), z_b), dim=1)

            # Prediction using z=[z_l, z_b]
            pred_conflict = self.model_d.fc(z_conflict)
            pred_align = self.model_b.fc(z_align)

            loss_dis_conflict = self.criterion(pred_conflict, label).detach()
            loss_dis_align = self.criterion(pred_align, label).detach()

            # EMA sample loss
            self.sample_loss_ema_d.update(loss_dis_conflict, index)
            self.sample_loss_ema_b.update(loss_dis_align, index)

            # class-wise normalize
            loss_dis_conflict = self.sample_loss_ema_d.parameter[index].clone().detach()
            loss_dis_align = self.sample_loss_ema_b.parameter[index].clone().detach()

            loss_dis_conflict = loss_dis_conflict.to(self.device)
            loss_dis_align = loss_dis_align.to(self.device)

            label_cpu = label.cpu()
            # for c in range(self.num_classes):
            for c in torch.unique(label_cpu):
                class_index = torch.where(label == c)[0].to(self.device)
                max_loss_conflict = self.sample_loss_ema_d.max_loss(c)
                max_loss_align = self.sample_loss_ema_b.max_loss(c)
                loss_dis_conflict[class_index] /= max_loss_conflict
                loss_dis_align[class_index] /= max_loss_align

            loss_weight = loss_dis_align / (loss_dis_align + loss_dis_conflict + 1e-8)  # Eq.1 (reweighting module) in the main paper
            loss_dis_conflict = self.criterion(pred_conflict, label) * loss_weight.to(self.device)  # Eq.2 W(z)CE(C_i(z),y)

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)
            loss_dis_align = self.criterion(pred_align[curr_align_flag], label[curr_align_flag])


            

            # feature-level augmentation : augmentation after certain iteration (after representation is disentangled at a certain level)
            if step > args.curr_step:
                indices = np.random.permutation(z_b.size(0))
                z_b_swap = z_b[indices]  # z tilde
                label_swap = label[indices]  # y tilde
                curr_align_flag = curr_align_flag[indices]

                # Prediction using z_swap=[z_l, z_b tilde]
                # Again, gradients of z_b tilde are not backpropagated to z_l (and vice versa) in order to guarantee disentanglement of representation.
                z_mix_conflict = torch.cat((z_l, z_b_swap.detach()), dim=1)
                z_mix_align = torch.cat((z_l.detach(), z_b_swap), dim=1)

                # Prediction using z_swap
                pred_mix_conflict = self.model_d.fc(z_mix_conflict)
                pred_mix_align = self.model_b.fc(z_mix_align)


                loss_swap_conflict =  self.criterion(pred_mix_conflict, label) * loss_weight.to(self.device)  # Eq.3 W(z)CE(C_i(z_swap),y)
                loss_swap_align = self.criterion(pred_mix_align[curr_align_flag], label_swap[curr_align_flag])
                lambda_swap = self.args.lambda_swap  # Eq.3 lambda_swap_b

            else:
                # before feature-level augmentation
                loss_swap_conflict = torch.tensor([0]).float()
                loss_swap_align = torch.tensor([0]).float()
                lambda_swap = 0

            loss_dis = loss_dis_conflict.mean() + args.lambda_dis_align * loss_dis_align.mean()  # Eq.2 L_dis
            loss_swap = loss_swap_conflict.mean() + args.lambda_swap_align * loss_swap_align.mean()  # Eq.3 L_swap
            loss = loss_dis + lambda_swap * loss_swap  # Eq.4 Total objective

            if args.use_orbis:

                try:
                    open_data1, open_data2, open_index = next(open_iter)
                except:
                    open_iter = iter(self.open_loader)
                    open_data1, open_data2, open_index = next(open_iter)

                open_data1 = open_data1.to(self.device)
                open_data2 = open_data2.to(self.device)
                
                feat_d1 = self.model_d(open_data1, head_ext = True)
                feat_d2 = self.model_d(open_data2, head_ext = True)
                feat_d1 = F.normalize(feat_d1,dim=1)
                feat_d2 = F.normalize(feat_d2,dim=1)
                feat_d = torch.cat((feat_d1.unsqueeze(dim=1), feat_d2.unsqueeze(dim=1)), dim=1)
                cont_loss = self.cont_criterion(feat_d)   * args.lbd
                
                if np.isnan(cont_loss.mean().item()):
                    raise NameError('cont_loss')
                
                loss += cont_loss.mean() 

            self.optimizer_d.zero_grad()
            self.optimizer_b.zero_grad()
            loss.backward()
            self.optimizer_d.step()
            self.optimizer_b.step()

            
            if step >= args.curr_step and args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_d.step()

            if args.use_lr_decay and (step+1) % args.lr_decay_step == 0:
                self.logger('******* learning rate decay .... ********')
                self.logger(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                self.logger(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

              
            if (step+1) % args.valid_freq == 0:
                self.disent_acc(step)

            cnt += data.shape[0]
            if cnt == train_num:
                self.logger(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0


    def test_lff(self, args):
        if 'cmnist' in self.args.dataset and self.args.model == 'MLP':
            self.model_b = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        else:
            self.model_b = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.model_d.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_d.th'))['state_dict'])
        self.model_b.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_b.th'))['state_dict'])
        self.lff_acc(step=0, inference=True)

    def test_disent(self, args):
        if 'cmnist' in self.args.dataset and self.args.model == 'MLP':
            self.model_d = get_model('mlp_DISENTANGLE', self.num_classes,self.pretrain).to(self.device)
            self.model_b = get_model('mlp_DISENTANGLE', self.num_classes,self.pretrain).to(self.device)
        else:
            self.model_d = get_model('resnet_DISENTANGLE', self.num_classes,self.pretrain).to(self.device)
            self.model_b = get_model('resnet_DISENTANGLE', self.num_classes,self.pretrain).to(self.device)

        self.model_d.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_d.th'))['state_dict'])
        self.model_b.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_b.th'))['state_dict'])
        self.disent_acc(step=0, inference=True)

    def tsne(self, iter):

        # given feat tsne
        from sklearn.manifold import TSNE
        tsne = TSNE(n_components=2)


        model = copy.deepcopy(self.model_d)
        model = remove_fc(model)
        model.eval()

        all_train_feat = []
        all_train_label = []
        all_train_blabel = []
        for index, data, attr, _, _ in tqdm(self.train_loader):
            data = data.to(self.device)
            label = attr[:, self.args.target_attr_idx]
            blabel = attr[:, self.args.bias_attr_idx]
            feat = model(data).squeeze().detach().cpu()

            if len(all_train_feat) == 0:
                all_train_feat = torch.zeros((len(self.train_loader.dataset), feat.shape[1]))
                all_train_label = torch.zeros(len(self.train_loader.dataset)).long()
                all_train_blabel = torch.zeros(len(self.train_loader.dataset)).long()

            all_train_feat[index] = feat
            all_train_label[index] = label
            all_train_blabel[index] = blabel

        all_open_feat = []
        for batch in tqdm(self.open_loader):
            data, _, index, _ = batch    
            data = data.to(self.device)
            
            feat = model(data).squeeze().detach().cpu()

            if len(all_open_feat) == 0:
                all_open_feat = torch.zeros((len(self.open_loader.dataset), feat.shape[1]))
               
            all_open_feat[index] = feat
            
        tsne_feat = torch.cat([all_train_feat, all_open_feat], dim=0)
        tsne_result = tsne.fit_transform(tsne_feat)
        open_result = tsne_result[len(all_train_feat):]
        train_result = tsne_result[:len(all_train_feat)]
        
        import matplotlib.pyplot as plt
        
        align_1 = torch.where((all_train_label == all_train_blabel) & (all_train_label == 0))[0]
        align_2 = torch.where((all_train_label == all_train_blabel) & (all_train_label == 1))[0]
        conflict_1 = torch.where((all_train_label != all_train_blabel) & (all_train_label == 0))[0]
        conflict_2 = torch.where((all_train_label != all_train_blabel) & (all_train_label == 1))[0]
        
        plt.scatter(tsne_result[align_1,0], tsne_result[align_1,1] , alpha = 0.1, label = 'Align class 1')
        plt.scatter(tsne_result[align_2,0], tsne_result[align_2,1] , alpha = 0.1, label = 'Align class 2')
        plt.scatter(tsne_result[conflict_1,0], tsne_result[conflict_1,1] , alpha = 0.3, label = 'Conflict class 1')
        plt.scatter(tsne_result[conflict_2,0], tsne_result[conflict_2,1] , alpha = 0.3, label = 'Conflict class 2')

        plt.scatter(open_result[:,0], open_result[:,1] , alpha = 1.0, label = 'Open 1', marker = '*')
        
        plt.legend()

        os.makedirs('./tsne', exist_ok=True)
        plt.savefig(f'./tsne/tsne_{iter:4d}.png')
        plt.close()