from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import torch.optim as optim

from data.util import get_dataset, IdxDataset, data2batch_size, data2model, data2preprocess
from module.loss import GeneralizedCELoss
from module.util import get_model
from module.util import get_backbone
from util import *

import warnings
warnings.filterwarnings(action='ignore')
import copy
import wandb



class Learner(object):
    def __init__(self, args):
        self.args = args

        if args.wandb:
            import wandb
            wandb.init(project=args.exp, config=args)
            # wandb.run.name = args.exp

        run_name = args.exp
        if args.tensorboard:
            from tensorboardX import SummaryWriter
            self.writer = SummaryWriter(f'result/{args.tensorboard_dir}/{run_name}')

        if '-' in args.dataset:
            dataset = args.dataset.split('-')[0]
        else:
            dataset = args.dataset
            
            
        self.model = data2model[dataset]
        if "mnist" in args.dataset:
            self.model = args.model
        self.batch_size = data2batch_size[dataset]

        print(f'model: {self.model} || dataset: {args.dataset}')
        print(f'working with experiment: {args.exp}...')
        self.log_dir = os.makedirs(os.path.join(args.log_dir, args.dataset, args.exp), exist_ok=True)
        self.device = torch.device(args.device)
        self.args = args

        print(self.args)

        # logging directories
        self.log_dir = os.path.join(args.log_dir, args.dataset, args.exp)
        self.summary_dir =  os.path.join(args.log_dir, args.dataset, args.tensorboard_dir, args.exp)
        self.result_dir = os.path.join(self.log_dir, "result")
        os.makedirs(self.summary_dir, exist_ok=True)
        os.makedirs(self.result_dir, exist_ok=True)
        
            
        self.train_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="train",
            transform_split="train",
            percent=args.percent,
            use_preprocess=data2preprocess[dataset],
            args=args
        )
        self.valid_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="valid",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[dataset],
            args=args
        )

        self.test_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="test",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[dataset],
            args=args
        )

        # move to dataset class attribute
        # train_target_attr = []
        # for data in self.train_dataset.data:
        #     train_target_attr.append(int(data.split('/')[-1].split('_')[0])) #!!!!!!
        train_target_attr = self.train_dataset.y
        train_target_attr = torch.LongTensor(train_target_attr)

        attr_dims = []
        attr_dims.append(torch.max(train_target_attr).item() + 1)
        # self.num_classes = attr_dims[0]
        self.num_classes = self.train_dataset.n_classes

        self.train_dataset = IdxDataset(self.train_dataset)

        # make loader
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True
        )

        self.pretrain_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False
        )

        self.valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        # define model and optimizer
        self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        self.model_d = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        self.optimizer_d = torch.optim.Adam(
                self.model_d.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        # define loss
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        print(f'self.criterion: {self.criterion}')
        
        self.debias_criterion = GeneralizedCELoss(q=args.q_d) if args.debias_criterion == "GCE" else nn.CrossEntropyLoss(reduction='none')
        print(f'self.debias_criterion: {self.debias_criterion}')

        self.bias_criterion = GeneralizedCELoss(q=args.q) if args.bias_criterion == "GCE" else nn.CrossEntropyLoss(reduction='none')
        print(f'self.bias_criterion: {self.bias_criterion}')

        self.sample_loss_ema_b = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=args.ema_alpha)
        self.sample_loss_ema_d = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=args.ema_alpha)

        print(f'alpha : {self.sample_loss_ema_d.alpha}')
        
        # evaluation setup
        # acc for [bc, bn, ba, overall]
        self.best_valid_acc_b, self.best_test_acc_b = [0,0,0,0], [0,0,0,0]
        self.best_valid_loss_b, self.best_test_loss_b = [0,0,0,0], [0,0,0,0]
        self.best_valid_acc_d, self.best_test_acc_d = [0,0,0,0], [0,0,0,0]
        self.best_valid_loss_d, self.best_test_loss_d = [0,0,0,0], [0,0,0,0]
        self.best_valid_acc_i, self.best_test_acc_i = [0,0,0,0], [0,0,0,0]
        self.best_valid_loss_i, self.best_test_loss_i = [0,0,0,0], [0,0,0,0]
        
        self.dataset = self.args.dataset.split('-')[0] if '-' in self.args.dataset else self.args.dataset
        self.n_classes = self.train_dataset.dataset.n_classes
        if self.dataset in ["bar", "NICO"]: # class-wise grouping
            self.keys = [f"cls-{i}" for i in range(self.n_classes)] + ["OA"]
        else:
            self.keys = ["BC", "BN", "BA", "OA"]
            
        

        print('finished model initialization....')

    # evaluation code for vanilla
    def evaluate(self, model, data_loader, var=False, his=False):
        """
        return list of acc: [bias conflict, bias neutral, bias align, overall]
        """
        model.eval()
        total_correct, total_num = 0, 0
        pred_ls = []
        loss_ls = []
        label_ls = []
        s_ls = []
        
        for data, attr, index in tqdm(data_loader, leave=False):
            label = attr[:, self.args.target_attr_idx].to(self.device)
            s = attr[:, self.args.bias_attr_idx].to(self.device)
            data = data.to(self.device)
            # s = s.to(self.device)

            with torch.no_grad():
                logit = model(data)
                loss = self.criterion(logit, label)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                # correct = (pred == label).long()
                # total_correct += correct.sum()
                # total_num += correct.shape[0]
                pred_ls.append(pred)
                loss_ls.append(loss)
                label_ls.append(label)
                s_ls.append(s)
        
        # pred = torch.concatenate(pred_ls)
        # label = torch.concatenate(label_ls)
        # s = torch.concatenate(s_ls)
        pred = torch.cat(pred_ls)
        loss = torch.cat(loss_ls)
        label = torch.cat(label_ls)
        s = torch.cat(s_ls)
        bias = data_loader.dataset.bias if self.args.bias_attr_idx == 1 else data_loader.dataset.bias2
        correct = (pred == label)
        
        # partition according to bias state and calc accs
        bias_state = get_bias_state(label, s, bias).squeeze(1)
        
        accs = []
        losses = []
        loss_var = []
        hist = []

        
        if self.dataset in ['bar', 'NICO']: # class-wise sub group accuracy
            for state in range(self.n_classes):
                total_correct = correct[label==state].sum()
                total = correct[label==state].shape[0]
                l = loss[label==state].mean()
                l_v = loss[label==state].var()
                if his:
                    h = wandb.Histogram(loss[label==state].cpu().numpy())
                else:
                    h = None
                if total == 0:
                    accs.append(-0.01)
                    losses.append(-0.01)
                    loss_var.append(-0.01)
                    hist.append(None)
                    continue
                accs.append((total_correct/float(total)).item())
                losses.append(l.item())
                loss_var.append(l_v.item())
                hist.append(h)
        else: # bias state based sub group accuracy
            for state in [-1, 0, 1]:
                total_correct = correct[bias_state==state].sum()
                total = correct[bias_state==state].shape[0]
                l = loss[bias_state==state].mean()
                l_v = loss[bias_state==state].var()
                if his:
                    h = wandb.Histogram(loss[bias_state==state].cpu().numpy())
                else:
                    h = None
                if total == 0:
                    accs.append(-0.01)
                    losses.append(-0.01)
                    loss_var.append(-0.01)
                    hist.append(None)
                    continue
                accs.append((total_correct/float(total)).item())
                losses.append(l.item())
                loss_var.append(l_v.item())
                hist.append(h)

        # calc overall acc
        if self.dataset in ['cmnist', 'scmnist', "bffhq", "bar", "dogs_and_cats", "corruptedCifar10", "cifar10c", "cifar10c_mb"]:
            total_correct = correct.sum()
            total = correct.shape[0]
            overall_acc = total_correct/float(total)
            overall_loss = loss.mean()
            overall_loss_var = loss.var()
            
        else:
            n_classes =  data_loader.dataset.n_classes
            n_s =  data_loader.dataset.n_s
            total_acc = torch.zeros((n_classes, n_s))
            total_l = torch.zeros((n_classes, n_s))
            total_lv = torch.zeros((n_classes, n_s))
            n_exist_group = 0
            # go through all groups
            for i in range(n_classes):
                mask_i = label == i
                for j in range(n_s):
                    mask_j = s == j
                    mask = mask_i & mask_j # mask for group with label=i and attr=j
                    total_correct = correct[mask].sum()
                    total = correct[mask].shape[0]
                    if total == 0: # skip if the group does not exist
                        continue
                    l = loss[mask].mean()
                    l_v = loss[mask].var()
                    total_acc[i][j] = total_correct / float(total)
                    total_l[i][j] = l
                    total_lv[i][j] = l_v
                    n_exist_group += 1
            # average
            overall_acc = total_acc.sum() / n_exist_group
            overall_loss = total_l.sum() / n_exist_group
            overall_loss_var = total_lv.sum() / n_exist_group
        if his:
            overall_hist = wandb.Histogram(loss.cpu().numpy())
        else:
            overall_hist = None
                    


                    


        accs.append(overall_acc.item())
        losses.append(overall_loss.item())
        loss_var.append(overall_loss_var.item())
        hist.append(overall_hist)

        # accs = total_correct/float(total_num)
        model.train()
        result = (accs, losses)
        if var:
            result += (loss_var,)
        # if his:
        result += (hist,)
        return result
    
    def save_best(self, step):
        model_path = os.path.join(self.result_dir, "best_model_d.th")
        state_dict = {
            'steps': step,
            'state_dict': self.model_d.state_dict(),
            'optimizer': self.optimizer_d.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)

        model_path = os.path.join(self.result_dir, "best_model_b.th")
        state_dict = {
            'steps': step,
            'state_dict': self.model_b.state_dict(),
            'optimizer': self.optimizer_b.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)

        print(f'{step} model saved ...')
        
    def save_vanilla(self, step, best=None):
        if best:
            model_path = os.path.join(self.result_dir, "best_model.th")
        else:
            model_path = os.path.join(self.result_dir, "model_{}.th".format(step))
        state_dict = {
            'steps': step,
            'state_dict': self.model_b.state_dict(),
            'optimizer': self.optimizer_b.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)
        print(f'{step} model saved ...')

    def board_vanilla_loss(self, step, loss_b):
        if self.args.wandb:
            wandb.log({
                "loss_b_train": loss_b,
            }, step=step,)

        if self.args.tensorboard:
            self.writer.add_scalar(f"loss/loss_b_train", loss_b, step)

    def board_vanilla_acc(self, step, epoch, inference=None, his=False):
        valid_accs_b, valid_losses_b, valid_var_b, valid_hist_b = self.evaluate(self.model_b, self.valid_loader, var=True, his=his)
        test_accs_b, test_losses_b, test_var_b, test_hist_b = self.evaluate(self.model_b, self.test_loader, var=True, his=his)

        print(f'epoch: {epoch}')

        if valid_accs_b[-1] >= self.best_valid_acc_b[-1]:
            self.best_valid_acc_b = valid_accs_b
            self.best_valid_loss_b = valid_losses_b
        if test_accs_b[-1] >= self.best_test_acc_b[-1]:
            self.best_test_acc_b = test_accs_b
            self.best_test_loss_b = test_losses_b
            self.save_vanilla(step, best=True) # ! model selection should be according to valid acc rather than test acc

        print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b}')
        
        log = {
                "acc/acc_b_valid": trans_result(valid_accs_b, self.keys),
                "acc/acc_b_test": trans_result(test_accs_b, self.keys),
                "acc/best_acc_b_valid": trans_result(self.best_valid_acc_b, self.keys),
                "acc/best_acc_b_test": trans_result(self.best_test_acc_b, self.keys),

                "loss/loss_b_valid": trans_result(valid_losses_b, self.keys),
                "loss/loss_b_test": trans_result(test_losses_b, self.keys),
                "loss/best_loss_b_valid": trans_result(self.best_valid_loss_b, self.keys),
                "loss/best_loss_b_test": trans_result(self.best_test_loss_b, self.keys),

                "loss/var_b_valid": trans_result(valid_var_b, self.keys),
                "loss/var_b_test": trans_result(test_var_b, self.keys),
                
                "hist/loss_b_valid": trans_result(valid_hist_b, self.keys),
                "hist/loss_b_test": trans_result(test_hist_b, self.keys),
            }

        if "_mb" in self.args.dataset: # for multiple bias datasets, do extra evalutations
            self.args.bias_attr_idx = 2
            test_accs_b_2, _, _ = self.evaluate(self.model_b, self.test_loader, his=his)
            log_ = {
                "acc/acc_b_test_2": trans_result(test_accs_b_2, self.keys),
            }
            log.update(log_)
            self.args.bias_attr_idx = 1
        
        if self.args.wandb:
            wandb.log(log, step=step)
            # trans_hist_valid = trans_result(valid_hist_b)
            # trans_hist_test = trans_result(test_hist_b)
            # for k in trans_hist_test:
            #     wandb.log({
            #         f"hist/loss_b_valid": wandb.Histogram(np_histogram=trans_hist_valid[k]),
            #         f"hist/loss_b_test": wandb.Histogram(np_histogram=trans_hist_test[k])
            #     }, step=step)

    def train(self, args):
        # training vanilla ...
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)
        epoch, cnt = 0, 0

        for step in tqdm(range(args.num_steps)):
            try:
                index, data, attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _ = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            label = attr[:, args.target_attr_idx]

            logit_b = self.model_b(data)
            loss_b_update = self.criterion(logit_b, label)
            loss = loss_b_update.mean()

            self.optimizer_b.zero_grad()
            loss.backward()
            self.optimizer_b.step()

            ##################################################
            #################### LOGGING #####################
            ##################################################

            # if step % args.save_freq == 0:
            #     self.save_vanilla(step)

            if step % args.log_freq == 0:
                self.board_vanilla_loss(step, loss_b=loss)

            if step % args.valid_freq == 0:
                self.board_vanilla_acc(step, epoch)

            cnt += len(index)
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += 1
                cnt = 0

def trans_result(accs, keys=["BC", "BN", "BA", "OA"]):
    result = {
        # "BC": accs[0],
        # "BN": accs[1],
        # "BA": accs[2],
        # "OA": accs[3],
    }
    for i in range(len(accs)):
        result[keys[i]] = accs[i]
    return result

def get_bias_state(label, bias_label, bias):
    """ partition according to bias state and calc accs"""
    if isinstance(bias_label, torch.Tensor):
        bias_state = torch.zeros((bias_label.shape[0], 1)).cuda()
    else:
        bias_state = np.zeros((bias_label.shape[0], 1))
    for i in range(len(bias)):
        if bias[i] is None:
            # skip bias neutral samples
            continue
        bias_class = bias[i]
        ba_mask = (bias_label == i) & (label == bias_class)
        bc_mask = (bias_label == i) & (label != bias_class)
        bias_state[ba_mask] = 1
        bias_state[bc_mask] = -1
    return bias_state