"""
Reproduced codes for "Learning from Failure: Learning from Failure: Training Debiased Classifier from Biased Classifier"
; NeurIPS 2020 by Nam et al.,.

"""

from __future__ import print_function
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from torchvision.models import resnet34, resnet50, densenet121
import random, pdb, pickle
import argparse
from utils.loss import GeneralizedCELoss
from utils.dataloader import *
from utils.utils import EMA 
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default=None, required=True, help='path to save checkpoints (default: None)')
parser.add_argument('--data_path', type=str, default='data', metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--data_type', type=str, default='iwildcam',
                    help='Dataset type of WILDS (default: iwildcam)')
parser.add_argument('--data_file', type=str)
parser.add_argument('--device_id',type = int, help = 'device id to use')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--l2', type=float, default=5e-4,
                    help='weight decay')
parser.add_argument('--lr', type=float, default=0.5,
                    help='initial learning rate')
parser.add_argument('--ckpt_b', 
                    help='Previous checkpoint of the biased model.')
parser.add_argument('--ckpt_d', 
                    help='Previous checkpoint of the debiased model.')
parser.add_argument('--p_noise', type=float, default=0.0,
                    help='proportion of noisy labels')
parser.add_argument('--q', type=float, default=0.0,
                    help='q value for GCE.')

args = parser.parse_args()

device_id = args.device_id
use_cuda = torch.cuda.is_available()

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")

#######* Load data #######################################
print('==> Preparing data..')

dataset = WildsDataset(args.data_type, args, with_idx=True) # 'with_idx' makes trainloader yield the indices of samples in the batch.
dataset.inject_label_noise(args.p_noise)
if args.data_file != None:
    dataset = pickle.load(open(args.data_file, 'rb'))
trainloader, valloader, testloader = dataset.get_loader(args)
with open(f'{args.dir}/noisy_data.pk', 'wb') as f:
    pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)

#######* Build model #######################################
print('==> Building model..')
if args.data_type in ['iwilds','rxrx1', 'waterbirds', 'celebA']:
    net_b = resnet50(pretrained=True)#* Load pre-trained model
    net_d = resnet50(pretrained=True)#* Load pre-trained model
    net_b.fc = torch.nn.Linear(net_b.fc.in_features, dataset.target_dim) #* change output dimension
    net_d.fc = torch.nn.Linear(net_d.fc.in_features, dataset.target_dim) #* change output dimension
elif args.data_type in ['fmow', 'camelyon17']:
    net_b = densenet121(pretrained=True)
    net_d = densenet121(pretrained=True)
    net_b.classifier = torch.nn.Linear(net_b.classifier.in_features, dataset.target_dim) #* change output dimension
    net_d.classifier = torch.nn.Linear(net_d.classifier.in_features, dataset.target_dim) #* change output dimension
else:
    from models import *
    net = ResNet18(num_classes=dataset.target_dim)

if args.ckpt_b != None:
    net_b.load_state_dict(torch.load(args.ckpt_b))
    net_d.load_state_dict(torch.load(args.ckpt_d))

if use_cuda:
    net_b.cuda(device_id)
    net_d.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True

#######* Train models #######################################

