from __future__ import print_function
import pickle
import pdb
import random
import numpy as np
from utils.utils import init_grad
from utils.dataloader import *
from utils.loss import MAELoss, GeneralizedCELoss
import heapq
import argparse
from torchvision.models import resnet50, densenet121
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import sys
sys.path.append('..')


parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default=None, required=True,
                    help='path to save checkpoints (default: None)')
parser.add_argument('--data_path', type=str, default='data', metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--data_type', type=str, default='waterbirds')
parser.add_argument('--data_file', type=str)
parser.add_argument('--device_id', type=int, help='device id to use')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--l2', type=float, default=5e-4,
                    help='weight decay')
parser.add_argument('--lr', type=float, default=0.5,
                    help='initial learning rate')
parser.add_argument('--p_h', type=int, default=5,
                    help='Coefficient to train the hard sample set.')
parser.add_argument('--ckpt',
                    help='Previous checkpoint.')
parser.add_argument('--curricular', action='store_true')
parser.add_argument('--mae', action='store_true')
parser.add_argument('--gce', action='store_true')
parser.add_argument('--idx', type=str,
                    help='The pickle file including indices of hard samples.')

args = parser.parse_args()

device_id = args.device_id
use_cuda = torch.cuda.is_available()

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")

#######* Load data #######################################
print('==> Preparing data..')

dataset = WildsDataset(args.data_type, args)
#args.data_file = 'noise_waterbirds_0.1/noisy_data.pk'
if args.data_file != None:
    dataset = pickle.load(open(args.data_file, 'rb'))
print(f"Dataset size before upsampling: [{len(dataset.train_data)}]")
if args.curricular:
    hard_idx = pickle.load(open(args.idx, 'rb'))
    if len(hard_idx) > 0:
        dataset.add_training_data(hard_idx, times=args.p_h)
    else:
        print("[Warning] ! ! ! The length of errorset is zero. ! ! !")
print(f"Dataset size after upsampling: [{len(dataset.train_data)}]")
trainloader, valloader, testloader = dataset.get_loader(args)

#######* Build model #######################################
print('==> Building model..')


def bert_forward(net, inputs, targets=None):
    input_ids = inputs[:, :, 0]
    input_masks = inputs[:, :, 1]
    segment_ids = inputs[:, :, 2]
    return net(input_ids=input_ids, attention_mask=input_masks, token_type_ids=segment_ids, labels=targets)[0]


def resnet_forward(net, inputs, targets=None):
    return net(inputs)


if args.data_type in ['iwilds', 'rxrx1', 'waterbirds', 'celebA']:
    net = resnet50(pretrained=True)  # * Load pre-trained model
    # * change output dimension
    net.fc = torch.nn.Linear(net.fc.in_features, dataset.target_dim)
    model_forward = resnet_forward
    optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          weight_decay=args.l2, momentum=0.9)
    eval_iter = 500
elif args.data_type in ['fmow', 'camelyon17']:
    net = densenet121(pretrained=True)
    model_forward = resnet_forward
    optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          weight_decay=args.l2, momentum=0.9)
elif args.data_type in ['civilcomments']:
    from transformers import BertForSequenceClassification
    net = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=2)
    model_forward = bert_forward
    optimizer = optim.AdamW(net.parameters(), lr=args.lr,
                            weight_decay=args.l2, eps=1e-8)
    eval_iter = 10000

if args.ckpt != None:
    net.load_state_dict(torch.load(args.ckpt))

if use_cuda:
    net.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.autograd.set_detect_anomaly(False)
    torch.autograd.profiler.profile(False)
    torch.autograd.profiler.emit_nvtx(False)

#######* Train models #######################################


def train(epoch):
    print('\nEpoch: %d' % epoch)
    global iterations, prev_best_acc, prev_best_woacc
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets, metadata) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(device_id), targets.cuda(device_id)

        # optimizer.zero_grad()
        init_grad(net)
        outputs = model_forward(net, inputs)
        loss = criterion(outputs, targets)
        loss.requires_grad_(True)
        loss.backward()
        if args.data_type in ['civilcomments']:
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        if batch_idx % 100 == 0:
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
                  % (train_loss/(batch_idx+1), 100.*correct.item()/total, correct, total))
        iterations += 1

        if iterations % eval_iter == 0:
            test(epoch)
            cur_acc, cur_woacc = test(epoch, loader='val')
            if cur_acc > prev_best_acc:
                print("Save the best acc model !")
                torch.save(net.state_dict(), args.dir +
                           f'/{args.data_type}_model_best_val.pt')
                prev_best_acc = cur_acc
            if cur_woacc > prev_best_woacc:
                print("Save the best wo-acc model !")
                torch.save(net.state_dict(), args.dir +
                           f'/{args.data_type}_model_bestwo_val.pt')
                prev_best_woacc = cur_woacc


def test(epoch, loader='test'):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pred_list = []
    truth_res = []
    metadata_list = []
    loader = testloader if loader == 'test' else valloader
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(loader):
            if use_cuda:
                inputs, targets = inputs.cuda(
                    device_id), targets.cuda(device_id)
            truth_res += list(targets.cpu().data)
            metadata_list += list(metadata.cpu().data)
            outputs = model_forward(net, inputs)
            pred_list += list(F.softmax(outputs, dim=1).cpu().data)
            loss = criterion(outputs, targets)
            test_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss/len(loader), correct, total,
        100. * correct.item() / total))

    pred_list = torch.stack(pred_list).max(-1)[1]
    truth_res = torch.stack(truth_res)
    metadata_list = torch.stack(metadata_list)
    eval_result, eval_result_str = dataset.dataset.eval(
        pred_list, truth_res, metadata_list)
    worst_acc = float(eval_result['acc_wg'])
    print(worst_acc)
    acc = correct.item()/total
    print(eval_result_str)
    return acc, worst_acc


def get_hard_index(model, criterion, p=0.2):
    trainloader_ordered, _, _ = dataset.get_loader(args, shuffle_train=False)
    model.eval()
    score_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(trainloader_ordered):
            if use_cuda:
                inputs, targets = inputs.cuda(
                    device_id), targets.cuda(device_id)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            score_list.append(loss.cpu().numpy())
    score_list = np.concatenate(score_list)
    score_dict = dict(zip(dataset.train_data.indices, score_list))
    # * get top k proportion from score_dict
    large_index = heapq.nlargest(
        int(len(score_list)*p), score_dict, key=score_dict.get)
    return large_index


datasize = len(dataset.train_data)
num_batch = datasize/args.batch_size+1
print(f"Num batch: [{num_batch}]")

criterion = nn.CrossEntropyLoss(reduction='mean')
if args.mae:
    criterion = MAELoss()
if args.gce:
    criterion = GeneralizedCELoss(reduction='mean')

iterations = 0
mt = 0

if args.ckpt != None:
    test(0)

prev_best_acc, prev_best_woacc = -1, -1
for epoch in range(args.epochs):
    train(epoch)
    test(epoch)

    cur_acc, cur_woacc = test(epoch, loader='val')
    if cur_acc > prev_best_acc:
        print("Save the best acc model !")
        torch.save(net.state_dict(), args.dir +
                   f'/{args.data_type}_model_best_val.pt')
        prev_best_acc = cur_acc
    if cur_woacc > prev_best_woacc:
        print("Save the best wo-acc model !")
        torch.save(net.state_dict(), args.dir +
                   f'/{args.data_type}_model_bestwo_val.pt')
        prev_best_woacc = cur_woacc

torch.save(net.state_dict(), args.dir + f'/{args.data_type}_model_{mt}.pt')
