from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import torch.optim as optim

from data.util import get_dataset, IdxDataset
from module.loss import GeneralizedCELoss
from module.util import get_model
from module.util import get_backbone
from util import *

import warnings
warnings.filterwarnings(action='ignore')
import copy
import wandb

from learners.base_learner import *

class DisentShapeLearner(Learner):
    def __init__(self, args):
        super(DisentShapeLearner, self).__init__(args)
        if args.fd == "patch-shuffle":
            self.trans = ShufflePatches(patch_size=args.ps)
        elif args.fd == "pixel-shuffle":
            self.trans = ShufflePatches(patch_size=1)
        elif args.fd == "center-occlude":
            self.trans = CenterOcclude(occlusion_size=args.ps)
        self.unlearn_criterion = nn.CrossEntropyLoss(reduction='none', label_smoothing=1)

    # evaluation code for disent
    def evaluate_disent(self,model_b, model_d, data_loader, model='label'):
        model_b.eval()
        model_d.eval()

        total_correct, total_num = 0, 0
        pred_ls = []
        loss_ls = []
        label_ls = []
        s_ls = []

        for data, attr, index in tqdm(data_loader, leave=False):
            label = attr[:, 0].to(self.device)
            s = attr[:, 1].to(self.device)
            data = data.to(self.device)

            with torch.no_grad():
                if 'mnist' in self.args.dataset and self.args.model == 'MLP':
                    z_l = model_d.extract(data)
                    z_b = model_b.extract(data)
                else:
                    z_l, z_b = [], []
                    hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l))
                    _ = self.model_d(data)
                    hook_fn.remove()
                    z_l = z_l[0]
                    hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
                    _ = self.model_b(data)
                    hook_fn.remove()
                    z_b = z_b[0]
                z_origin = torch.cat((z_l, z_b), dim=1)
                if model == 'bias':
                    pred_label = model_b.fc(z_origin)
                else:
                    pred_label = model_d.fc(z_origin)
                loss = self.criterion(pred_label, label)
                pred = pred_label.data.max(1, keepdim=True)[1].squeeze(1)
                # correct = (pred == label).long()
                # total_correct += correct.sum()
                # total_num += correct.shape[0]
                pred_ls.append(pred)
                loss_ls.append(loss)
                label_ls.append(label)
                s_ls.append(s)
                
        pred = torch.concatenate(pred_ls)
        loss = torch.concatenate(loss_ls)
        label = torch.concatenate(label_ls)
        s = torch.concatenate(s_ls)
        bias = data_loader.dataset.bias
        correct = (pred == label)
        
        # partition according to bias state and calc accs
        bias_state = get_bias_state(label, s, bias).squeeze(1)
        
        accs = []
        losses = []
        for state in [-1, 0, 1]:
            total_correct = correct[bias_state==state].sum()
            total = correct[bias_state==state].shape[0]
            l = loss[bias_state==state].mean()
            if total == 0:
                accs.append(-0.01)
                losses.append(-0.01)
                continue
            accs.append((total_correct/float(total)).item())
            losses.append(l.item())
            
        # calc overall acc
        total_correct = correct.sum()
        total = correct.shape[0]
        overall_acc = total_correct/float(total)
        overall_loss = loss.mean()
        
        accs.append(overall_acc.item())
        losses.append(overall_loss.item())

        # accs = total_correct/float(total_num)
        model_b.train()
        model_d.train()
        return accs, losses

    def board_disent_loss(self, step, loss_dis_conflict, loss_dis_align, loss_swap_conflict, loss_swap_align, lambda_swap):
        
        if self.args.wandb:
            wandb.log({
                "loss/loss_dis_conflict": loss_dis_conflict,
                "loss/loss_dis_align": loss_swap_conflict,
                "loss/loss_swap_conflict": loss_swap_conflict,
                "loss/loss_swap_align": loss_swap_align,
                "loss/loss": (loss_dis_conflict + loss_dis_align) + lambda_swap * (loss_swap_conflict + loss_swap_align)
                }, step=step)
        
        if self.args.tensorboard:
            self.writer.add_scalar(f"loss/loss_dis_conflict",  loss_dis_conflict, step)
            self.writer.add_scalar(f"loss/loss_dis_align",     loss_dis_align, step)
            self.writer.add_scalar(f"loss/loss_swap_conflict", loss_swap_conflict, step)
            self.writer.add_scalar(f"loss/loss_swap_align",    loss_swap_align, step)
            self.writer.add_scalar(f"loss/loss",               (loss_dis_conflict + loss_dis_align) + lambda_swap * (loss_swap_conflict + loss_swap_align), step)

    def board_weights(self, step):
        weights = []
        loss_b_ls = []
        loss_d_ls = []
        hist_b_ls = []
        hist_d_ls = []
        hist_wx_ls = []
        for state in [-1, 0, 1, None]:
            index = torch.ones_like(self.label_index, dtype=torch.bool).to(self.device)
            if state is not None:
                index = self.bias_state.squeeze(1) == state
            label = self.label_index[index]
            # skip if certain subgroup is empty
            if label.shape[0] == 0:
                weights.append(-0.1)
                loss_b_ls.append(-0.1)
                loss_d_ls.append(-0.1)
                hist_b_ls.append(None)
                hist_d_ls.append(None)
                hist_wx_ls.append(None)
                continue
            # class-wise normalize
            loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
            loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()
            loss_b_ls.append(loss_b.mean())
            loss_d_ls.append(loss_d.mean())
            hist_b_ls.append(wandb.Histogram(loss_b.cpu().numpy()))
            hist_d_ls.append(wandb.Histogram(loss_d.cpu().numpy()))
            wx = self.calc_weights(loss_b, loss_d, label)
            weights.append(wx.mean())
            hist_wx_ls.append(wandb.Histogram(wx.cpu().numpy()))
            

        weights = trans_result(weights)
        loss_b_ls = trans_result(loss_b_ls)
        loss_d_ls = trans_result(loss_d_ls)
        hist_b_ls = trans_result(hist_b_ls)
        hist_d_ls = trans_result(hist_d_ls)
        hist_wx_ls = trans_result(hist_wx_ls)

        log_dict = {
                    "w(x)_mean": weights,
                    "loss/loss_b": loss_b_ls,
                    "loss/loss_d": loss_d_ls,
                    "hist/hist_loss_b": hist_b_ls,
                    "hist/hist_loss_d": hist_d_ls,
                    "hist/hist_weights": hist_wx_ls,
                }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def calc_weights(self, loss_b, loss_d, label):
        if np.isnan(loss_b.mean().item()):
            raise NameError('loss_b_ema')
        if np.isnan(loss_d.mean().item()):
            raise NameError('loss_d_ema')

        label_cpu = label.cpu()

        for c in range(self.num_classes):
            class_index = np.where(label_cpu == c)[0]
            if "_max" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.max_loss(c)
            elif "_mean" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.mean_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.mean_loss(c)
            elif "_1-x" in self.args.w_func:
                x_b = torch.exp(-loss_b)
                x_d = torch.exp(-loss_d)
                loss_b = 1 - x_b
                loss_d = 1 - x_d
                break
            else:
                break
            loss_b[class_index] /= norm_loss_b
            loss_d[class_index] /= norm_loss_d
        
        scale = 2 if 's2' in self.args.w_func else 1
        
        if self.args.w_func.startswith('frac_'):
            loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8)
        if self.args.w_func.startswith('frac+0.01_'):
            loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8) + 0.01
        if self.args.w_func.startswith('b/(1-a+b)_'):
            loss_weight = scale * loss_b / (loss_b + 1 - loss_d + 1e-8)
        elif self.args.w_func.startswith('b_'):
            loss_weight = scale * loss_b
        elif self.args.w_func.startswith('b+0.01_'):
            loss_weight = scale * loss_b + 0.01
        elif self.args.w_func.startswith('e^(2(b-1))_'):
            pow = 2 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('log(b+1)_'):
            loss_weight = scale * torch.log(loss_b+1)
        elif self.args.w_func.startswith('e^(-x/(b+x))_'):
            loss_d *= norm_loss_d # denormalize
            x = torch.exp(-loss_d)   # recover probability
            pow = -x/(loss_b+x)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-a/(b+a))_'):
            pow = -loss_d/(loss_b+loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a))_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('be^(-1/(b+a))_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+2(b-1))_'):
            pow = -1/(loss_b+loss_d) + 2 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+3(b-1))_'):
            pow = -1/(loss_b+loss_d) + 3 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+5(b-1))_'):
            pow = -1/(loss_b+loss_d) + 5 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('be^(-1/(b+a))+0.01_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow) + 0.01
        elif self.args.w_func.startswith('be^(-1/(b+a))+0.1_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow) + 0.1

        elif self.args.w_func.startswith('e^(-1/(b+0.5a))_'):
            pow = -1/(loss_b+0.5*loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+0.3a))_'):
            pow = -1/(loss_b+0.3*loss_d)
            loss_weight = scale * torch.exp(pow)

        return loss_weight

    def board_pretrain_best_acc(self, i, model_b, best_valid_acc_b, step):
        # check label network
        valid_accs_b, valid_losses_b = self.evaluate(model_b, self.valid_loader)

        print(f'best: {best_valid_acc_b}, curr: {valid_accs_b}')

        if valid_accs_b[-1] > best_valid_acc_b[-1]:
            best_valid_acc_b = valid_accs_b

            ######### copy parameters #########
            self.best_model_b = copy.deepcopy(model_b)
            print(f'early model {i}th saved...')

        log_dict = {
            f"{i}_pretrain_best_valid_acc": trans_result(best_valid_acc_b),
        }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        # if self.args.wandb:
        #     wandb.log(log_dict, step=step)

        return best_valid_acc_b
 
    def board_disent_acc(self, step, inference=None):
        # check label network
        valid_accs_d, valid_losses_d = self.evaluate_disent(self.model_b, self.model_d, self.valid_loader, model='label')
        test_accs_d, test_losses_d = self.evaluate_disent(self.model_b, self.model_d, self.test_loader, model='label')
        
        valid_accs_b, valid_losses_b = self.evaluate_disent(self.model_b, self.model_d, self.valid_loader, model='bias')
        test_accs_b, test_losses_b = self.evaluate_disent(self.model_b, self.model_d, self.test_loader, model='bias')
        
        if inference:
            print(f'test acc: {test_accs_d.item()}')
            import sys
            sys.exit(0)

        if valid_accs_b[-1] >= self.best_valid_acc_b[-1]:
            self.best_valid_acc_b = valid_accs_b
            self.best_valid_loss_b = valid_losses_b

        if test_accs_b[-1] >= self.best_test_acc_b[-1]:
            self.best_test_acc_b = test_accs_b
            self.best_test_loss_b = test_losses_b

        if valid_accs_d[-1] > self.best_valid_acc_d[-1]:
            self.best_valid_acc_d = valid_accs_d
            self.best_valid_loss_d = valid_losses_d

        if test_accs_d[-1] >= self.best_test_acc_d[-1]:
            self.best_test_acc_d = test_accs_d
            self.best_test_loss_d = test_losses_d
            self.save_best(step)

        if self.args.tensorboard:
            self.writer.add_scalar(f"acc/acc_b_valid", valid_accs_b, step)
            self.writer.add_scalar(f"acc/acc_b_test", test_accs_b, step)
            self.writer.add_scalar(f"acc/acc_d_valid", valid_accs_d, step)
            self.writer.add_scalar(f"acc/acc_d_test", test_accs_d, step)
            self.writer.add_scalar(f"acc/best_acc_b_valid", self.best_valid_acc_b, step)
            self.writer.add_scalar(f"acc/best_acc_b_test", self.best_test_acc_b, step)
            self.writer.add_scalar(f"acc/best_acc_d_valid", self.best_valid_acc_d, step)
            self.writer.add_scalar(f"acc/best_acc_d_test", self.best_test_acc_d, step)
            
        if self.args.wandb:
            wandb.log({
                "acc/acc_b_valid": trans_result(valid_accs_b),
                "acc/acc_b_test": trans_result(test_accs_b),
                "acc/acc_d_valid": trans_result(valid_accs_d),
                "acc/acc_d_test": trans_result(test_accs_d),
                "acc/best_acc_b_valid": trans_result(self.best_valid_acc_b),
                "acc/best_acc_b_test": trans_result(self.best_test_acc_b),
                "acc/best_acc_d_valid": trans_result(self.best_valid_acc_d),
                "acc/best_acc_d_test": trans_result(self.best_test_acc_d),
                
                "loss/loss_b_valid": trans_result(valid_losses_b),
                "loss/loss_b_test": trans_result(test_losses_b),
                "loss/best_loss_b_valid": trans_result(self.best_valid_loss_b),
                "loss/best_loss_b_test": trans_result(self.best_test_loss_b),
                "loss/loss_d_valid": trans_result(valid_losses_d),
                "loss/loss_d_test": trans_result(test_losses_d),
                "loss/best_loss_d_valid": trans_result(self.best_valid_loss_d),
                "loss/best_loss_d_test": trans_result(self.best_test_loss_d)
            }, step=step)

        print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b} ')
        print(f'valid_d: {valid_accs_d} || test_d: {test_accs_d} ')

    def concat_dummy(self, z):
        def hook(model, input, output):
            z.append(output.squeeze())
            return torch.cat((output, torch.zeros_like(output)), dim=1)
        return hook

    def pretrain_b_ensemble_best(self, args):
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)
        epoch, cnt = 0, 0
        index_dict, label_dict, gt_prob_dict = {}, {}, {}

        # train multiple bias models on train_iter, select model with best valid acc. Then get the prediction confidence of biased models on the pretrain_loader data. Samples with gt confidence above a given threshold is marked by their index in exceed_mask. Among those samples, bias aligned sample are marked in exceed_align, while bias conflict are marked in exceed_conflict.
        for i in range(self.args.num_bias_models):
            best_valid_acc_b = [0,0,0,0]
            print(f'{i}th model working ...')
            del self.model_b
            self.best_model_b = None
            self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained, first_stage=True).to(self.device)
            self.optimizer_b = torch.optim.Adam(self.model_b.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
            for step in tqdm(range(self.args.biased_model_train_iter)):
                try:
                    index, data, attr, _ = next(train_iter)
                except:
                    train_iter = iter(self.train_loader)
                    index, data, attr, _ = next(train_iter)

                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]

                if args.pre == "x_":
                    data = data
                elif args.pre == "x'":
                    data = self.trans(data)

                logit_b = self.model_b(data)
                loss_b_update = self.bias_criterion(logit_b, label)
                loss = loss_b_update.mean()

                self.optimizer_b.zero_grad()
                loss.backward()
                self.optimizer_b.step()

                cnt += len(index)
                if cnt >= train_num:
                    print(f'finished epoch: {epoch}')
                    epoch += 1
                    cnt = len(index)

                if step % args.valid_freq == 0:
                    best_valid_acc_b = self.board_pretrain_best_acc(i, self.model_b, best_valid_acc_b, step)

            label_list, bias_list, pred_list, index_list, gt_prob_list, align_flag_list = [], [], [], [], [], []
            self.best_model_b.eval()

            for index, data, attr, _ in self.pretrain_loader:
                index = index.to(self.device)
                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                bias_label = attr[:, args.bias_attr_idx]

                logit_b = self.best_model_b(data)
                prob = torch.softmax(logit_b, dim=-1)
                gt_prob = torch.gather(prob, index=label.unsqueeze(1), dim=1).squeeze(1)

                label_list += label.tolist()
                index_list += index.tolist()
                gt_prob_list += gt_prob.tolist()
                align_flag_list += (label == bias_label).tolist()

            index_list = torch.tensor(index_list)
            label_list = torch.tensor(label_list)
            gt_prob_list = torch.tensor(gt_prob_list)
            align_flag_list = torch.tensor(align_flag_list)

            align_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == True)).long()
            conflict_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == False)).long()
            mask = (gt_prob_list > args.biased_model_softmax_threshold).long()

            exceed_align = index_list[align_mask.nonzero().squeeze(1)]
            exceed_conflict = index_list[conflict_mask.nonzero().squeeze(1)]
            exceed_mask = index_list[mask.nonzero().squeeze(1)]

            model_index = i
            index_dict[f'{model_index}_exceed_align'] = exceed_align
            index_dict[f'{model_index}_exceed_conflict'] = exceed_conflict
            index_dict[f'{model_index}_exceed_mask'] = exceed_mask
            label_dict[model_index] = label_list
            gt_prob_dict[model_index] = gt_prob_list

            log_dict = {
                f"{model_index}_exceed_align": len(exceed_align),
                f"{model_index}_exceed_conflict": len(exceed_conflict),
                f"{model_index}_exceed_mask": len(exceed_mask),
            }
            if args.tensorboard:
                for key, value in log_dict.items():
                    self.writer.add_scalar(key, value, step)

        # aggregate confident samples, ba confident samples, bc conflict samples from all bias models
        exceed_mask = [(gt_prob_dict[i] > args.biased_model_softmax_threshold).long() for i in
                        range(self.args.num_bias_models)]
        exceed_mask_align = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == True)).long() for i in
            range(self.args.num_bias_models)]
        exceed_mask_conflict = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == False)).long() for i in
            range(self.args.num_bias_models)]
        
        # sum to calculate how likely each sample is considered confident in bias models
        mask_sum = torch.stack(exceed_mask).sum(dim=0)
        mask_sum_align = torch.stack(exceed_mask_align).sum(dim=0)
        mask_sum_conflict = torch.stack(exceed_mask_conflict).sum(dim=0)

        # if the number of bias models that consider a sample confident is larger then the threshold args.agreement, the sample is considered the final exceed sample
        total_exceed_mask = index_list[(mask_sum >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_align = index_list[(mask_sum_align >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_conflict = index_list[(mask_sum_conflict >= self.args.agreement).long().nonzero().squeeze(1)]

        exceed_mask_list = [total_exceed_mask]

        print(f'exceed mask list length: {len(exceed_mask_list)}')
        curr_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                              torch.tensor(total_exceed_mask).long().cuda())
        curr_align_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                    torch.tensor(total_exceed_align).long().cuda())
        curr_conflict_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                       torch.tensor(total_exceed_conflict).long().cuda())
        log_dict = {
            f"total_exceed_align": len(total_exceed_align),
            f"total_exceed_conflict": len(total_exceed_conflict),
            f"total_exceed_mask": len(total_exceed_mask),
        }

        total_exceed_mask = torch.tensor(total_exceed_mask)

        for key, value in log_dict.items():
            print(f"* {key}: {value}")
        print(f"* EXCEED DATA COUNT: {Counter(curr_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (ALIGN) COUNT: {Counter(curr_align_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (CONFLICT) COUNT: {Counter(curr_conflict_index_label.squeeze(1).tolist())}")

        if args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)

        return total_exceed_mask

    def train(self, args):
        epoch, cnt = 0, 0
        print('Training DisEnt ...')
        train_num = len(self.train_dataset)

        # self.model_d   : model for predicting intrinsic attributes ((E_i,C_i) in the main paper)
        # self.model_d.fc: fc layer for predicting intrinsic attributes (C_i in the main paper)
        # self.model_b   : model for predicting bias attributes ((E_b, C_b) in the main paper)
        # self.model_b.fc: fc layer for predicting bias attributes (C_b in the main paper)

        #################
        # define models
        #################
        if 'mnist' in args.dataset and args.model == 'MLP':
            model_name = 'mlp_DISENTANGLE'
        else:
            model_name = 'resnet_DISENTANGLE'

        print(f'criterion: {self.criterion}')
        print(f'debias criterion: {self.debias_criterion}')
        print(f'bias criterion: {self.bias_criterion}')

        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)

        self.bias_state = torch.zeros(train_num, 1).cuda()
        self.label_index = torch.zeros(train_num).long().cuda()

        mask_index = torch.zeros(train_num, 1)
        epoch, cnt = 0, 0

        #### BiasEnsemble ####
        if args.bias_ensm:
            print("Applying Bias Ensemble for bias amplification ...")
            pseudo_align_flag = self.pretrain_b_ensemble_best(args)
            mask_index[pseudo_align_flag] = 1
        else:
            mask_index[:] = 1

        del self.model_b
        self.model_b = get_model(model_name, self.num_classes).to(self.device)
        self.model_d = get_model(model_name, self.num_classes).to(self.device)

        ##################
        # define optimizer
        ##################

        self.optimizer_d = torch.optim.Adam(
            self.model_d.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

        self.optimizer_b = torch.optim.Adam(
            self.model_b.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,
                                                         gamma=args.lr_gamma)
            self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,
                                                         gamma=args.lr_gamma)

        bias = self.train_loader.dataset.dataset.bias

        for step in tqdm(range(args.num_steps)):
            try:
                index, data, attr, image_path = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, image_path = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx].to(self.device)
            
            # patch shuffle
            data_s = self.trans(data)

            bias_label = attr[:, args.bias_attr_idx]
            # flag_align, flag_conflict = (label == bias_label), (label != bias_label)
            bias_state = get_bias_state(label, bias_label, bias)
            self.bias_state[index] = bias_state
            self.label_index[index] = label

            # Feature extraction
            # Prediction by concatenating zero vectors (dummy vectors).
            # We do not use the prediction here.
            if 'mnist' in args.dataset and args.model == 'MLP':
                z_l_o = self.model_d.extract(data)
                z_l_s = self.model_d.extract(data_s)
                z_b_o = self.model_b.extract(data)
                z_b_s = self.model_b.extract(data_s)
            else:
                z_b_o = []
                hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b_o))
                _ = self.model_b(data)
                hook_fn.remove()
                z_b_o = z_b_o[0]
                
                z_b_s = []
                hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b_s))
                _ = self.model_b(data_s)
                hook_fn.remove()
                z_b_s = z_b_s[0]

                z_l_o = []
                hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l_o))
                _ = self.model_d(data)
                hook_fn.remove()
                z_l_o = z_l_o[0]
                
                z_l_s = []
                hook_fn = self.model_d.avgpool.register_forward_hook(self.concat_dummy(z_l_s))
                _ = self.model_d(data_s)
                hook_fn.remove()
                z_l_s = z_l_s[0]

            if args.biloss == "x_":
                z_b = z_b_o
                # z_l = z_l_o
            elif args.biloss == "x'":
                z_b = z_b_s
                # z_l = z_l_s
            z_l = z_l_o
            # z=[z_l, z_b]
            # Gradients of z_b are not backpropagated to z_l (and vice versa) in order to guarantee disentanglement of representation.
            
            z_conflict = torch.cat((z_l, z_b.detach()), dim=1)
            z_align = torch.cat((z_l.detach(), z_b), dim=1)
            # z_conflict_s = torch.cat((z_l, z_b_s.detach()), dim=1)
            # z_align_s = torch.cat((z_l.detach(), z_b_s), dim=1)

            # Prediction using z=[z_l, z_b]
            pred_conflict = self.model_d.fc(z_conflict)
            pred_align = self.model_b.fc(z_align)
            # pred_conflict_s = self.model_d.fc(z_conflict_s)
            # pred_align_s = self.model_b.fc(z_align_s)
            

            # if args.biloss == "x_":
            #     loss_dis_align = self.criterion(pred_align, label).detach()
            # elif args.biloss == "x'":
            #     loss_dis_align = self.criterion(pred_align_s, label).detach()
            loss_dis_conflict = self.criterion(pred_conflict, label).detach()
            loss_dis_align = self.criterion(pred_align, label).detach()

            # EMA sample loss
            self.sample_loss_ema_d.update(loss_dis_conflict, index)
            self.sample_loss_ema_b.update(loss_dis_align, index)

            # class-wise normalize
            loss_dis_conflict = self.sample_loss_ema_d.parameter[index].clone().detach()
            loss_dis_align = self.sample_loss_ema_b.parameter[index].clone().detach()

            # loss_dis_conflict = loss_dis_conflict.to(self.device)
            # loss_dis_align = loss_dis_align.to(self.device)
            
            

            # for c in range(self.num_classes):
            #     class_index = torch.where(label == c)[0].to(self.device)
            #     max_loss_conflict = self.sample_loss_ema_d.max_loss(c)
            #     max_loss_align = self.sample_loss_ema_b.max_loss(c)
            #     loss_dis_conflict[class_index] /= max_loss_conflict
            #     loss_dis_align[class_index] /= max_loss_align

            # loss_weight = loss_dis_align / (loss_dis_align + loss_dis_conflict + 1e-8)  # Eq.1 (reweighting module) in the main paper
            loss_weight = self.calc_weights(loss_dis_align, loss_dis_conflict, label)
            loss_dis_conflict = self.debias_criterion(pred_conflict, label) * loss_weight.to(self.device)  # Eq.2 W(z)CE(C_i(z),y)

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)
            
            # if args.biloss == "x_":
            #     loss_dis_align = self.bias_criterion(pred_align[curr_align_flag], label[curr_align_flag])
            # elif args.biloss == "x'":
            #     loss_dis_align = self.bias_criterion(pred_align_s[curr_align_flag], label[curr_align_flag])
            loss_dis_align = self.bias_criterion(pred_align[curr_align_flag], label[curr_align_flag])

            # feature-level augmentation : augmentation after certain iteration (after representation is disentangled at a certain level)
            if step > args.curr_step:
                indices = np.random.permutation(z_b.size(0))
                z_b_swap = z_b[indices]  # z tilde
                label_swap = label[indices]  # y tilde
                curr_align_flag = curr_align_flag[indices]

                # Prediction using z_swap=[z_l, z_b tilde]
                # Again, gradients of z_b tilde are not backpropagated to z_l (and vice versa) in order to guarantee disentanglement of representation.
                z_mix_conflict = torch.cat((z_l, z_b_swap.detach()), dim=1)
                z_mix_align = torch.cat((z_l.detach(), z_b_swap), dim=1)

                # Prediction using z_swap
                pred_mix_conflict = self.model_d.fc(z_mix_conflict)
                pred_mix_align = self.model_b.fc(z_mix_align)

                loss_swap_conflict = self.debias_criterion(pred_mix_conflict, label) * loss_weight.to(self.device)  # Eq.3 W(z)CE(C_i(z_swap),y)
                loss_swap_align = self.bias_criterion(pred_mix_align[curr_align_flag], label_swap[curr_align_flag])
                lambda_swap = self.args.lambda_swap  # Eq.3 lambda_swap_b

            else:
                # before feature-level augmentation
                loss_swap_conflict = torch.tensor([0]).float()
                loss_swap_align = torch.tensor([0]).float()
                lambda_swap = 0

            loss_dis = loss_dis_conflict.mean() + args.lambda_dis_align * loss_dis_align.mean()  # Eq.2 L_dis
            loss_swap = loss_swap_conflict.mean() + args.lambda_swap_align * loss_swap_align.mean()  # Eq.3 L_swap
            loss = loss_dis + lambda_swap * loss_swap  # Eq.4 Total objective

            self.optimizer_d.zero_grad()
            self.optimizer_b.zero_grad()
            loss.backward()
            self.optimizer_d.step()
            self.optimizer_b.step()

            if step >= args.curr_step and args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_l.step()

            if args.use_lr_decay and step % args.lr_decay_step == 0:
                print('******* learning rate decay .... ********')
                print(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                print(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

            if step % args.log_freq == 0:
                self.board_disent_loss(
                    step=step,
                    loss_dis_conflict=loss_dis_conflict.mean(),
                    loss_dis_align=args.lambda_dis_align * loss_dis_align.mean(),
                    loss_swap_conflict=loss_swap_conflict.mean(),
                    loss_swap_align=args.lambda_swap_align * loss_swap_align.mean(),
                    lambda_swap=lambda_swap
                )

                bias_label = attr[:, args.bias_attr_idx]
                pred = pred_conflict.data.max(1, keepdim=True)[1].squeeze(1)

                ac_flag = (label == bias_label) & (label == pred)
                aw_flag = (label == bias_label) & (label != pred)
                cc_flag = (label != bias_label) & (label == pred)
                cw_flag = (label != bias_label) & (label != pred)

                ac_flag = ac_flag & curr_align_flag
                aw_flag = aw_flag & curr_align_flag
                cc_flag = cc_flag & curr_align_flag
                cw_flag = cw_flag & curr_align_flag

                loss_dis_align_temp = self.criterion(pred_align, label)
                # self.board_lff_wx(step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag)
                if step > len(train_iter):
                    self.board_weights(step)

            if step % args.valid_freq == 0:
                self.board_disent_acc(step)

            cnt += data.shape[0]
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0

    def test(self, args):
        if 'mnist' in args.dataset:
            self.model_d = get_model('mlp_DISENTANGLE', self.num_classes).to(self.device)
            self.model_b = get_model('mlp_DISENTANGLE', self.num_classes).to(self.device)
        else:
            self.model_d = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device)
            self.model_b = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device)

        self.model_d.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_d.th'))['state_dict'])
        self.model_b.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_b.th'))['state_dict'])
        self.board_disent_acc(step=0, inference=True)