def train(epoch):
    print('\nEpoch: %d' % epoch)
    global iterations, prev_best_acc, prev_best_woacc
    net_b.train()
    net_d.train()
    train_loss_b, train_loss_d = 0, 0
    correct_b, correct_d = 0, 0
    total_b, total_d = 0, 0

    for batch_idx, (indices, inputs, targets, metadata) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(device_id), targets.cuda(device_id)
        
        outputs_b = net_b(inputs)
        outputs_d = net_d(inputs)
        
        loss_b = criterion_d(outputs_b, targets).cpu().detach()
        loss_d = criterion_d(outputs_d, targets).cpu().detach()
        
        # Update moving average loss per sample 
        sample_loss_ema_b.update(loss_b, indices)
        sample_loss_ema_d.update(loss_d, indices)
        
        # Get the moving average loss
        loss_b = sample_loss_ema_b.parameter[indices].clone().detach()
        loss_d = sample_loss_ema_d.parameter[indices].clone().detach()

        # Class-wise normalize for the moving averaged loss
        # -> There is no specific description of this part in the main manuscripts and
        # supplementary of Nam et al.,. So we refer to the official repository:
        # https://github.com/alinlab/LfF/blob/master/train.py
        label_cpu = targets.cpu()
        for c in range(dataset.target_dim):
            class_index = np.where(label_cpu == c)[0]
            max_loss_b = sample_loss_ema_b.max_loss(c)
            max_loss_d = sample_loss_ema_d.max_loss(c)
            loss_b[class_index] = loss_b[class_index]/max_loss_b
            loss_d[class_index] = loss_d[class_index]/max_loss_d
        
        loss_weight = loss_b / (loss_b + loss_d + 1e-10)
        loss_b_update = criterion_b(outputs_b, targets)
        loss_d_update = criterion_d(outputs_d, targets) * loss_weight.cuda(device_id)
        loss = loss_b_update.mean() + loss_d_update.mean()

        optimizer_b.zero_grad()
        optimizer_d.zero_grad()
        loss.backward()
        optimizer_b.step()
        optimizer_d.step()

        train_loss_b += loss_b_update.mean().data.item()
        train_loss_d += loss_d_update.mean().data.item()
        _, predicted_b = torch.max(outputs_b.data, 1)
        _, predicted_d = torch.max(outputs_d.data, 1)
        total_b += targets.size(0)
        total_d += targets.size(0)
        correct_b += predicted_b.eq(targets.data).cpu().sum()
        correct_d += predicted_d.eq(targets.data).cpu().sum()
        
        if batch_idx%100==0:
            print('Loss_Biased: %.3f | Loss_Debiased: %.3f | Acc_Biased: %.3f%% (%d/%d) | Acc_Debiased: %.3f%% (%d/%d)'
                % (train_loss_b/(batch_idx+1), train_loss_d/(batch_idx+1), 100.*correct_b.item()/total_b,
                 correct_b, total_b, 100.*correct_d.item()/total_d, correct_d, total_d))

        if iterations % 500 == 0:
            test(net_d, epoch)
            cur_acc, cur_woacc = test(net_d, epoch, loader='val')
            if cur_acc > prev_best_acc:
                print("Save the best acc model !")
                torch.save(net_d.state_dict(), args.dir + f'/{args.data_type}_model_best_val.pt')
                prev_best_acc = cur_acc
            if cur_woacc > prev_best_woacc:
                print("Save the best wo-acc model !")
                torch.save(net_d.state_dict(), args.dir + f'/{args.data_type}_model_bestwo_val.pt')
                prev_best_woacc = cur_woacc
        iterations += 1

def test(model, epoch, loader='test'):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    pred_list = []
    truth_res = []
    metadata_list = []
    loader = testloader if loader=='test' else valloader
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(loader):
            if use_cuda:
                inputs, targets = inputs.cuda(device_id), targets.cuda(device_id)
            truth_res += list(targets.cpu().data)
            metadata_list += list(metadata.cpu().data)
            outputs = model(inputs)
            pred_list += list(F.softmax(outputs,dim=1).cpu().data)
            loss = criterion_d(outputs, targets)
            test_loss += loss.mean().data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss/len(loader), correct, total,
        100. * correct.item() / total))
    pred_list = torch.stack(pred_list).max(-1)[1]
    truth_res = torch.stack(truth_res)
    metadata_list = torch.stack(metadata_list)
    eval_result, eval_result_str = dataset.dataset.eval(pred_list, truth_res, metadata_list)
    worst_acc = float(eval_result['acc_wg'])
    print(worst_acc)
    acc = correct.item()/total
    print(eval_result_str)

    model.train()
    return acc, worst_acc 


datasize = len(dataset.train_data) 
num_batch = datasize/args.batch_size+1
print(f"Num batch: [{num_batch}]")

criterion_d = nn.CrossEntropyLoss(reduction='none')
criterion_b = GeneralizedCELoss(q=args.q)

optimizer_b = optim.Adam(net_b.parameters(), lr=args.lr, weight_decay=args.l2)
optimizer_d = optim.Adam(net_d.parameters(), lr=args.lr, weight_decay=args.l2)

train_labels = torch.LongTensor(dataset.dataset.y_array[dataset.train_data.indices])
sample_loss_ema_b = EMA(train_labels, alpha=0.7)
sample_loss_ema_d = EMA(train_labels, alpha=0.7)

if args.ckpt_d != None:
    test(net_d, 0)

make_hard = False
prev_best_acc, prev_best_woacc = -1, -1
iterations = 0
for epoch in range(args.epochs):
    train(epoch)
    test(net_b, epoch)
    test(net_d, epoch)

    cur_acc, cur_woacc = test(net_d, epoch, loader='val')
    if cur_acc > prev_best_acc:
        print("Save the best acc model !")
        torch.save(net_d.state_dict(), args.dir + f'/{args.data_type}_model_best_val.pt')
        torch.save(net_b.state_dict(), args.dir + f'/{args.data_type}_model_best_val_biased.pt')
        prev_best_acc = cur_acc
    if cur_woacc > prev_best_woacc:
        print("Save the best wo-acc model !")
        torch.save(net_d.state_dict(), args.dir + f'/{args.data_type}_model_bestwo_val.pt')
        torch.save(net_b.state_dict(), args.dir + f'/{args.data_type}_model_bestwo_val_biased.pt')
        prev_best_woacc = cur_woacc