from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import torch.optim as optim

from data.util import get_dataset, IdxDataset
from module.loss import GeneralizedCELoss
from module.util import get_model
from module.util import get_backbone
from util import *

import warnings
warnings.filterwarnings(action='ignore')
import copy
import wandb

from learners.base_learner import *

class ShapeLearner(Learner):
    def __init__(self, args):
        super(ShapeLearner, self).__init__(args)
        if args.fd == "patch-shuffle":
            self.trans = ShufflePatches(patch_size=args.ps)
        elif args.fd == "pixel-shuffle":
            self.trans = ShufflePatches(patch_size=1)
        elif args.fd == "center-occlude":
            self.trans = CenterOcclude(occlusion_size=args.ps)
        elif args.fd == "gray-scale":
            self.trans = gray_scale
        self.unlearn_criterion = nn.CrossEntropyLoss(reduction='none', label_smoothing=1)

    def board_lff_loss(self, step, loss_b, loss_d):
        
        if self.args.wandb:
            wandb.log({
                "loss/loss_b_train": loss_b,
                "loss/loss_d_train": loss_d,
                }, step=step)
        
        if self.args.tensorboard:
            self.writer.add_scalar(f"loss/loss_b_train", loss_b, step)
            self.writer.add_scalar(f"loss/loss_d_train", loss_d, step)

    def board_weights(self, step, his=False):
        weights = []
        loss_b_ls = []
        loss_d_ls = []
        hist_b_ls = []
        hist_d_ls = []
        hist_wx_ls = []
        if self.dataset in ['bar', 'NICO']: # class-wise sub group
            for state in (list(range(self.n_classes)) + [None]):
                index = torch.ones_like(self.label_index, dtype=torch.bool).to(self.device)
                if state is not None:
                    index = self.label_index == state
                label = self.label_index[index] #!
                # skip if certain subgroup is empty
                if label.shape[0] == 0:
                    weights.append(-0.1)
                    loss_b_ls.append(-0.1)
                    loss_d_ls.append(-0.1)
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
                    continue
                # class-wise normalize
                loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
                loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()
                loss_b_ls.append(loss_b.mean())
                loss_d_ls.append(loss_d.mean())
                wx = self.calc_weights(loss_b, loss_d, label)
                weights.append(wx.mean())
                if his:
                    hist_b_ls.append(wandb.Histogram(loss_b.cpu().numpy()))
                    hist_d_ls.append(wandb.Histogram(loss_d.cpu().numpy()))
                    hist_wx_ls.append(wandb.Histogram(wx.cpu().numpy()))
                else:
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
        else:
            for state in [-1, 0, 1, None]:
                index = torch.ones_like(self.label_index, dtype=torch.bool).to(self.device)
                if state is not None:
                    index = self.bias_state.squeeze(1) == state
                label = self.label_index[index]
                # skip if certain subgroup is empty
                if label.shape[0] == 0:
                    weights.append(-0.1)
                    loss_b_ls.append(-0.1)
                    loss_d_ls.append(-0.1)
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
                    continue
                # class-wise normalize
                loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
                loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()
                loss_b_ls.append(loss_b.mean())
                loss_d_ls.append(loss_d.mean())
                wx = self.calc_weights(loss_b, loss_d, label)
                weights.append(wx.mean())
                if his:
                    hist_b_ls.append(wandb.Histogram(loss_b.cpu().numpy()))
                    hist_d_ls.append(wandb.Histogram(loss_d.cpu().numpy()))
                    hist_wx_ls.append(wandb.Histogram(wx.cpu().numpy()))
                else:
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
            

        weights = trans_result(weights, self.keys)
        loss_b_ls = trans_result(loss_b_ls, self.keys)
        loss_d_ls = trans_result(loss_d_ls, self.keys)
        hist_b_ls = trans_result(hist_b_ls, self.keys)
        hist_d_ls = trans_result(hist_d_ls, self.keys)
        hist_wx_ls = trans_result(hist_wx_ls, self.keys)

        log_dict = {
                    "w(x)_mean": weights,
                    "loss/loss_b": loss_b_ls,
                    "loss/loss_d": loss_d_ls,
                    "hist/hist_loss_b": hist_b_ls,
                    "hist/hist_loss_d": hist_d_ls,
                    "hist/hist_weights": hist_wx_ls,
                }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def calc_weights(self, loss_b, loss_d, label):
        if np.isnan(loss_b.mean().item()):
            raise NameError('loss_b_ema')
        if np.isnan(loss_d.mean().item()):
            raise NameError('loss_d_ema')

        label_cpu = label.cpu()

        for c in range(self.num_classes):
            class_index = np.where(label_cpu == c)[0]
            if "_max" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.max_loss(c)
            elif "_mean" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.mean_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.mean_loss(c)
            elif "_1-x" in self.args.w_func:
                x_b = torch.exp(-loss_b)
                x_d = torch.exp(-loss_d)
                loss_b = 1 - x_b
                loss_d = 1 - x_d
                break
            else:
                break
            loss_b[class_index] /= norm_loss_b
            loss_d[class_index] /= norm_loss_d
        
        scale = 2 if 's2' in self.args.w_func else 1
        
        if self.args.w_func.startswith('frac_'):
            loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8)
        if self.args.w_func.startswith('frac+0.01_'):
            loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8) + 0.01
        if self.args.w_func.startswith('b/(1-a+b)_'):
            loss_weight = scale * loss_b / (loss_b + 1 - loss_d + 1e-8)
        elif self.args.w_func.startswith('b_'):
            loss_weight = scale * loss_b
        elif self.args.w_func.startswith('b+0.01_'):
            loss_weight = scale * loss_b + 0.01
        elif self.args.w_func.startswith('e^(2(b-1))_'):
            pow = 2 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('log(b+1)_'):
            loss_weight = scale * torch.log(loss_b+1)
        elif self.args.w_func.startswith('e^(-x/(b+x))_'):
            loss_d *= norm_loss_d # denormalize
            x = torch.exp(-loss_d)   # recover probability
            pow = -x/(loss_b+x)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-a/(b+a))_'):
            pow = -loss_d/(loss_b+loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a))_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('be^(-1/(b+a))_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+2(b-1))_'):
            pow = -1/(loss_b+loss_d) + 2 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+3(b-1))_'):
            pow = -1/(loss_b+loss_d) + 3 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+5(b-1))_'):
            pow = -1/(loss_b+loss_d) + 5 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('be^(-1/(b+a))+0.01_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow) + 0.01
        elif self.args.w_func.startswith('be^(-1/(b+a))+0.1_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow) + 0.1

        elif self.args.w_func.startswith('e^(-1/(b+0.5a))_'):
            pow = -1/(loss_b+0.5*loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+0.3a))_'):
            pow = -1/(loss_b+0.3*loss_d)
            loss_weight = scale * torch.exp(pow)

        return loss_weight
        # loss_weight_total = loss_weight

        # loss_weight_total = loss_weight_total.mean()
        # return loss_weight_total

    def board_lff_wx(self, step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag):
        log_dict = {
                    "w(x)_mean/align": loss_weight[aw_flag | ac_flag].mean(),
                    "w(x)_mean/conflict": loss_weight[cw_flag | cc_flag].mean(),
                }
        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def board_lff_acc(self, step, inference=None, his=False):
        
        # check label network
        valid_accs_b, valid_losses_b, _ = self.evaluate(self.model_b, self.valid_loader)
        test_accs_b, test_losses_b, test_hist_b = self.evaluate(self.model_b, self.test_loader, his=his)

        valid_accs_d, valid_losses_d, _ = self.evaluate(self.model_d, self.valid_loader)
        test_accs_d, test_losses_d, test_hist_d = self.evaluate(self.model_d, self.test_loader, his=his)
        
        # valid_accs_i, valid_losses_i, _ = self._evaluate(self.model_b, self.model_d, self.valid_loader, his=his)
        # test_accs_i, test_losses_i, test_hist_i = self._evaluate(self.model_b, self.model_d, self.test_loader, his=his)

        if inference:
            print(f'test acc: {test_accs_d.item()}')
            import sys
            sys.exit(0)

        if valid_accs_b[-1] >= self.best_valid_acc_b[-1]:
            self.best_valid_acc_b = valid_accs_b
            self.best_valid_loss_b = valid_losses_b

        if test_accs_b[-1] >= self.best_test_acc_b[-1]:
            self.best_test_acc_b = test_accs_b
            self.best_test_loss_b = test_losses_b

        if valid_accs_d[-1] >= self.best_valid_acc_d[-1]:
            self.best_valid_acc_d = valid_accs_d
            self.best_valid_loss_d = valid_losses_d

        if test_accs_d[-1] >= self.best_test_acc_d[-1]:
            self.best_test_acc_d = test_accs_d
            self.best_test_loss_d = test_losses_d
            self.save_best(step)
            
        # if test_accs_i[-1] >= self.best_test_acc_i[-1]:
        #     self.best_test_acc_i = test_accs_i
        #     self.best_test_loss_i = test_losses_i
            # self.save_best(step)
            
        log = {
            "acc/acc_b_valid": trans_result(valid_accs_b, self.keys),
            "acc/acc_b_test": trans_result(test_accs_b, self.keys),
            "acc/best_acc_b_valid": trans_result(self.best_valid_acc_b, self.keys),
            "acc/best_acc_b_test": trans_result(self.best_test_acc_b, self.keys),
            "acc/acc_d_valid": trans_result(valid_accs_d, self.keys),
            "acc/acc_d_test": trans_result(test_accs_d, self.keys),
            "acc/best_acc_d_valid": trans_result(self.best_valid_acc_d, self.keys),
            "acc/best_acc_d_test": trans_result(self.best_test_acc_d, self.keys),
            # "acc/acc_i_valid": trans_result(valid_accs_i, self.keys),
            # "acc/acc_i_test": trans_result(test_accs_i, self.keys),
            "acc/best_acc_i_valid": trans_result(self.best_valid_acc_i, self.keys),
            "acc/best_acc_i_test": trans_result(self.best_test_acc_i, self.keys),
            
            "loss/loss_b_valid": trans_result(valid_losses_b, self.keys),
            "loss/loss_b_test": trans_result(test_losses_b, self.keys),
            "loss/best_loss_b_valid": trans_result(self.best_valid_loss_b, self.keys),
            "loss/best_loss_b_test": trans_result(self.best_test_loss_b, self.keys),
            "loss/loss_d_valid": trans_result(valid_losses_d, self.keys),
            "loss/loss_d_test": trans_result(test_losses_d, self.keys),
            "loss/best_loss_d_valid": trans_result(self.best_valid_loss_d, self.keys),
            "loss/best_loss_d_test": trans_result(self.best_test_loss_d, self.keys),
            # "loss/loss_i_valid": trans_result(valid_losses_i, self.keys),
            # "loss/loss_i_test": trans_result(test_losses_i, self.keys),
            "loss/best_loss_i_valid": trans_result(self.best_valid_loss_i, self.keys),
            "loss/best_loss_i_test": trans_result(self.best_test_loss_i, self.keys),
            
            "hist/loss_b_test": trans_result(test_hist_b, self.keys),
            "hist/loss_d_test": trans_result(test_hist_d, self.keys),
            # "hist/loss_i_test": trans_result(test_hist_i, self.keys),
        }
        
        if "_mb" in self.args.dataset: # for multiple bias datasets, do extra evalutations
            self.args.bias_attr_idx = 2
            test_accs_b_2, _, _ = self.evaluate(self.model_b, self.test_loader, his=his)
            test_accs_d_2, _, _ = self.evaluate(self.model_d, self.test_loader, his=his)
            log_ = {
                "acc/acc_b_test_2": trans_result(test_accs_b_2, self.keys),
                "acc/acc_d_test_2": trans_result(test_accs_d_2, self.keys),
            }
            log.update(log_)
            self.args.bias_attr_idx = 1
            
        if self.args.wandb:
            wandb.log(log, step=step)

        print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b} ')
        print(f'valid_d: {valid_accs_d} || test_d: {test_accs_d} ')

    def board_pretrain_best_acc(self, i, model_b, best_valid_acc_b, step):
        # check label network
        valid_accs_b, valid_losses_b, _ = self.evaluate(model_b, self.valid_loader)

        print(f'best: {best_valid_acc_b}, curr: {valid_accs_b}')

        if valid_accs_b[-1] > best_valid_acc_b[-1]:
            best_valid_acc_b = valid_accs_b

            ######### copy parameters #########
            self.best_model_b = copy.deepcopy(model_b)
            print(f'early model {i}th saved...')

        log_dict = {
            f"{i}_pretrain_best_valid_acc": trans_result(best_valid_acc_b, self.keys),
        }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        # if self.args.wandb:
        #     wandb.log(log_dict, step=step)

        return best_valid_acc_b

    def pretrain_b_ensemble_best(self, args):
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)
        epoch, cnt = 0, 0
        index_dict, label_dict, gt_prob_dict = {}, {}, {}

        # train multiple bias models on train_iter, select model with best valid acc. Then get the prediction confidence of biased models on the pretrain_loader data. Samples with gt confidence above a given threshold is marked by their index in exceed_mask. Among those samples, bias aligned sample are marked in exceed_align, while bias conflict are marked in exceed_conflict.
        for i in range(self.args.num_bias_models):
            best_valid_acc_b = [0,0,0,0]
            print(f'{i}th model working ...')
            del self.model_b
            self.best_model_b = None
            self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained, first_stage=True).to(self.device)
            self.optimizer_b = torch.optim.Adam(self.model_b.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
            for step in tqdm(range(self.args.biased_model_train_iter)):
                try:
                    index, data, attr, _ = next(train_iter)
                except:
                    train_iter = iter(self.train_loader)
                    index, data, attr, _ = next(train_iter)

                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                
                if args.pre == "x_":
                    data = data
                elif args.pre == "x'":
                    data = self.trans(data)

                logit_b = self.model_b(data)
                loss_b_update = self.bias_criterion(logit_b, label)
                loss = loss_b_update.mean()

                self.optimizer_b.zero_grad()
                loss.backward()
                self.optimizer_b.step()

                cnt += len(index)
                if cnt >= train_num:
                    print(f'finished epoch: {epoch}')
                    epoch += 1
                    cnt = len(index)

                if step % args.valid_freq == 0:
                    best_valid_acc_b = self.board_pretrain_best_acc(i, self.model_b, best_valid_acc_b, step)

            label_list, bias_list, pred_list, index_list, gt_prob_list, align_flag_list = [], [], [], [], [], []
            self.best_model_b.eval()

            for index, data, attr, _ in self.pretrain_loader:
                index = index.to(self.device)
                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                bias_label = attr[:, args.bias_attr_idx]

                logit_b = self.best_model_b(data)
                prob = torch.softmax(logit_b, dim=-1)
                gt_prob = torch.gather(prob, index=label.unsqueeze(1), dim=1).squeeze(1)

                label_list += label.tolist()
                index_list += index.tolist()
                gt_prob_list += gt_prob.tolist()
                align_flag_list += (label == bias_label).tolist()

            index_list = torch.tensor(index_list)
            label_list = torch.tensor(label_list)
            gt_prob_list = torch.tensor(gt_prob_list)
            align_flag_list = torch.tensor(align_flag_list)

            align_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == True)).long()
            conflict_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == False)).long()
            mask = (gt_prob_list > args.biased_model_softmax_threshold).long()

            exceed_align = index_list[align_mask.nonzero().squeeze(1)]
            exceed_conflict = index_list[conflict_mask.nonzero().squeeze(1)]
            exceed_mask = index_list[mask.nonzero().squeeze(1)]

            model_index = i
            index_dict[f'{model_index}_exceed_align'] = exceed_align
            index_dict[f'{model_index}_exceed_conflict'] = exceed_conflict
            index_dict[f'{model_index}_exceed_mask'] = exceed_mask
            label_dict[model_index] = label_list
            gt_prob_dict[model_index] = gt_prob_list

            log_dict = {
                f"{model_index}_exceed_align": len(exceed_align),
                f"{model_index}_exceed_conflict": len(exceed_conflict),
                f"{model_index}_exceed_mask": len(exceed_mask),
            }
            if args.tensorboard:
                for key, value in log_dict.items():
                    self.writer.add_scalar(key, value, step)

        # aggregate confident samples, ba confident samples, bc conflict samples from all bias models
        exceed_mask = [(gt_prob_dict[i] > args.biased_model_softmax_threshold).long() for i in
                        range(self.args.num_bias_models)]
        exceed_mask_align = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == True)).long() for i in
            range(self.args.num_bias_models)]
        exceed_mask_conflict = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == False)).long() for i in
            range(self.args.num_bias_models)]
        
        # sum to calculate how likely each sample is considered confident in bias models
        mask_sum = torch.stack(exceed_mask).sum(dim=0)
        mask_sum_align = torch.stack(exceed_mask_align).sum(dim=0)
        mask_sum_conflict = torch.stack(exceed_mask_conflict).sum(dim=0)

        # if the number of bias models that consider a sample confident is larger then the threshold args.agreement, the sample is considered the final exceed sample
        total_exceed_mask = index_list[(mask_sum >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_align = index_list[(mask_sum_align >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_conflict = index_list[(mask_sum_conflict >= self.args.agreement).long().nonzero().squeeze(1)]

        exceed_mask_list = [total_exceed_mask]

        print(f'exceed mask list length: {len(exceed_mask_list)}')
        curr_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                              torch.tensor(total_exceed_mask).long().cuda())
        curr_align_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                    torch.tensor(total_exceed_align).long().cuda())
        curr_conflict_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                       torch.tensor(total_exceed_conflict).long().cuda())
        log_dict = {
            f"total_exceed_align": len(total_exceed_align),
            f"total_exceed_conflict": len(total_exceed_conflict),
            f"total_exceed_mask": len(total_exceed_mask),
        }

        total_exceed_mask = torch.tensor(total_exceed_mask)

        for key, value in log_dict.items():
            print(f"* {key}: {value}")
        print(f"* EXCEED DATA COUNT: {Counter(curr_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (ALIGN) COUNT: {Counter(curr_align_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (CONFLICT) COUNT: {Counter(curr_conflict_index_label.squeeze(1).tolist())}")

        if args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)

        return total_exceed_mask

    def train(self, args):
        print('Training LfF ...')

        num_updated = 0
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)

        mask_index = torch.zeros(train_num, 1)
        self.bias_state = torch.zeros(train_num, 1).cuda()
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0

        #### BiasEnsemble ####
        if args.bias_ensm:
            print("Applying Bias Ensemble for bias amplification ...")
            pseudo_align_flag = self.pretrain_b_ensemble_best(args)
            mask_index[pseudo_align_flag] = 1
        else:
            mask_index[:] = 1

        del self.model_b
        self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,gamma=args.lr_gamma)
            self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,gamma=args.lr_gamma)

        # get bias mapping
        bias = self.train_loader.dataset.dataset.bias
        
        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _ = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]
            
            # patch shuffle
            data_s = self.trans(data)

            # flag_conflict = (label != bias_label)
            bias_state = get_bias_state(label, bias_label, bias)
            self.bias_state[index] = bias_state
            self.label_index[index] = label

            logit_b = self.model_b(data)
            logit_b_s = self.model_b(data_s)
            logit_d = self.model_d(data)
            logit_d_s = self.model_d(data_s)

            if args.biloss == "x_":
                loss_b = self.criterion(logit_b, label).detach()
            elif args.biloss == "x'":
                loss_b = self.criterion(logit_b_s, label).detach()
            loss_d = self.criterion(logit_d, label).detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d')

            # EMA sample loss: smoothly update the loss of each sample
            self.sample_loss_ema_b.update(loss_b, index)
            self.sample_loss_ema_d.update(loss_d, index)

            # class-wise normalize: get the smoothed loss as the final loss, and normalize the sample loss with the maximum loss of its corresponding class
            loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
            loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b_ema')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d_ema')

            loss_weight = self.calc_weights(loss_b, loss_d, label)
            # label_cpu = label.cpu()

            # for c in range(self.num_classes):
            #     class_index = np.where(label_cpu == c)[0]
            #     max_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
            #     max_loss_d = self.sample_loss_ema_d.max_loss(c)
            #     loss_b[class_index] /= max_loss_b
            #     loss_d[class_index] /= max_loss_d

            # # re-weighting based on loss value / generalized CE for biased model
            # loss_weight = loss_b / (loss_b + loss_d + 1e-8)
            pred = logit_d.data.max(1, keepdim=True)[1].squeeze(1)


            if np.isnan(loss_weight.mean().item()):
                raise NameError('loss_weight')

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)

            # update the biased classifier only on confident samples agreed by the biased committee
            if args.biloss == "x_":
                loss_b_update = self.bias_criterion(logit_b[curr_align_flag], label[curr_align_flag])
            elif args.biloss == "x'":
                loss_b_update = self.bias_criterion(logit_b_s[curr_align_flag], label[curr_align_flag])
            
            
            w = loss_weight.to(self.device)
            null_label = torch.zeros_like(label)

            if args.deloss == "w_":
                loss_d_update = self.debias_criterion(logit_d, label) * w
            elif args.deloss == "b_ce-y+ce-null":
                loss_d_update = self.debias_criterion(logit_d, label) * w + args.beta * (1 - w) * (self.criterion(logit_d, label) + self.unlearn_criterion(logit_d_s, null_label))
            elif args.deloss == "b_ce-y-ce-y":
                loss_d_update = self.debias_criterion(logit_d, label) * w + args.beta * (1 - w) * (self.criterion(logit_d, label) - self.criterion(logit_d_s, label))
            else:
                raise ValueError('loss function not implemented')

            if np.isnan(loss_b_update.mean().item()):
                raise NameError('loss_b_update')

            if np.isnan(loss_d_update.mean().item()):
                raise NameError('loss_d_update')

            loss = loss_b_update.mean() + loss_d_update.mean()
            num_updated += loss_weight.mean().item() * data.size(0)

            self.optimizer_b.zero_grad()
            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_b.step()
            self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_l.step()

            if args.use_lr_decay and step % args.lr_decay_step == 0:
                print('******* learning rate decay .... ********')
                print(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                print(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

            if step % args.log_freq == 0:
                self.board_lff_loss(step, loss_b_update.mean(), loss_d_update.mean())
                bias_label = attr[:, args.bias_attr_idx]

                ### used bias labels for logging
                ac_flag = (label == bias_label) & (label == pred)
                aw_flag = (label == bias_label) & (label != pred)
                cc_flag = (label != bias_label) & (label == pred)
                cw_flag = (label != bias_label) & (label != pred)

                ac_flag = ac_flag & curr_align_flag
                aw_flag = aw_flag & curr_align_flag
                cc_flag = cc_flag & curr_align_flag
                cw_flag = cw_flag & curr_align_flag

                # self.board_lff_wx(step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag)

                if step > len(train_iter):
                    self.board_weights(step, his=args.his)

            if step % args.valid_freq == 0:
                self.board_lff_acc(step, his=args.his)

                if args.use_lr_decay and args.tensorboard:
                    self.writer.add_scalar(f"loss/learning rate", self.optimizer_d.param_groups[-1]['lr'], step)

            cnt += len(index)
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0

    def test(self, args):
        if 'mnist' in args.dataset:
            self.model_b = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        else:
            self.model_b = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.model_d.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_d.th'))['state_dict'])
        self.model_b.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_b.th'))['state_dict'])
        self.board_lff_acc(step=0, inference=True)
        
    # evaluation with caliberated inference
    def _evaluate(self, model_b, model_d, data_loader, var=False, his=False):
        """
        return list of acc: [bias conflict, bias neutral, bias align, overall]
        """
        model_b.eval()
        model_d.eval()
        total_correct, total_num = 0, 0
        pred_ls = []
        loss_ls = []
        label_ls = []
        s_ls = []
        
        ls = nn.LogSigmoid()
        for data, attr, index in tqdm(data_loader, leave=False):
            label = attr[:, 0].to(self.device)
            s = attr[:, 1].to(self.device)
            data = data.to(self.device)
            # s = s.to(self.device)
            

            with torch.no_grad():
                logit_b = model_b(data)
                logit_d = model_d(data)
                if self.args.infer == "ls_d-ls_b":
                    logit = ls(logit_d) - ls(logit_b)
                elif self.args.infer == "d-b":
                    logit = logit_d - logit_b
                else:
                    dft = [-0.01, -0.01, -0.01, -0.01]
                    result = (dft, dft)
                    if var:
                        result += (dft,)
                    result += (dft,)
                    return result
                    # raise ValueError("invalid inference procedure")
                loss = self.criterion(logit, label)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                # correct = (pred == label).long()
                # total_correct += correct.sum()
                # total_num += correct.shape[0]
                pred_ls.append(pred)
                loss_ls.append(loss)
                label_ls.append(label)
                s_ls.append(s)
        
        # pred = torch.concatenate(pred_ls)
        # label = torch.concatenate(label_ls)
        # s = torch.concatenate(s_ls)
        pred = torch.cat(pred_ls)
        loss = torch.cat(loss_ls)
        label = torch.cat(label_ls)
        s = torch.cat(s_ls)
        bias = data_loader.dataset.bias
        correct = (pred == label)
        
        # partition according to bias state and calc accs
        bias_state = get_bias_state(label, s, bias).squeeze(1)
        
        accs = []
        losses = []
        loss_var = []
        hist = []
        if self.dataset in ['bar', 'NICO']: # class-wise sub group
            for state in range(self.n_classes):
                total_correct = correct[bias_state==state].sum()
                total = correct[bias_state==state].shape[0]
                l = loss[bias_state==state].mean()
                l_v = loss[bias_state==state].var()
                if his:
                    h = wandb.Histogram(loss[bias_state==state].cpu().numpy())
                else:
                    h = None
                if total == 0:
                    accs.append(-0.01)
                    losses.append(-0.01)
                    loss_var.append(-0.01)
                    hist.append(None)
                    continue
                accs.append((total_correct/float(total)).item())
                losses.append(l.item())
                loss_var.append(l_v.item())
                hist.append(h)
        else:
            for state in [-1, 0, 1]:
                total_correct = correct[bias_state==state].sum()
                total = correct[bias_state==state].shape[0]
                l = loss[bias_state==state].mean()
                l_v = loss[bias_state==state].var()
                if his:
                    h = wandb.Histogram(loss[bias_state==state].cpu().numpy())
                else:
                    h = None
                if total == 0:
                    accs.append(-0.01)
                    losses.append(-0.01)
                    loss_var.append(-0.01)
                    hist.append(None)
                    continue
                accs.append((total_correct/float(total)).item())
                losses.append(l.item())
                loss_var.append(l_v.item())
                hist.append(h)

        # calc overall acc
        dataset = self.args.dataset.split('-')[0] if '-' in self.args.dataset else self.args.dataset
        if dataset in ['cmnist', 'scmnist', "bffhq", "bar", "dogs_and_cats", "corruptedCifar10"]:
            total_correct = correct.sum()
            total = correct.shape[0]
            overall_acc = total_correct/float(total)
            overall_loss = loss.mean()
            overall_loss_var = loss.var()
            
        else:
            n_classes =  data_loader.dataset.n_classes
            n_s =  data_loader.dataset.n_s
            total_acc = torch.zeros((n_classes, n_s))
            total_l = torch.zeros((n_classes, n_s))
            total_lv = torch.zeros((n_classes, n_s))
            n_exist_group = 0
            # go through all groups
            for i in range(n_classes):
                mask_i = label == i
                for j in range(n_s):
                    mask_j = s == j
                    mask = mask_i & mask_j # mask for group with label=i and attr=j
                    total_correct = correct[mask].sum()
                    total = correct[mask].shape[0]
                    if total == 0: # skip if the group does not exist
                        continue
                    l = loss[mask].mean()
                    l_v = loss[mask].var()
                    total_acc[i][j] = total_correct / float(total)
                    total_l[i][j] = l
                    total_lv[i][j] = l_v
                    n_exist_group += 1
            # average
            overall_acc = total_acc.sum() / n_exist_group
            overall_loss = total_l.sum() / n_exist_group
            overall_loss_var = total_lv.sum() / n_exist_group
        if his:
            overall_hist = wandb.Histogram(loss.cpu().numpy())
        else:
            overall_hist = None

        accs.append(overall_acc.item())
        losses.append(overall_loss.item())
        loss_var.append(overall_loss_var.item())
        hist.append(overall_hist)

        # accs = total_correct/float(total_num)
        model_b.train()
        model_d.train()
        result = (accs, losses)
        if var:
            result += (loss_var,)
        result += (hist,)
        return result

# class ShufflePatches(object):
#   def __init__(self, patch_size):
#     self.ps = patch_size

#   def __call__(self, x):
#     # divide the batch of images into non-overlapping patches
#     u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
#     # permute the patches of each image in the batch
#     pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
#     # fold the permuted patches back together
#     f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
#     return f