from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler

import os
import torch.optim as optim

from data.util import get_dataset, IdxDataset
from module.loss import GeneralizedCELoss
from module.util import get_model
from module.util import get_backbone
from util import *
import torch.nn.functional as F

import warnings
warnings.filterwarnings(action='ignore')
import copy
import wandb

from learners.base_learner import *
from data.util import ProbDataset

class DPRLearner(Learner):
    def __init__(self, args):
        super(DPRLearner, self).__init__(args)
        if args.fd == "patch-shuffle":
            self.trans = ShufflePatches(patch_size=args.ps)
        elif args.fd == "pixel-shuffle":
            self.trans = ShufflePatches(patch_size=1)
        elif args.fd == "center-occlude":
            self.trans = CenterOcclude(occlusion_size=args.ps)
        elif args.fd == "gray-scale":
            self.trans = gray_scale
        self.unlearn_criterion = nn.CrossEntropyLoss(reduction='none', label_smoothing=1)
        
        # define custom metrics for wandb
        if args.wandb:
            wandb.define_metric("b_step")
            wandb.define_metric("acc/acc_b", step_metric="b_step")
            wandb.define_metric("acc/best_acc_b", step_metric="b_step")
            wandb.define_metric("loss/loss_b", step_metric="b_step")
            wandb.define_metric("d_step")
            wandb.define_metric("acc/acc_d", step_metric="d_step")
            wandb.define_metric("acc/best_acc_d", step_metric="d_step")
            wandb.define_metric("loss/loss_d", step_metric="d_step")
            # wandb.define_metric("c_step")
            # wandb.define_metric("acc/", step_metric="c_step")
            # wandb.define_metric("loss/", step_metric="c_step")

    def board_loss(self, step, key="b", loss=-1):
        
        if self.args.wandb:
            wandb.log({
                f"loss/loss_{key}_train": loss,
                # "b_step": step,
                # "d_step": step
                f"{key}_step": step
                })

    def board_weights(self, step, his=False):
        weights = []
        loss_b_ls = []
        loss_d_ls = []
        hist_b_ls = []
        hist_d_ls = []
        hist_wx_ls = []
        if self.dataset in ['bar', 'NICO']: # class-wise sub group
            for state in (list(range(self.n_classes)) + [None]):
                index = torch.ones_like(self.label_index, dtype=torch.bool).to(self.device)
                if state is not None:
                    index = self.label_index == state
                label = self.label_index[index] #!
                # skip if certain subgroup is empty
                if label.shape[0] == 0:
                    weights.append(-0.1)
                    loss_b_ls.append(-0.1)
                    loss_d_ls.append(-0.1)
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
                    continue
                # class-wise normalize
                loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
                loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()
                loss_b_ls.append(loss_b.mean())
                loss_d_ls.append(loss_d.mean())
                wx = self.calc_weights(loss_b, loss_d, label)
                weights.append(wx.mean())
                if his:
                    hist_b_ls.append(wandb.Histogram(loss_b.cpu().numpy()))
                    hist_d_ls.append(wandb.Histogram(loss_d.cpu().numpy()))
                    hist_wx_ls.append(wandb.Histogram(wx.cpu().numpy()))
                else:
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
        else:
            for state in [-1, 0, 1, None]:
                index = torch.ones_like(self.label_index, dtype=torch.bool).to(self.device)
                if state is not None:
                    index = self.bias_state.squeeze(1) == state
                label = self.label_index[index]
                # skip if certain subgroup is empty
                if label.shape[0] == 0:
                    weights.append(-0.1)
                    loss_b_ls.append(-0.1)
                    loss_d_ls.append(-0.1)
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
                    continue
                # class-wise normalize
                loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
                loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()
                loss_b_ls.append(loss_b.mean())
                loss_d_ls.append(loss_d.mean())
                wx = self.calc_weights(loss_b, loss_d, label)
                weights.append(wx.mean())
                if his:
                    hist_b_ls.append(wandb.Histogram(loss_b.cpu().numpy()))
                    hist_d_ls.append(wandb.Histogram(loss_d.cpu().numpy()))
                    hist_wx_ls.append(wandb.Histogram(wx.cpu().numpy()))
                else:
                    hist_b_ls.append(None)
                    hist_d_ls.append(None)
                    hist_wx_ls.append(None)
            

        weights = trans_result(weights, self.keys)
        loss_b_ls = trans_result(loss_b_ls, self.keys)
        loss_d_ls = trans_result(loss_d_ls, self.keys)
        hist_b_ls = trans_result(hist_b_ls, self.keys)
        hist_d_ls = trans_result(hist_d_ls, self.keys)
        hist_wx_ls = trans_result(hist_wx_ls, self.keys)

        log_dict = {
                    "w(x)_mean": weights,
                    "loss/loss_b": loss_b_ls,
                    "loss/loss_d": loss_d_ls,
                    "hist/hist_loss_b": hist_b_ls,
                    "hist/hist_loss_d": hist_d_ls,
                    "hist/hist_weights": hist_wx_ls,
                }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def calc_weights(self, loss_b, loss_d, label):
        if np.isnan(loss_b.mean().item()):
            raise NameError('loss_b_ema')
        if np.isnan(loss_d.mean().item()):
            raise NameError('loss_d_ema')

        label_cpu = label.cpu()

        for c in range(self.num_classes):
            class_index = np.where(label_cpu == c)[0]
            if "_max" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.max_loss(c)
            elif "_mean" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.mean_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.mean_loss(c)
            elif "_1-x" in self.args.w_func:
                x_b = torch.exp(-loss_b)
                x_d = torch.exp(-loss_d)
                loss_b = 1 - x_b
                loss_d = 1 - x_d
                break
            else:
                break
            loss_b[class_index] /= norm_loss_b
            loss_d[class_index] /= norm_loss_d
        
        scale = 2 if 's2' in self.args.w_func else 1
        
        if self.args.w_func.startswith('frac_'):
            loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8)
        if self.args.w_func.startswith('frac+0.01_'):
            loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8) + 0.01
        if self.args.w_func.startswith('b/(1-a+b)_'):
            loss_weight = scale * loss_b / (loss_b + 1 - loss_d + 1e-8)
        elif self.args.w_func.startswith('b_'):
            loss_weight = scale * loss_b
        elif self.args.w_func.startswith('b+0.01_'):
            loss_weight = scale * loss_b + 0.01
        elif self.args.w_func.startswith('e^(2(b-1))_'):
            pow = 2 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('log(b+1)_'):
            loss_weight = scale * torch.log(loss_b+1)
        elif self.args.w_func.startswith('e^(-x/(b+x))_'):
            loss_d *= norm_loss_d # denormalize
            x = torch.exp(-loss_d)   # recover probability
            pow = -x/(loss_b+x)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-a/(b+a))_'):
            pow = -loss_d/(loss_b+loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a))_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('be^(-1/(b+a))_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+2(b-1))_'):
            pow = -1/(loss_b+loss_d) + 2 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+3(b-1))_'):
            pow = -1/(loss_b+loss_d) + 3 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+a)+5(b-1))_'):
            pow = -1/(loss_b+loss_d) + 5 * (loss_b - 1)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('be^(-1/(b+a))+0.01_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow) + 0.01
        elif self.args.w_func.startswith('be^(-1/(b+a))+0.1_'):
            pow = -1/(loss_b+loss_d)
            loss_weight = scale * loss_b * torch.exp(pow) + 0.1

        elif self.args.w_func.startswith('e^(-1/(b+0.5a))_'):
            pow = -1/(loss_b+0.5*loss_d)
            loss_weight = scale * torch.exp(pow)
        elif self.args.w_func.startswith('e^(-1/(b+0.3a))_'):
            pow = -1/(loss_b+0.3*loss_d)
            loss_weight = scale * torch.exp(pow)

        return loss_weight
        # loss_weight_total = loss_weight

        # loss_weight_total = loss_weight_total.mean()
        # return loss_weight_total

    def board_lff_wx(self, step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag):
        log_dict = {
                    "w(x)_mean/align": loss_weight[aw_flag | ac_flag].mean(),
                    "w(x)_mean/conflict": loss_weight[cw_flag | cc_flag].mean(),
                }
        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def board_acc(self, step, key="b", model=None, best_v=[], best_t=[], inference=None, his=False):
        
        # check label network
        valid_accs, valid_losses, _ = self.evaluate(model, self.valid_loader)
        test_accs, test_losses, test_hist = self.evaluate(model, self.test_loader, his=his)
        best = False

        if inference:
            print(f'test acc: {test_accs.item()}')
            import sys
            sys.exit(0)

        if valid_accs[-1] >= best_v[-1]:
            best = True
            best_v = valid_accs

        if test_accs[-1] >= best_t[-1]:
            best_t = test_accs


        log = {
            f"acc/acc_{key}_valid": trans_result(valid_accs, self.keys),
            f"acc/acc_{key}_test": trans_result(test_accs, self.keys),
            f"acc/best_acc_{key}_valid": trans_result(self.best_valid_acc_b, self.keys),
            f"acc/best_acc_{key}_test": trans_result(self.best_test_acc_b, self.keys),
            
            f"loss/loss_{key}_valid": trans_result(valid_losses, self.keys),
            f"loss/loss_{key}_test": trans_result(test_losses, self.keys),
            # f"loss/best_loss_{key}_valid": trans_result(best_v, self.keys),
            # f"loss/best_loss_{key}_test": trans_result(best_t, self.keys),
            f"hist/loss_{key}_test": trans_result(test_hist, self.keys),
            
            # f"{key}_step": step,
            f"{key}_step": step
        }
        
        if best:
            # self.save_models(f"best", debias=key=="d", bias=key=="b")
            p = os.path.join(self.log_dir, f"model_{key}_{self.args.biloss}_best_seed{self.args.seed}.pt")
            torch.save(model, p)

            
        if self.args.wandb:
            wandb.log(log)

        print(f'valid_{key}: {valid_accs} || test_{key}: {test_accs} ')
        
        return best_v, best_t

    def board_pretrain_best_acc(self, i, model_b, best_valid_acc_b, step):
        # check label network
        valid_accs_b, valid_losses_b, _ = self.evaluate(model_b, self.valid_loader)

        print(f'best: {best_valid_acc_b}, curr: {valid_accs_b}')

        if valid_accs_b[-1] > best_valid_acc_b[-1]:
            best_valid_acc_b = valid_accs_b

            ######### copy parameters #########
            self.best_model_b = copy.deepcopy(model_b)
            print(f'early model {i}th saved...')

        log_dict = {
            f"{i}_pretrain_best_valid_acc": trans_result(best_valid_acc_b, self.keys),
        }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        # if self.args.wandb:
        #     wandb.log(log_dict, step=step)

        return best_valid_acc_b

    def pretrain_b_ensemble_best(self, args):
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)
        epoch, cnt = 0, 0
        index_dict, label_dict, gt_prob_dict = {}, {}, {}

        # train multiple bias models on train_iter, select model with best valid acc. Then get the prediction confidence of biased models on the pretrain_loader data. Samples with gt confidence above a given threshold is marked by their index in exceed_mask. Among those samples, bias aligned sample are marked in exceed_align, while bias conflict are marked in exceed_conflict.
        for i in range(self.args.num_bias_models):
            best_valid_acc_b = [0,0,0,0]
            print(f'{i}th model working ...')
            del self.model_b
            self.best_model_b = None
            self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained, first_stage=True).to(self.device)
            self.optimizer_b = torch.optim.Adam(self.model_b.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
            for step in tqdm(range(self.args.biased_model_train_iter)):
                try:
                    index, data, attr, _ = next(train_iter)
                except:
                    train_iter = iter(self.train_loader)
                    index, data, attr, _ = next(train_iter)

                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                
                if args.pre == "x_":
                    data = data
                elif args.pre == "x'":
                    data = self.trans(data)

                logit_b = self.model_b(data)
                loss_b_update = self.bias_criterion(logit_b, label)
                loss = loss_b_update.mean()

                self.optimizer_b.zero_grad()
                loss.backward()
                self.optimizer_b.step()

                cnt += len(index)
                if cnt >= train_num:
                    print(f'finished epoch: {epoch}')
                    epoch += 1
                    cnt = len(index)

                if step % args.valid_freq == 0:
                    best_valid_acc_b = self.board_pretrain_best_acc(i, self.model_b, best_valid_acc_b, step)

            label_list, bias_list, pred_list, index_list, gt_prob_list, align_flag_list = [], [], [], [], [], []
            self.best_model_b.eval()

            for index, data, attr, _ in self.pretrain_loader:
                index = index.to(self.device)
                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                bias_label = attr[:, args.bias_attr_idx]

                logit_b = self.best_model_b(data)
                prob = torch.softmax(logit_b, dim=-1)
                gt_prob = torch.gather(prob, index=label.unsqueeze(1), dim=1).squeeze(1)

                label_list += label.tolist()
                index_list += index.tolist()
                gt_prob_list += gt_prob.tolist()
                align_flag_list += (label == bias_label).tolist()

            index_list = torch.tensor(index_list)
            label_list = torch.tensor(label_list)
            gt_prob_list = torch.tensor(gt_prob_list)
            align_flag_list = torch.tensor(align_flag_list)

            align_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == True)).long()
            conflict_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == False)).long()
            mask = (gt_prob_list > args.biased_model_softmax_threshold).long()

            exceed_align = index_list[align_mask.nonzero().squeeze(1)]
            exceed_conflict = index_list[conflict_mask.nonzero().squeeze(1)]
            exceed_mask = index_list[mask.nonzero().squeeze(1)]

            model_index = i
            index_dict[f'{model_index}_exceed_align'] = exceed_align
            index_dict[f'{model_index}_exceed_conflict'] = exceed_conflict
            index_dict[f'{model_index}_exceed_mask'] = exceed_mask
            label_dict[model_index] = label_list
            gt_prob_dict[model_index] = gt_prob_list

            log_dict = {
                f"{model_index}_exceed_align": len(exceed_align),
                f"{model_index}_exceed_conflict": len(exceed_conflict),
                f"{model_index}_exceed_mask": len(exceed_mask),
            }
            if args.tensorboard:
                for key, value in log_dict.items():
                    self.writer.add_scalar(key, value, step)

        # aggregate confident samples, ba confident samples, bc conflict samples from all bias models
        exceed_mask = [(gt_prob_dict[i] > args.biased_model_softmax_threshold).long() for i in
                        range(self.args.num_bias_models)]
        exceed_mask_align = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == True)).long() for i in
            range(self.args.num_bias_models)]
        exceed_mask_conflict = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == False)).long() for i in
            range(self.args.num_bias_models)]
        
        # sum to calculate how likely each sample is considered confident in bias models
        mask_sum = torch.stack(exceed_mask).sum(dim=0)
        mask_sum_align = torch.stack(exceed_mask_align).sum(dim=0)
        mask_sum_conflict = torch.stack(exceed_mask_conflict).sum(dim=0)

        # if the number of bias models that consider a sample confident is larger then the threshold args.agreement, the sample is considered the final exceed sample
        total_exceed_mask = index_list[(mask_sum >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_align = index_list[(mask_sum_align >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_conflict = index_list[(mask_sum_conflict >= self.args.agreement).long().nonzero().squeeze(1)]

        exceed_mask_list = [total_exceed_mask]

        print(f'exceed mask list length: {len(exceed_mask_list)}')
        curr_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                              torch.tensor(total_exceed_mask).long().cuda())
        curr_align_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                    torch.tensor(total_exceed_align).long().cuda())
        curr_conflict_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                       torch.tensor(total_exceed_conflict).long().cuda())
        log_dict = {
            f"total_exceed_align": len(total_exceed_align),
            f"total_exceed_conflict": len(total_exceed_conflict),
            f"total_exceed_mask": len(total_exceed_mask),
        }

        total_exceed_mask = torch.tensor(total_exceed_mask)

        for key, value in log_dict.items():
            print(f"* {key}: {value}")
        print(f"* EXCEED DATA COUNT: {Counter(curr_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (ALIGN) COUNT: {Counter(curr_align_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (CONFLICT) COUNT: {Counter(curr_conflict_index_label.squeeze(1).tolist())}")

        if args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)

        return total_exceed_mask

    def save_models(self,option,debias=False,bias=False):
        option = f"seed{self.args.seed}_" + option
        if 'ResNet20' in self.args.model:
            if debias:
                torch.save(self.model_d.state_dict(), self.log_dir+'/model_d_'+option+'.pt')
            if bias:
                torch.save(self.model_b.state_dict(), self.log_dir+'/model_b_'+option+'.pt')
        
        else:
            if debias:
                torch.save(self.model_d, self.log_dir+'/model_d_'+option+'.pt')
            if bias:
                torch.save(self.model_b, self.log_dir+'/model_b_'+option+'.pt')
        

    def load_models(self,option,debias=False,bias=False):
        option = f"seed{self.args.seed}_" + option
        print('=====')
        if self.args.model == 'ResNet20':
            if debias:
                self.model_d.load_state_dict(torch.load(self.log_dir+'/model_d_'+option+'.pt'))
                self.model_d = self.model_d.to(self.device)
                print(f'load debiased model: {self.log_dir}model_d_{option}.pt')
            if bias:
                self.model_b.load_state_dict(torch.load(self.log_dir+'/model_b_'+option+'.pt'))
                self.model_b = self.model_b.to(self.device)
                print(f'load biased model: {self.log_dir}model_b_{option}.pt')
        
        else:
            if debias:
                self.model_d = torch.load(self.log_dir+'/model_d_'+option+'.pt').to(self.device)
                print(f'load debiased model: {self.log_dir}model_d_{option}.pt')
            if bias:
                self.model_b = torch.load(self.log_dir+'/model_b_'+option+'.pt').to(self.device)
                print(f'load biased model: {self.log_dir}model_b_{option}.pt')
        
        print('=====')

    def train(self, args):
        
        # prepare biased model
        model_name = f"model_b_{args.biloss}_step{args.b_step}_seed{args.seed}.pt"
        model_path = os.path.join(self.log_dir, model_name)
        if os.path.exists(model_path):
            self.model_b = torch.load(model_path).to(self.device)
            print(f"====== loading biased model directly from {model_path} ======")
        else:
            print(f"====== fail in loading existing model, train from scratch ======")
            if args.b_step >= args.b_max_step:
                print(f"====== b_step {args.b_step} larger than b_max_step {args.b_max_step} ======")
                exit()
            self.train_b(args)
            # p = os.path.join(self.log_dir, f"model_b_end_seed{self.args.seed}.pt")
            # torch.save(self.model_b, p)
            self.model_b = torch.load(model_path).to(self.device)
        

            
        # Calculate probability
        prob_name = f"model_b_{args.biloss}_step{args.b_step}_seed{args.seed}_sampling-prob_tau{self.args.tau}.pt"
        prob_path = os.path.join(self.log_dir, prob_name)
        if os.path.exists(prob_path):
            saved_data = torch.load(prob_path)
            print(f"====== loading prob directly from {prob_path} ======")
        else:
            print(f"====== fail in loading existing prob, calculating from scratch ======")
            self.calc_prob(args, prob_path)
            saved_data = torch.load(prob_path)
            
        mag = saved_data['sampling_probability']
        mag_prob = mag / torch.sum(mag)
        inv_mag = 1./mag
        norm_inv_mag = inv_mag / torch.sum(inv_mag)
        mag_prob = 1./norm_inv_mag / torch.sum(1./norm_inv_mag)
            
            
        # Train model d
        # if not self.args.pretrained_dmodel:
        self.train_d(args, mag_prob)
        p = os.path.join(self.log_dir, f"model_d_end_seed{self.args.seed}.pt")
        torch.save(self.model_d, p)
            
        
    def calc_prob(self, args, prob_path):
        # Calculate probability
        self.model_b.eval()
        grad_mat = None
        
        idxorder = torch.zeros(len(self.train_loader.dataset))
        start, end = 0,0

        # g = (y, a)
        p_y_a = torch.zeros((len(self.train_loader.dataset), self.train_loader.dataset.dataset.n_classes, len(self.train_loader.dataset.dataset.bias))) # dim = N x C x C (examples, y, a)
        sampling_probability = torch.zeros(len(self.train_loader.dataset))
        labels = None
        biased_predictions = None
        
        self.train_loader = DataLoader(
            self.train_loader.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False
        )
        

        for _, e in enumerate(self.train_loader):
            index, data, attr, _ = e
            data = data.to(self.device)
            attr = attr.to(self.device)
            idx = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]
            
            end = start + len(label)
            if labels is None:
                labels = torch.zeros(len(self.train_loader.dataset)).to(label.dtype)
                biased_predictions = torch.zeros(len(self.train_loader.dataset)).to(label.dtype)

            logit_b = self.model_b(data)
            logit_b = logit_b / self.args.tau
            prob_b = F.softmax(logit_b, dim=1)
            predicted_a = torch.argmax(prob_b, dim=1)
            
            dl_dw = (1 - prob_b[torch.arange(prob_b.shape[0]), label])
            if grad_mat is None:
                grad_mat = torch.zeros(len(self.train_loader.dataset))
                grad_mat[torch.arange(start, end)] = dl_dw.detach().cpu()
            else:
                grad_mat[torch.arange(start, end)] = dl_dw.detach().cpu()
            

            p_y_a[torch.arange(start, end), label] = prob_b.detach().cpu()
            labels[torch.arange(start, end)] = label.detach().cpu()
            biased_predictions[torch.arange(start, end)] = predicted_a.detach().cpu()

            idxorder[start:end] = idx.detach().cpu()
            start = end
        
        p_y_a = torch.mean(p_y_a, dim=0)
        p_y = torch.sum(p_y_a, dim=1)
        
        sampling_probability = grad_mat # basically dl_dw, i.e. 1 - the probability of the ground truth predicted by the biased model, the likely hood of being a BC sample
        # compute sampling probability
        sampling_probability = torch.clamp(sampling_probability, min=1e-8)
            
        del(grad_mat)

        order = torch.argsort(idxorder)

        torch.save({'p_y_a': p_y_a, 
                'labels': labels[order], 
                'sampling_probability': sampling_probability[order]}, 
                prob_path)
        
        print('==============================')
        print(f'save sampling probability using tau {self.args.tau}')
        print('p_y_a: ', p_y_a)
        print('==============================')

        
    def train_b(self, args):
        print('Training DPR, biased model ...')

        num_updated = 0
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)

        mask_index = torch.zeros(train_num, 1)
        self.bias_state = torch.zeros(train_num, 1).cuda()
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0

        #### BiasEnsemble ####
        if args.bias_ensm:
            print("Applying Bias Ensemble for bias amplification ...")
            pseudo_align_flag = self.pretrain_b_ensemble_best(args)
            mask_index[pseudo_align_flag] = 1
        else:
            mask_index[:] = 1

        del self.model_b
        self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,gamma=args.lr_gamma)

        # get bias mapping
        bias = self.train_loader.dataset.dataset.bias
        
        # Bias train

        for step in tqdm(range(args.b_max_step)):
            # train main model
            try:
                index, data, attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _ = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]
            
            # patch shuffle
            data_s = self.trans(data)

            # flag_conflict = (label != bias_label)
            bias_state = get_bias_state(label, bias_label, bias)
            self.bias_state[index] = bias_state
            self.label_index[index] = label

            logit_b = self.model_b(data)
            logit_b_s = self.model_b(data_s)

            if args.biloss == "x_":
                loss_b = self.criterion(logit_b, label).detach()
            elif args.biloss == "x'":
                loss_b = self.criterion(logit_b_s, label).detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b')
            
            pred = logit_b.data.max(1, keepdim=True)[1].squeeze(1)

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)

            # update the biased classifier only on confident samples agreed by the biased committee
            if args.biloss == "x_":
                loss_b_update = self.bias_criterion(logit_b[curr_align_flag], label[curr_align_flag])
            elif args.biloss == "x'":
                loss_b_update = self.bias_criterion(logit_b_s[curr_align_flag], label[curr_align_flag])

            if np.isnan(loss_b_update.mean().item()):
                raise NameError('loss_b_update')

            loss = loss_b_update.mean()
            # num_updated += loss_weight.mean().item() * data.size(0)

            self.optimizer_b.zero_grad()
            # self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_b.step()
            # self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_b.step()

            if args.use_lr_decay and step % args.lr_decay_step == 0:
                print('******* learning rate decay .... ********')
                print(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                # print(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

            if step % args.log_freq == 0:
                self.board_loss(step, key="b", loss=loss_b_update.mean())
                bias_label = attr[:, args.bias_attr_idx]

                ### used bias labels for logging
                ac_flag = (label == bias_label) & (label == pred)
                aw_flag = (label == bias_label) & (label != pred)
                cc_flag = (label != bias_label) & (label == pred)
                cw_flag = (label != bias_label) & (label != pred)

                ac_flag = ac_flag & curr_align_flag
                aw_flag = aw_flag & curr_align_flag
                cc_flag = cc_flag & curr_align_flag
                cw_flag = cw_flag & curr_align_flag

                # self.board_lff_wx(step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag)

                # if step > len(train_iter):
                #     self.board_weights(step, his=args.his)
                # self.save_models(f"step{step}", bias=True)
                p = os.path.join(self.log_dir, f"model_b_{self.args.biloss}_step{step}_seed{self.args.seed}.pt")
                torch.save(self.model_b, p)

            if step % args.valid_freq == 0:
                self.best_valid_acc_b, self.best_test_acc_b = self.board_acc(step, key="b",  model=self.model_b, best_v=self.best_valid_acc_b, best_t=self.best_test_acc_b, his=args.his)

            cnt += len(index)
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0

    def train_d(self, args, mag_prob):
        print('Training DPR, debiased model ...')

        num_updated = 0
        sampler = WeightedRandomSampler(mag_prob, len(mag_prob))
        self.train_loader = DataLoader(
            # ProbDataset(self.train_loader.dataset),
            self.train_loader.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
            sampler=sampler
        )
        
        

        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)

        mask_index = torch.zeros(train_num, 1)
        self.bias_state = torch.zeros(train_num, 1).cuda()
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0

        # self.train_loader.dataset.update_prob(mag_prob)
        
        # initialize debiased model
        if args.d_init == "b":
            self.model_d.load_state_dict(self.model_b.state_dict())
            print("===== initalize model_d with model_b =====")
        elif args.d_init == "rand":
            del self.model_d
            self.model_d = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            print("===== initalize model_d randomly =====")
        else:
            raise ValueError("invalid initialization of model_d")
        
        self.optimizer_d = torch.optim.Adam(
                self.model_d.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        if args.use_lr_decay:
            self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,gamma=args.lr_gamma)

        # get bias mapping
        # bias = self.train_loader.dataset.dataset.dataset.bias
        bias = self.train_loader.dataset.dataset.bias
        
        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _ = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]
            
            # patch shuffle
            # data_s = self.trans(data)

            # flag_conflict = (label != bias_label)
            bias_state = get_bias_state(label, bias_label, bias)
            self.bias_state[index] = bias_state
            self.label_index[index] = label

            logit_d = self.model_d(data)
            loss_d = self.criterion(logit_d, label).detach()

            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d_ema')

            # # re-weighting based on loss value / generalized CE for biased model
            # loss_weight = loss_b / (loss_b + loss_d + 1e-8)
            pred = logit_d.data.max(1, keepdim=True)[1].squeeze(1)

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)
            
            null_label = torch.zeros_like(label)


            loss_d_update = self.debias_criterion(logit_d, label) 

            if np.isnan(loss_d_update.mean().item()):
                raise NameError('loss_d_update')

            loss = loss_d_update.mean()

            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_l.step()

            if args.use_lr_decay and step % args.lr_decay_step == 0:
                print('******* learning rate decay .... ********')
                print(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

            if step % args.log_freq == 0:
                self.board_loss(step, key="d", loss=loss_d_update.mean())
                bias_label = attr[:, args.bias_attr_idx]

                ### used bias labels for logging
                ac_flag = (label == bias_label) & (label == pred)
                aw_flag = (label == bias_label) & (label != pred)
                cc_flag = (label != bias_label) & (label == pred)
                cw_flag = (label != bias_label) & (label != pred)

                ac_flag = ac_flag & curr_align_flag
                aw_flag = aw_flag & curr_align_flag
                cc_flag = cc_flag & curr_align_flag
                cw_flag = cw_flag & curr_align_flag

                # self.board_lff_wx(step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag)


            if step % args.valid_freq == 0:
                self.best_valid_acc_d, self.best_test_acc_d = self.board_acc(step, key="d",  model=self.model_d, best_v=self.best_valid_acc_d, best_t=self.best_test_acc_d, his=args.his)

                if args.use_lr_decay and args.tensorboard:
                    self.writer.add_scalar(f"loss/learning rate", self.optimizer_d.param_groups[-1]['lr'], step)

            cnt += len(index)
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0


    def test(self, args):
        if 'mnist' in args.dataset:
            self.model_b = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        else:
            self.model_b = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.model_d.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_d.th'))['state_dict'])
        self.model_b.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_b.th'))['state_dict'])
        self.board_lff_acc(step=0, inference=True)
        
    # evaluation with caliberated inference
    def _evaluate(self, model_b, model_d, data_loader, var=False, his=False):
        """
        return list of acc: [bias conflict, bias neutral, bias align, overall]
        """
        model_b.eval()
        model_d.eval()
        total_correct, total_num = 0, 0
        pred_ls = []
        loss_ls = []
        label_ls = []
        s_ls = []
        
        ls = nn.LogSigmoid()
        for data, attr, index in tqdm(data_loader, leave=False):
            label = attr[:, 0].to(self.device)
            s = attr[:, 1].to(self.device)
            data = data.to(self.device)
            # s = s.to(self.device)
            

            with torch.no_grad():
                logit_b = model_b(data)
                logit_d = model_d(data)
                if self.args.infer == "ls_d-ls_b":
                    logit = ls(logit_d) - ls(logit_b)
                elif self.args.infer == "d-b":
                    logit = logit_d - logit_b
                else:
                    dft = [-0.01, -0.01, -0.01, -0.01]
                    result = (dft, dft)
                    if var:
                        result += (dft,)
                    result += (dft,)
                    return result
                    # raise ValueError("invalid inference procedure")
                loss = self.criterion(logit, label)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                # correct = (pred == label).long()
                # total_correct += correct.sum()
                # total_num += correct.shape[0]
                pred_ls.append(pred)
                loss_ls.append(loss)
                label_ls.append(label)
                s_ls.append(s)
        
        # pred = torch.concatenate(pred_ls)
        # label = torch.concatenate(label_ls)
        # s = torch.concatenate(s_ls)
        pred = torch.cat(pred_ls)
        loss = torch.cat(loss_ls)
        label = torch.cat(label_ls)
        s = torch.cat(s_ls)
        bias = data_loader.dataset.bias
        correct = (pred == label)
        
        # partition according to bias state and calc accs
        bias_state = get_bias_state(label, s, bias).squeeze(1)
        
        accs = []
        losses = []
        loss_var = []
        hist = []
        if self.dataset in ['bar', 'NICO']: # class-wise sub group
            for state in range(self.n_classes):
                total_correct = correct[bias_state==state].sum()
                total = correct[bias_state==state].shape[0]
                l = loss[bias_state==state].mean()
                l_v = loss[bias_state==state].var()
                if his:
                    h = wandb.Histogram(loss[bias_state==state].cpu().numpy())
                else:
                    h = None
                if total == 0:
                    accs.append(-0.01)
                    losses.append(-0.01)
                    loss_var.append(-0.01)
                    hist.append(None)
                    continue
                accs.append((total_correct/float(total)).item())
                losses.append(l.item())
                loss_var.append(l_v.item())
                hist.append(h)
        else:
            for state in [-1, 0, 1]:
                total_correct = correct[bias_state==state].sum()
                total = correct[bias_state==state].shape[0]
                l = loss[bias_state==state].mean()
                l_v = loss[bias_state==state].var()
                if his:
                    h = wandb.Histogram(loss[bias_state==state].cpu().numpy())
                else:
                    h = None
                if total == 0:
                    accs.append(-0.01)
                    losses.append(-0.01)
                    loss_var.append(-0.01)
                    hist.append(None)
                    continue
                accs.append((total_correct/float(total)).item())
                losses.append(l.item())
                loss_var.append(l_v.item())
                hist.append(h)

        # calc overall acc
        dataset = self.args.dataset.split('-')[0] if '-' in self.args.dataset else self.args.dataset
        if dataset in ['cmnist', 'scmnist', "bffhq", "bar", "dogs_and_cats", "corruptedCifar10"]:
            total_correct = correct.sum()
            total = correct.shape[0]
            overall_acc = total_correct/float(total)
            overall_loss = loss.mean()
            overall_loss_var = loss.var()
            
        else:
            n_classes =  data_loader.dataset.n_classes
            n_s =  data_loader.dataset.n_s
            total_acc = torch.zeros((n_classes, n_s))
            total_l = torch.zeros((n_classes, n_s))
            total_lv = torch.zeros((n_classes, n_s))
            n_exist_group = 0
            # go through all groups
            for i in range(n_classes):
                mask_i = label == i
                for j in range(n_s):
                    mask_j = s == j
                    mask = mask_i & mask_j # mask for group with label=i and attr=j
                    total_correct = correct[mask].sum()
                    total = correct[mask].shape[0]
                    if total == 0: # skip if the group does not exist
                        continue
                    l = loss[mask].mean()
                    l_v = loss[mask].var()
                    total_acc[i][j] = total_correct / float(total)
                    total_l[i][j] = l
                    total_lv[i][j] = l_v
                    n_exist_group += 1
            # average
            overall_acc = total_acc.sum() / n_exist_group
            overall_loss = total_l.sum() / n_exist_group
            overall_loss_var = total_lv.sum() / n_exist_group
        if his:
            overall_hist = wandb.Histogram(loss.cpu().numpy())
        else:
            overall_hist = None

        accs.append(overall_acc.item())
        losses.append(overall_loss.item())
        loss_var.append(overall_loss_var.item())
        hist.append(overall_hist)

        # accs = total_correct/float(total_num)
        model_b.train()
        model_d.train()
        result = (accs, losses)
        if var:
            result += (loss_var,)
        result += (hist,)
        return result

# class ShufflePatches(object):
#   def __init__(self, patch_size):
#     self.ps = patch_size

#   def __call__(self, x):
#     # divide the batch of images into non-overlapping patches
#     u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
#     # permute the patches of each image in the batch
#     pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
#     # fold the permuted patches back together
#     f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
#     return f