from __future__ import print_function
import glob
import pdb
import heapq
import os
import pickle
import random
import numpy as np
from models.resnet_alea import resnet50_alea, sample_softmax
from utils.utils import get_scf_idxes
from utils.sgmcmc import SGLD, pSGLD, H_SA_SGHMC
from utils.dataloader import *
from torch.autograd import Variable
from scipy.stats import entropy
import argparse
from torchvision.models import resnet34, resnet50, densenet121
import torchvision.transforms as transforms
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import sys
sys.path.append('..')


parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default=None, required=True,
                    help='path to save checkpoints (default: None)')
parser.add_argument('--data_path', type=str, default='data', metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 8)')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--data_type', type=str, default='iwildcam',
                    help='Dataset type of WILDS (default: iwildcam)')
parser.add_argument('--device_id', type=int, help='device id to use')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--l2', type=float, default=5e-4,
                    help='weight decay')
parser.add_argument('--lr', type=float, default=0.5,
                    help='initial learning rate')
parser.add_argument('--p_noise', type=float, default=0.0,
                    help='proportion of noisy labels')

args = parser.parse_args()
device_id = args.device_id
use_cuda = torch.cuda.is_available()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")
# Data
print('==> Preparing data..')
dataset = WildsDataset(args.data_type, args)
dataset.inject_label_noise(args.p_noise)
trainloader, valloader, testloader = dataset.get_loader(args)
args.temperature = 1/dataset.N_training

with open(f'{args.dir}/noisy_data.pk', 'wb') as f:
    pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)

# Model
print('==> Building model..')
if args.data_type in ['iwilds', 'rxrx1', 'waterbirds', 'celebA']:
    net = resnet50(pretrained=True)  # * Load pre-trained model
    # * change output dimension
    net.fc = torch.nn.Linear(net.fc.in_features, dataset.target_dim)
elif args.data_type in ['fmow', 'camelyon17']:
    net = densenet121(pretrained=True)

if use_cuda:
    net.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets, metadata) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(device_id), targets.cuda(device_id)

        optimizer.zero_grad()
        logits = net(inputs)
        loss = criterion(logits, targets)
        loss.requires_grad_(True)
        (loss).backward()
        optimizer.step()

        train_loss += loss.mean().data.item()
        _, predicted = torch.max(logits, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
    print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
          % (train_loss/(batch_idx+1), 100.*correct.item()/total, correct, total))


def test(epoch, loader='test'):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pred_list = []
    truth_res = []
    metadata_list = []
    loader = testloader if loader == 'test' else valloader
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(loader):
            if use_cuda:
                inputs, targets = inputs.cuda(
                    device_id), targets.cuda(device_id)
            truth_res += list(targets.cpu().data)
            metadata_list += list(metadata.cpu().data)
            logits = net(inputs)
            pred_list += list(F.softmax(logits, dim=-1).cpu().data)
            loss = criterion(logits, targets)
            test_loss += loss.data.item()
            _, predicted = torch.max(logits.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss/len(loader), correct, total,
        100. * correct.item() / total))
    pred_list = torch.stack(pred_list).max(-1)[1]
    truth_res = torch.stack(truth_res)
    metadata_list = torch.stack(metadata_list)
    eval_result, eval_result_str = dataset.dataset.eval(
        pred_list, truth_res, metadata_list)
    print(eval_result_str)
    return correct.item()/total


def get_errorset_index(criterion):
    model = net

    trainloader_ordered, _, _ = dataset.get_loader(args, shuffle_train=False)
    model.eval()
    pred_list = []
    target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(trainloader_ordered):
            if use_cuda:
                inputs, targets = inputs.cuda(
                    device_id), targets.cuda(device_id)
            logits = model(inputs)
            pred = F.softmax(logits, dim=-1)

            target_list.append(targets.cpu().numpy())
            pred_list.append(pred.cpu().numpy())

    pred_list = np.concatenate(pred_list)
    target_list = np.concatenate(target_list)

    indices = dataset.train_data.indices
    error_index = np.where(np.argmax(pred_list, axis=1) != target_list)
    return indices[error_index]


criterion = nn.CrossEntropyLoss()
print(f"L2 reg : {args.l2}")
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      weight_decay=args.l2, momentum=0.9)

for epoch in range(1, args.epochs+1):
    train(epoch)
    test(epoch)

error_idx = get_errorset_index(nn.NLLLoss(reduction='none'))

print(f"Length of the errorset : {len(error_idx)}")

print(f'Save the index of the error samples for [{args.dir}/error_idx.pk]')
with open(f'{args.dir}/error_idx.pk', 'wb') as f:
    pickle.dump(error_idx, f, pickle.HIGHEST_PROTOCOL)
