'''Modified from https://github.com/alinlab/LfF and https://github.com/kakaoenterprise/Learning-Debiased-Disentangled'''

from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
from data.util import get_dataset, IdxDataset
from torch_module.util import get_model
from torch_util import adjust_learning_rate
import time
import pandas as pd
import numpy as np


class Learner(object):
    def __init__(self, args):
        data2model = {'cmnist': "ResNet18",
                       'cifar10c': "ResNet18",
                       'cifar10_lff': "ResNet18",
                       'waterbird': "ResNet50",
                       'bffhq': "ResNet18"}

        data2mini_batch_size = {'cmnist': 256,
                                'cifar10c': 256,
                                'cifar10_lff': 256,
                                'waterbird': 50,
                                'bffhq': 64}

        data2batch_size = {'cmnist': 1024,
                           'cifar10c': 1024,
                           'cifar10_lff': 1024,
                           'waterbird': 50,
                           'bffhq': 256}
        
        data2preprocess = {'cmnist': True,
                           'cifar10c': True,
                           'cifar10_lff': True,
                           'waterbird': True,
                           'bffhq': True}

        self.model = data2model[args.dataset]
        self.batch_size = data2batch_size[args.dataset]
        args.mini_batch_size = data2mini_batch_size[args.dataset]
        self.mini_batch_size = args.mini_batch_size

        self.device = f'cuda:{args.gpu}'
        self.args = args

        self.train_dataset = get_dataset(
            args,
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="train",
            transform_split="train",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
            use_type0=args.use_type0,
            use_type1=args.use_type1
        )
        
        self.valid_dataset = get_dataset(
            args,
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="valid",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
            use_type0=args.use_type0,
            use_type1=args.use_type1
        )

        self.test_dataset = get_dataset(
            args,
            args.dataset,
            data_dir=args.data_dir,
            dataset_split="test",
            transform_split="valid",
            percent=args.percent,
            use_preprocess=data2preprocess[args.dataset],
            use_type0=args.use_type0,
            use_type1=args.use_type1
        )
        
        res_dir = args.inf_path
        df = pd.read_csv(res_dir, sep='\t')
        
        unbias_idx = df['index'].values.astype(np.int64)
        loaded_true_label = df['true_label'].values.astype(np.int64)

        unbias_set = torch.utils.data.Subset(self.train_dataset, indices=unbias_idx)
        self.train_loader = DataLoader(
            unbias_set,
            batch_size=self.batch_size, 
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
        )
        
        # create bias dataset
        index = np.arange(len(self.train_dataset))
        new_index = np.delete(index, unbias_idx)
        new_train_dataset = torch.utils.data.Subset(self.train_dataset, indices=new_index)

        self.aug_train_loader = DataLoader(
            new_train_dataset,
            batch_size=self.batch_size, 
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
        )
        
        if args.dataset == 'cifar10_lff':
          train_target_attr = torch.LongTensor(self.train_dataset.query_attr)
        elif args.dataset == 'waterbird':
          train_target_attr = torch.LongTensor(self.train_dataset.y_array)
        else:
          train_target_attr = []
          for data in self.train_dataset.data:
              train_target_attr.append(int(data.split('_')[-2]))
          train_target_attr = torch.LongTensor(train_target_attr)

        attr_dims = []
        attr_dims.append(torch.max(train_target_attr).item() + 1)
        self.num_classes = attr_dims[0]
        self.train_dataset = IdxDataset(self.train_dataset)

        # make loader
        self.valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.mini_batch_size, 
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.mini_batch_size, 
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
        )

        self.best_valid_acc_b, self.best_test_acc_b = 0., 0.
        self.best_valid_acc_d, self.best_test_acc_d = 0., 0.


    def evaluate_ours(self, args, model, data_loader):
        model.eval()
        total_correct, total_num = 0, 0

        pred_list = []
        label_list = []
        bias_label_list = []

        bias_align_total_correct, bias_align_total_num = 0, 0
        bias_free_total_correct, bias_free_total_num = 0, 0
        for data, attr, _ in data_loader:
            label = attr[:, args.target_attr_idx]
            data = data.to(self.device)
            label = label.to(self.device)

            bias_label = attr[:, 1]
            bias_label = bias_label.to(self.device)

            with torch.no_grad():
                logit = model(data)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                correct = (pred == label).long()
                total_correct += correct.sum()
                total_num += correct.shape[0]
                
                bias_align_mask = (bias_label == label)
                bias_align_correct = (pred == label)[bias_align_mask].long()
                bias_align_total_correct += bias_align_correct.sum()
                bias_align_total_num += bias_align_mask.sum().item()
                
                bias_free_correct = (pred == label)[~bias_align_mask].long()
                bias_free_total_correct += bias_free_correct.sum()
                bias_free_total_num += (~bias_align_mask).sum().item()
                
            pred_list.append(pred.cpu().detach())
            label_list.append(label.cpu().detach())
            bias_label_list.append(bias_label.cpu().detach())
        
        pred_list = torch.cat(pred_list, dim=0)
        label_list = torch.cat(label_list, dim=0)
        bias_label_list = torch.cat(bias_label_list, dim=0)
        accuracy_matrix = torch.zeros(self.num_classes, self.num_classes)
        for i in range(self.num_classes):
            class_mask = (label_list == i)
            for j in range(self.num_classes):
                class_bias_mask = (bias_label_list == j)
                total_sample_num = (class_mask & class_bias_mask).sum().item()
                correct_sample_num = ((pred_list == label_list) & class_mask & class_bias_mask).sum().item()
                if total_sample_num != 0:
                    accuracy_matrix[i, j] = correct_sample_num/float(total_sample_num)
                else: # if no sample in this group, set accuracy to a large number
                    accuracy_matrix[i, j] = 100000

        accs = total_correct/float(total_num)
        bias_align_accs = bias_align_total_correct/float(bias_align_total_num)
        bias_free_accs = bias_free_total_correct/float(bias_free_total_num)
        
        return accs, bias_align_accs, bias_free_accs, accuracy_matrix.numpy()

    def retrain(self, args):
        print("retrain starts...")
        
        if args.dataset == 'waterbird':
            if args.dfa:
                self.model_d = get_model('resnet50_DISENTANGLE', self.num_classes)
            else:
                self.model_d = get_model('ResNet50', self.num_classes)
        else:
            if args.dfa:
                self.model_d = get_model('resnet18_DISENTANGLE', self.num_classes)
            else:
                self.model_d = get_model('ResNet18', self.num_classes)
            
        saved_model_path = args.model_path
        saved_model_dict = torch.load(saved_model_path, map_location=lambda storage, loc: storage.cuda(0))

        if args.dfa:
          del saved_model_dict['state_dict']['fc.weight']
          del saved_model_dict['state_dict']['fc.bias']

        self.model_d.load_state_dict(saved_model_dict['state_dict'], strict=False)

        if args.dataset == 'bffhq':
            self.model_d.fc = nn.Linear(512, 2)
        elif args.dataset == 'waterbird':
            self.model_d.fc = nn.Linear(2048, 2)
        else:
            self.model_d.fc = nn.Linear(512, 10)

        self.model = self.model_d.to(self.device)

        self.model.zero_grad()

        if args.dataset == 'waterbird':
            self.optimizer_d = torch.optim.SGD(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay,)
        else:
            self.optimizer_d = torch.optim.Adam(
                self.model.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )
        
        self.criterion = nn.CrossEntropyLoss(reduction='none')

        train_iter = iter(self.train_loader)
        aug_train_iter = iter(self.aug_train_loader)
        loss_d = 0
        pbar = tqdm(range(args.num_steps))
        for step in pbar:
            if args.cosine:
                adjust_learning_rate(args, self.optimizer_d, step)
                
            try:
                batch_data, batch_attr, _ = next(train_iter)
            except:
                train_iter = iter(self.train_loader)
                batch_data, batch_attr, _ = next(train_iter)

            try:
                aug_batch_data, aug_batch_attr, _ = next(aug_train_iter)
            except:
                aug_train_iter = iter(self.aug_train_loader)
                aug_batch_data, aug_batch_attr, _ = next(aug_train_iter)

            # for indices in sampler:
            data = batch_data.to(self.device)
            label = batch_attr[:,0].to(self.device)

            self.model.eval()

            loss_b = 0
            logit_b = self.model(aug_batch_data.to(self.device))

            loss_b_update = self.criterion(logit_b, aug_batch_attr[:,0].to(self.device))
            loss_b = loss_b_update.mean() 
            
            logit_d = self.model(data)
            loss_d = 0
            loss_d_update = self.criterion(logit_d, label)
            loss_d_update = loss_d_update
            loss_d = loss_d_update.mean() 

            loss = loss_d + args.b_weight * loss_b

            self.optimizer_d.zero_grad()
            loss.backward()
            self.optimizer_d.step()
                
        test_accs_d, test_accs_align, test_accs_free, accuracy_matrix = self.evaluate_ours(args, self.model_d, self.test_loader)
        
        if args.dataset == 'waterbird':
            print(f'test_acc_worst: {np.min(accuracy_matrix)*100:.2f}')
        elif args.dataset == 'bffhq':
            print(f'test_acc_conflict: {test_accs_free*100:.2f}')
        else:
            print(f'test_acc: {test_accs_d*100:.2f}')
