from collections import Counter
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import torch.optim as optim

from data.util import get_dataset, IdxDataset, data2batch_size, data2model, data2preprocess
from module.loss import GeneralizedCELoss
from module.util import get_model
from module.util import get_backbone
from util import *

import warnings
warnings.filterwarnings(action='ignore')
import copy
import wandb

import torch.nn.functional as F

class TarLearner(object):
    def __init__(self, args):
        self.args = args

        if args.wandb:
            import wandb
            wandb.init(project=args.exp, config=args)
            # wandb.run.name = args.exp

        run_name = args.exp
        if args.tensorboard:
            from tensorboardX import SummaryWriter
            self.writer = SummaryWriter(f'result/{args.tensorboard_dir}/{run_name}')

        self.model = data2model[args.dataset]
        if "mnist" in args.dataset:
            self.model = args.model
        self.batch_size = data2batch_size[args.dataset]

        print(f'model: {self.model} || dataset: {args.dataset}')
        print(f'working with experiment: {args.exp}...')
        self.log_dir = os.makedirs(os.path.join(args.log_dir, args.dataset, args.exp), exist_ok=True)
        self.device = torch.device(args.device)
        self.args = args

        print(self.args)

        # logging directories
        self.log_dir = os.path.join(args.log_dir, args.dataset, args.exp)
        self.summary_dir =  os.path.join(args.log_dir, args.dataset, args.tensorboard_dir, args.exp)
        self.result_dir = os.path.join(self.log_dir, "result")
        os.makedirs(self.summary_dir, exist_ok=True)
        os.makedirs(self.result_dir, exist_ok=True)
        
            
        self.train_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="train",
            transform_split="train",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
            args=args
        )
        self.valid_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="valid",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
            args=args
        )

        self.test_dataset = get_dataset(
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="test",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
            args=args
        )

        # move to dataset class attribute
        train_target_attr = []
        for data in self.train_dataset.data:
            # train_target_attr.append(int(data.split('/')[-1].split('_')[0]))
            train_target_attr.append(int(os.path.basename(data).split('_')[0]))
        train_target_attr = torch.LongTensor(train_target_attr)

        attr_dims = []
        attr_dims.append(torch.max(train_target_attr).item() + 1)
        self.num_classes = attr_dims[0]

        self.train_dataset = IdxDataset(self.train_dataset)

        # make loader
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True
        )

        self.pretrain_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False
        )

        self.valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        # define model and optimizer
        self.model_b = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)
        self.model_d = get_backbone(self.model, self.num_classes, args=self.args, pretrained=self.args.resnet_pretrained).to(self.device)

        self.optimizer_b = torch.optim.Adam(
                self.model_b.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        self.optimizer_d = torch.optim.Adam(
                self.model_d.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        # define loss
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        print(f'self.criterion: {self.criterion}')

        self.bias_criterion = GeneralizedCELoss(q=0.7) if args.bias_criterion == "GCE" else nn.CrossEntropyLoss(reduction='none')
        print(f'self.bias_criterion: {self.bias_criterion}')

        self.sample_loss_ema_b = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=args.ema_alpha)
        self.sample_loss_ema_d = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=args.ema_alpha)

        print(f'alpha : {self.sample_loss_ema_d.alpha}')
        # acc for [bc, bn, ba, overall]
        self.best_valid_acc_b, self.best_test_acc_b = [0,0,0,0], [0,0,0,0]
        self.best_valid_loss_b, self.best_test_loss_b = [0,0,0,0], [0,0,0,0]
        self.best_valid_acc_d, self.best_test_acc_d = [0,0,0,0], [0,0,0,0]
        self.best_valid_loss_d, self.best_test_loss_d = [0,0,0,0], [0,0,0,0]

        print('finished model initialization....')

    # evaluation code for vanilla
    def evaluate(self, model, data_loader, var=False):
        """
        return list of acc: [bias conflict, bias neutral, bias align, overall]
        """
        model.eval()
        total_correct, total_num = 0, 0
        pred_ls = []
        loss_ls = []
        label_ls = []
        s_ls = []
        
        for data, attr, index in tqdm(data_loader, leave=False):
            label = attr[:, 0].to(self.device)
            s = attr[:, 1].to(self.device)
            data = data.to(self.device)
            # s = s.to(self.device)

            with torch.no_grad():
                logit,imgs_rec = model(data)
                loss = self.criterion(logit, label)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                # correct = (pred == label).long()
                # total_correct += correct.sum()
                # total_num += correct.shape[0]
                pred_ls.append(pred)
                loss_ls.append(loss)
                label_ls.append(label)
                s_ls.append(s)
        
        # pred = torch.concatenate(pred_ls)
        # label = torch.concatenate(label_ls)
        # s = torch.concatenate(s_ls)
        pred = torch.cat(pred_ls)
        loss = torch.cat(loss_ls)
        label = torch.cat(label_ls)
        s = torch.cat(s_ls)
        bias = data_loader.dataset.bias
        correct = (pred == label)
        
        # partition according to bias state and calc accs
        bias_state = get_bias_state(label, s, bias).squeeze(1)
        
        accs = []
        losses = []
        loss_var = []
        for state in [-1, 0, 1]:
            total_correct = correct[bias_state==state].sum()
            total = correct[bias_state==state].shape[0]
            l = loss[bias_state==state].mean()
            l_v = loss[bias_state==state].var()
            if total == 0:
                accs.append(-0.01)
                losses.append(-0.01)
                loss_var.append(-0.01)
                continue
            accs.append((total_correct/float(total)).item())
            losses.append(l.item())
            loss_var.append(l_v.item())

        # calc overall acc
        total_correct = correct.sum()
        total = correct.shape[0]
        overall_acc = total_correct/float(total)
        overall_loss = loss.mean()
        overall_loss_var = loss.var()

        accs.append(overall_acc.item())
        losses.append(overall_loss.item())
        loss_var.append(overall_loss_var.item())

        # accs = total_correct/float(total_num)
        model.train()
        if var:
            return accs, losses, loss_var
        return accs, losses
    
    def save_best(self, step):
        model_path = os.path.join(self.result_dir, "best_model_d.th")
        state_dict = {
            'steps': step,
            'state_dict': self.model_d.state_dict(),
            'optimizer': self.optimizer_d.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)

        model_path = os.path.join(self.result_dir, "best_model_b.th")
        state_dict = {
            'steps': step,
            'state_dict': self.model_b.state_dict(),
            'optimizer': self.optimizer_b.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)

        print(f'{step} model saved ...')
        
    def save_vanilla(self, step, best=None):
        if best:
            model_path = os.path.join(self.result_dir, "best_model.th")
        else:
            model_path = os.path.join(self.result_dir, "model_{}.th".format(step))
        state_dict = {
            'steps': step,
            'state_dict': self.model_b.state_dict(),
            'optimizer': self.optimizer_b.state_dict(),
        }
        with open(model_path, "wb") as f:
            torch.save(state_dict, f)
        print(f'{step} model saved ...')

    def board_vanilla_loss(self, step, loss_b):
        if self.args.wandb:
            wandb.log({
                "loss_b_train": loss_b,
            }, step=step,)

        if self.args.tensorboard:
            self.writer.add_scalar(f"loss/loss_b_train", loss_b, step)

    def board_vanilla_acc(self, step, epoch, inference=None):
        valid_accs_b, valid_losses_b, valid_var_b = self.evaluate(self.model_b, self.valid_loader, var=True)
        test_accs_b, test_losses_b, test_var_b = self.evaluate(self.model_b, self.test_loader, var=True)

        print(f'epoch: {epoch}')

        if valid_accs_b[-1] >= self.best_valid_acc_b[-1]:
            self.best_valid_acc_b = valid_accs_b
            self.best_valid_loss_b = valid_losses_b
        if test_accs_b[-1] >= self.best_test_acc_b[-1]:
            self.best_test_acc_b = test_accs_b
            self.best_test_loss_b = test_losses_b
            self.save_vanilla(step, best=True) # ! model selection should be according to valid acc rather than test acc

        print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b}')

        if self.args.tensorboard:
            self.writer.add_scalar(f"acc/acc_b_valid_bc", valid_accs_b[0], step)
            self.writer.add_scalar(f"acc/acc_b_valid_bn", valid_accs_b[1], step)
            self.writer.add_scalar(f"acc/acc_b_valid_ba", valid_accs_b[2], step)
            self.writer.add_scalar(f"acc/acc_b_valid", valid_accs_b[3], step)
            self.writer.add_scalar(f"acc/acc_b_test_bc", test_accs_b[0], step)
            self.writer.add_scalar(f"acc/acc_b_test_bn", test_accs_b[1], step)
            self.writer.add_scalar(f"acc/acc_b_test_ba", test_accs_b[2], step)
            self.writer.add_scalar(f"acc/acc_b_test", test_accs_b[3], step)

            self.writer.add_scalar(f"acc/best_acc_b_valid", self.best_valid_acc_b, step)
            self.writer.add_scalar(f"acc/best_acc_b_test", self.best_test_acc_b, step)
        
        if self.args.wandb:
            wandb.log({
                "acc/acc_b_valid": trans_result(valid_accs_b),
                "acc/acc_b_test": trans_result(test_accs_b),
                "acc/best_acc_b_valid": trans_result(self.best_valid_acc_b),
                "acc/best_acc_b_test": trans_result(self.best_test_acc_b),

                "loss/loss_b_valid": trans_result(valid_losses_b),
                "loss/loss_b_test": trans_result(test_losses_b),
                "loss/best_loss_b_valid": trans_result(self.best_valid_loss_b),
                "loss/best_loss_b_test": trans_result(self.best_test_loss_b),

                "loss/var_b_valid": trans_result(valid_var_b),
                "loss/var_b_test": trans_result(test_var_b),
            }, step=step)

    def train(self, args):
        # training vanilla ...
        train_iter = iter(self.train_loader)
        train_num = len(self.train_dataset.dataset)
        epoch, cnt = 0, 0

        for step in tqdm(range(args.num_steps)):
            try:
                index, data, attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                index, data, attr, _ = next(train_iter)

            data = data.to(self.device)
            attr = attr.to(self.device)
            label = attr[:, args.target_attr_idx]

            logit_b, imgs_rec = self.model_b(data)
            loss_b_update = self.criterion(logit_b, label)
            loss = loss_b_update.mean()
            loss_clf = F.cross_entropy(logit_b, label)
            loss_rec = F.mse_loss(imgs_rec, data)
            loss_tv = self.total_variation_loss(imgs_rec)
            loss = loss_clf + args.Lambda * loss_rec + args.Lambda2 * loss_tv

            self.optimizer_b.zero_grad()
            loss.backward()
            self.optimizer_b.step()

            ##################################################
            #################### LOGGING #####################
            ##################################################

            # if step % args.save_freq == 0:
            #     self.save_vanilla(step)

            if step % args.log_freq == 0:
                self.board_vanilla_loss(step, loss_b=loss)

            if step % args.valid_freq == 0:
                self.board_vanilla_acc(step, epoch)

            cnt += len(index)
            if cnt == train_num:
                print(f'finished epoch: {epoch}')
                epoch += 1
                cnt = 0

    def total_variation_loss(self, img):
        bs_img, c_img, h_img, w_img = img.size()
        tv_h = torch.pow(img[:,:,1:,:]-img[:,:,:-1,:], 2).sum()
        tv_w = torch.pow(img[:,:,:,1:]-img[:,:,:,:-1], 2).sum()
        return (tv_h+tv_w)/(bs_img*c_img*h_img*w_img)

def trans_result(accs):
    result = {
        "BC": accs[0],
        "BN": accs[1],
        "BA": accs[2],
        "OA": accs[3],
    }
    return result

def get_bias_state(label, bias_label, bias):
    """ partition according to bias state and calc accs"""
    if isinstance(bias_label, torch.Tensor):
        bias_state = torch.zeros((bias_label.shape[0], 1)).cuda()
    else:
        bias_state = np.zeros((bias_label.shape[0], 1))
    for i in range(len(bias)):
        if bias[i] is None:
            # skip bias neutral samples
            continue
        bias_class = bias[i]
        ba_mask = (bias_label == i) & (label == bias_class)
        bc_mask = (bias_label == i) & (label != bias_class)
        bias_state[ba_mask] = 1
        bias_state[bc_mask] = -1
    return bias_state