from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import torch.optim as optim

from data.util import get_dataset, IdxDataset
from module.loss import GeneralizedCELoss
from module.util import get_model
from module.util import get_backbone
from util import *

import warnings
warnings.filterwarnings(action='ignore')
import copy
import wandb

from learners.base_learner import *

class HardLearner(Learner):
    def __init__(self, args):
        super(HardLearner, self).__init__(args)
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        print(f'self.criterion: {self.criterion}')

    def board_lff_loss(self, step, loss_b, loss_d):
        
        if self.args.wandb:
            wandb.log({
                "loss/loss_b_train": loss_b,
                "loss/loss_d_train": loss_d,
                }, step=step)
        
        if self.args.tensorboard:
            self.writer.add_scalar(f"loss/loss_b_train", loss_b, step)
            self.writer.add_scalar(f"loss/loss_d_train", loss_d, step)

    def board_weights(self, step):
        weights = []
        loss_b_ls = []
        loss_d_ls = []
        for state in [-1, 0, 1, None]:
            index = torch.ones_like(self.label_index, dtype=torch.bool).to(self.device)
            if state is not None:
                index = self.bias_state.squeeze(1) == state
            label = self.label_index[index]
            # skip if certain subgroup is empty
            if label.shape[0] == 0:
                weights.append(-0.1)
                loss_b_ls.append(-0.1)
                loss_d_ls.append(-0.1)
                continue
            # class-wise normalize
            loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
            loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()
            loss_b_ls.append(loss_b.mean())
            loss_d_ls.append(loss_d.mean())
            weights.append(self.calc_weights(loss_b, loss_d, label).mean())

        weights = trans_result(weights)
        loss_b_ls = trans_result(loss_b_ls)
        loss_d_ls = trans_result(loss_d_ls)

        log_dict = {
                    "w(x)_mean": weights,
                    "loss/loss_b": loss_b_ls,
                    "loss/loss_d": loss_d_ls,
                }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def calc_weights(self, loss_b, loss_d, label):
        if np.isnan(loss_b.mean().item()):
            raise NameError('loss_b_ema')
        if np.isnan(loss_d.mean().item()):
            raise NameError('loss_d_ema')

        label_cpu = label.cpu()

        for c in range(self.num_classes):
            class_index = np.where(label_cpu == c)[0]
            if "max" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.max_loss(c)
            elif "mean" in self.args.w_func:
                norm_loss_b = self.sample_loss_ema_b.mean_loss(c) + 1e-8
                norm_loss_d = self.sample_loss_ema_d.mean_loss(c)
            loss_b[class_index] /= norm_loss_b
            loss_d[class_index] /= norm_loss_d
        
        scale = 2 if 's2' in self.args.w_func else 1

        loss_weight = scale * loss_b / (loss_b + loss_d + 1e-8)

        return loss_weight
        # loss_weight_total = loss_weight

        # loss_weight_total = loss_weight_total.mean()
        # return loss_weight_total

    def board_lff_wx(self, step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag):
        log_dict = {
                    "w(x)_mean/align": loss_weight[aw_flag | ac_flag].mean(),
                    "w(x)_mean/conflict": loss_weight[cw_flag | cc_flag].mean(),
                }
        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

    def board_lff_acc(self, step, inference=None):
        # check label network
        valid_accs_b, valid_losses_b = self.evaluate(self.model_b, self.valid_loader)
        test_accs_b, test_losses_b = self.evaluate(self.model_b, self.test_loader)

        valid_accs_d, valid_losses_d = self.evaluate(self.model_d, self.valid_loader)
        test_accs_d, test_losses_d = self.evaluate(self.model_d, self.test_loader)

        if inference:
            print(f'test acc: {test_accs_d.item()}')
            import sys
            sys.exit(0)

        if valid_accs_b[-1] >= self.best_valid_acc_b[-1]:
            self.best_valid_acc_b = valid_accs_b
            self.best_valid_loss_b = valid_losses_b

        if test_accs_b[-1] >= self.best_test_acc_b[-1]:
            self.best_test_acc_b = test_accs_b
            self.best_test_loss_b = test_losses_b

        if valid_accs_d[-1] >= self.best_valid_acc_d[-1]:
            self.best_valid_acc_d = valid_accs_d
            self.best_valid_loss_d = valid_losses_d

        if test_accs_d[-1] >= self.best_test_acc_d[-1]:
            self.best_test_acc_d = test_accs_d
            self.best_test_loss_d = test_losses_d
            self.save_best(step)

        if self.args.tensorboard:
            self.writer.add_scalar(f"acc/acc_b_valid", valid_accs_b, step)
            self.writer.add_scalar(f"acc/acc_b_test", test_accs_b, step)
            self.writer.add_scalar(f"acc/best_acc_b_valid", self.best_valid_acc_b, step)
            self.writer.add_scalar(f"acc/best_acc_b_test", self.best_test_acc_b, step)
            self.writer.add_scalar(f"acc/acc_d_valid", valid_accs_d, step)
            self.writer.add_scalar(f"acc/acc_d_test", test_accs_d, step)
            self.writer.add_scalar(f"acc/best_acc_d_valid", self.best_valid_acc_d, step)
            self.writer.add_scalar(f"acc/best_acc_d_test", self.best_test_acc_d, step)
            
        if self.args.wandb:
            wandb.log({
                "acc/acc_b_valid": trans_result(valid_accs_b),
                "acc/acc_b_test": trans_result(test_accs_b),
                "acc/best_acc_b_valid": trans_result(self.best_valid_acc_b),
                "acc/best_acc_b_test": trans_result(self.best_test_acc_b),
                "acc/acc_d_valid": trans_result(valid_accs_d),
                "acc/acc_d_test": trans_result(test_accs_d),
                "acc/best_acc_d_valid": trans_result(self.best_valid_acc_d),
                "acc/best_acc_d_test": trans_result(self.best_test_acc_d),
                
                "loss/loss_b_valid": trans_result(valid_losses_b),
                "loss/loss_b_test": trans_result(test_losses_b),
                "loss/best_loss_b_valid": trans_result(self.best_valid_loss_b),
                "loss/best_loss_b_test": trans_result(self.best_test_loss_b),
                "loss/loss_d_valid": trans_result(valid_losses_d),
                "loss/loss_d_test": trans_result(test_losses_d),
                "loss/best_loss_d_valid": trans_result(self.best_valid_loss_d),
                "loss/best_loss_d_test": trans_result(self.best_test_loss_d)
            }, step=step)

        print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b} ')
        print(f'valid_d: {valid_accs_d} || test_d: {test_accs_d} ')

    def board_pretrain_best_acc(self, i, model_b, best_valid_acc_b, step):
        # check label network
        valid_accs_b, valid_losses_b = self.evaluate(model_b, self.valid_loader)

        print(f'best: {best_valid_acc_b}, curr: {valid_accs_b}')

        if valid_accs_b[-1] > best_valid_acc_b[-1]:
            best_valid_acc_b = valid_accs_b

            ######### copy parameters #########
            self.best_model_b = copy.deepcopy(model_b)
            print(f'early model {i}th saved...')

        log_dict = {
            f"{i}_pretrain_best_valid_acc": trans_result(best_valid_acc_b),
        }

        if self.args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)
                
        if self.args.wandb:
            wandb.log(log_dict, step=step)

        return best_valid_acc_b

    def pretrain_b_ensemble_best(self, args):
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)
        epoch, cnt = 0, 0
        index_dict, label_dict, gt_prob_dict = {}, {}, {}

        # train multiple bias models on train_iter, select model with best valid acc. Then get the prediction confidence of biased models on the pretrain_loader data. Samples with gt confidence above a given threshold is marked by their index in exceed_mask. Among those samples, bias aligned sample are marked in exceed_align, while bias conflict are marked in exceed_conflict.
        for i in range(self.args.num_bias_models):
            best_valid_acc_b = 0
            print(f'{i}th model working ...')
            del self.model_b
            self.best_model_b = None
            self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained, first_stage=True).to(self.device)
            self.optimizer_b = torch.optim.Adam(self.model_b.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
            for step in tqdm(range(self.args.biased_model_train_iter)):
                try:
                    index, data, attr, _ = next(train_iter)
                except:
                    train_iter = iter(self.train_loader)
                    index, data, attr, _ = next(train_iter)

                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]

                logit_b = self.model_b(data)
                loss_b_update = self.bias_criterion(logit_b, label)
                loss = loss_b_update.mean()

                self.optimizer_b.zero_grad()
                loss.backward()
                self.optimizer_b.step()

                cnt += len(index)
                if cnt >= train_num:
                    print(f'finished epoch: {epoch}')
                    epoch += 1
                    cnt = len(index)

                if step % args.valid_freq == 0:
                    best_valid_acc_b = self.board_pretrain_best_acc(i, self.model_b, best_valid_acc_b, step)

            label_list, bias_list, pred_list, index_list, gt_prob_list, align_flag_list = [], [], [], [], [], []
            self.best_model_b.eval()

            for index, data, attr, _ in self.pretrain_loader:
                index = index.to(self.device)
                data = data.to(self.device)
                attr = attr.to(self.device)
                label = attr[:, args.target_attr_idx]
                bias_label = attr[:, args.bias_attr_idx]

                logit_b = self.best_model_b(data)
                prob = torch.softmax(logit_b, dim=-1)
                gt_prob = torch.gather(prob, index=label.unsqueeze(1), dim=1).squeeze(1)

                label_list += label.tolist()
                index_list += index.tolist()
                gt_prob_list += gt_prob.tolist()
                align_flag_list += (label == bias_label).tolist()

            index_list = torch.tensor(index_list)
            label_list = torch.tensor(label_list)
            gt_prob_list = torch.tensor(gt_prob_list)
            align_flag_list = torch.tensor(align_flag_list)

            align_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == True)).long()
            conflict_mask = ((gt_prob_list > args.biased_model_softmax_threshold) & (align_flag_list == False)).long()
            mask = (gt_prob_list > args.biased_model_softmax_threshold).long()

            exceed_align = index_list[align_mask.nonzero().squeeze(1)]
            exceed_conflict = index_list[conflict_mask.nonzero().squeeze(1)]
            exceed_mask = index_list[mask.nonzero().squeeze(1)]

            model_index = i
            index_dict[f'{model_index}_exceed_align'] = exceed_align
            index_dict[f'{model_index}_exceed_conflict'] = exceed_conflict
            index_dict[f'{model_index}_exceed_mask'] = exceed_mask
            label_dict[model_index] = label_list
            gt_prob_dict[model_index] = gt_prob_list

            log_dict = {
                f"{model_index}_exceed_align": len(exceed_align),
                f"{model_index}_exceed_conflict": len(exceed_conflict),
                f"{model_index}_exceed_mask": len(exceed_mask),
            }
            if args.tensorboard:
                for key, value in log_dict.items():
                    self.writer.add_scalar(key, value, step)

        # aggregate confident samples, ba confident samples, bc conflict samples from all bias models
        exceed_mask = [(gt_prob_dict[i] > args.biased_model_softmax_threshold).long() for i in
                        range(self.args.num_bias_models)]
        exceed_mask_align = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == True)).long() for i in
            range(self.args.num_bias_models)]
        exceed_mask_conflict = [
            ((gt_prob_dict[i] > args.biased_model_softmax_threshold) & (align_flag_list == False)).long() for i in
            range(self.args.num_bias_models)]
        
        # sum to calculate how likely each sample is considered confident in bias models
        mask_sum = torch.stack(exceed_mask).sum(dim=0)
        mask_sum_align = torch.stack(exceed_mask_align).sum(dim=0)
        mask_sum_conflict = torch.stack(exceed_mask_conflict).sum(dim=0)

        # if the number of bias models that consider a sample confident is larger then the threshold args.agreement, the sample is considered the final exceed sample
        total_exceed_mask = index_list[(mask_sum >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_align = index_list[(mask_sum_align >= self.args.agreement).long().nonzero().squeeze(1)]
        total_exceed_conflict = index_list[(mask_sum_conflict >= self.args.agreement).long().nonzero().squeeze(1)]

        exceed_mask_list = [total_exceed_mask]

        print(f'exceed mask list length: {len(exceed_mask_list)}')
        curr_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                              torch.tensor(total_exceed_mask).long().cuda())
        curr_align_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                    torch.tensor(total_exceed_align).long().cuda())
        curr_conflict_index_label = torch.index_select(label_dict[0].unsqueeze(1).to(self.device), 0,
                                                       torch.tensor(total_exceed_conflict).long().cuda())
        log_dict = {
            f"total_exceed_align": len(total_exceed_align),
            f"total_exceed_conflict": len(total_exceed_conflict),
            f"total_exceed_mask": len(total_exceed_mask),
        }

        total_exceed_mask = torch.tensor(total_exceed_mask)

        for key, value in log_dict.items():
            print(f"* {key}: {value}")
        print(f"* EXCEED DATA COUNT: {Counter(curr_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (ALIGN) COUNT: {Counter(curr_align_index_label.squeeze(1).tolist())}")
        print(f"* EXCEED DATA (CONFLICT) COUNT: {Counter(curr_conflict_index_label.squeeze(1).tolist())}")

        if args.tensorboard:
            for key, value in log_dict.items():
                self.writer.add_scalar(key, value, step)

        return total_exceed_mask

    def train(self, args):
        print('Training LfF ...')

        num_updated = 0
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)

        mask_index = torch.zeros(train_num, 1)
        self.bias_state = torch.zeros(train_num, 1).cuda()
        self.label_index = torch.zeros(train_num).long().cuda()

        epoch, cnt = 0, 0

        #### BiasEnsemble ####
        if args.bias_ensm:
            print("Applying Bias Ensemble for bias amplification ...")
            pseudo_align_flag = self.pretrain_b_ensemble_best(args)
            mask_index[pseudo_align_flag] = 1
        else:
            mask_index[:] = 1

        del self.model_b
        self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        if args.use_lr_decay:
            self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step,gamma=args.lr_gamma)
            self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_d, step_size=args.lr_decay_step,gamma=args.lr_gamma)

        # get bias mapping
        bias = self.train_loader.dataset.dataset.bias
        
        for step in tqdm(range(args.num_steps)):
            # train main model
            try:
                index, data, attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _ = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            index = index.to(self.device)
            label = attr[:, args.target_attr_idx]
            bias_label = attr[:, args.bias_attr_idx]

            # flag_conflict = (label != bias_label)
            bias_state = get_bias_state(label, bias_label, bias)
            self.bias_state[index] = bias_state
            self.label_index[index] = label

            logit_b = self.model_b(data)
            logit_d = self.model_d(data)

            loss_b = self.criterion(logit_b, label).detach()
            loss_d = self.criterion(logit_d, label).detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d')

            # EMA sample loss: smoothly update the loss of each sample
            self.sample_loss_ema_b.update(loss_b, index)
            self.sample_loss_ema_d.update(loss_d, index)

            # class-wise normalize: get the smoothed loss as the final loss, and normalize the sample loss with the maximum loss of its corresponding class
            loss_b = self.sample_loss_ema_b.parameter[index].clone().detach()
            loss_d = self.sample_loss_ema_d.parameter[index].clone().detach()

            if np.isnan(loss_b.mean().item()):
                raise NameError('loss_b_ema')
            if np.isnan(loss_d.mean().item()):
                raise NameError('loss_d_ema')

            loss_weight = self.calc_weights(loss_b, loss_d, label)
            # label_cpu = label.cpu()

            # for c in range(self.num_classes):
            #     class_index = np.where(label_cpu == c)[0]
            #     max_loss_b = self.sample_loss_ema_b.max_loss(c) + 1e-8
            #     max_loss_d = self.sample_loss_ema_d.max_loss(c)
            #     loss_b[class_index] /= max_loss_b
            #     loss_d[class_index] /= max_loss_d

            # # re-weighting based on loss value / generalized CE for biased model
            # loss_weight = loss_b / (loss_b + loss_d + 1e-8)
            pred = logit_d.data.max(1, keepdim=True)[1].squeeze(1)


            if np.isnan(loss_weight.mean().item()):
                raise NameError('loss_weight')

            curr_align_flag = torch.index_select(mask_index.to(self.device), 0, index)
            curr_align_flag = (curr_align_flag.squeeze(1) == 1)

            # update the biased classifier only on confident samples agreed by the biased committee
            loss_b_update = self.bias_criterion(logit_b[curr_align_flag], label[curr_align_flag])
            loss_d_update = self.criterion(logit_d, label) * loss_weight.to(self.device)

            if np.isnan(loss_b_update.mean().item()):
                raise NameError('loss_b_update')

            if np.isnan(loss_d_update.mean().item()):
                raise NameError('loss_d_update')

            loss = loss_b_update.mean() + loss_d_update.mean()
            num_updated += loss_weight.mean().item() * data.size(0)

            self.optimizer_b.zero_grad()
            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_b.step()
            self.optimizer_d.step()

            if args.use_lr_decay:
                self.scheduler_b.step()
                self.scheduler_l.step()

            if args.use_lr_decay and step % args.lr_decay_step == 0:
                print('******* learning rate decay .... ********')
                print(f"self.optimizer_b lr: {self.optimizer_b.param_groups[-1]['lr']}")
                print(f"self.optimizer_d lr: {self.optimizer_d.param_groups[-1]['lr']}")

            if step % args.log_freq == 0:
                self.board_lff_loss(step, loss_b_update.mean(), loss_d_update.mean())
                bias_label = attr[:, args.bias_attr_idx]

                ### used bias labels for logging
                ac_flag = (label == bias_label) & (label == pred)
                aw_flag = (label == bias_label) & (label != pred)
                cc_flag = (label != bias_label) & (label == pred)
                cw_flag = (label != bias_label) & (label != pred)

                ac_flag = ac_flag & curr_align_flag
                aw_flag = aw_flag & curr_align_flag
                cc_flag = cc_flag & curr_align_flag
                cw_flag = cw_flag & curr_align_flag

                # self.board_lff_wx(step, loss_weight, ac_flag, aw_flag, cc_flag, cw_flag)

                if step > len(train_iter):
                    self.board_weights(step)

            if step % args.valid_freq == 0:
                self.board_lff_acc(step)

                if args.use_lr_decay and args.tensorboard:
                    self.writer.add_scalar(f"loss/learning rate", self.optimizer_d.param_groups[-1]['lr'], step)

            cnt += len(index)
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += len(index)
                cnt = 0

    def test(self, args):
        if 'mnist' in args.dataset:
            self.model_b = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("MLP", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        else:
            self.model_b = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
            self.model_d = get_backbone("ResNet18", self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.model_d.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_d.th'))['state_dict'])
        self.model_b.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'best_model_b.th'))['state_dict'])
        self.board_lff_acc(step=0, inference=True)
