from __future__ import print_function
import glob
import pdb
import heapq
import os
import pickle
import random
import numpy as np
from models.resnet_alea import resnet50_alea, sample_softmax
from utils.utils import get_scf_idxes, init_grad
from utils.sgmcmc import SGLD, pSGLD, H_SA_SGHMC
from utils.loss import MAELoss
from utils.dataloader import *
from torch.autograd import Variable
from scipy.stats import entropy
import argparse
from torchvision.models import resnet34, resnet50, densenet121
import torchvision.transforms as transforms
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import sys
sys.path.append('..')


parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default=None, required=True,
                    help='path to save checkpoints (default: None)')
parser.add_argument('--data_path', type=str, default='data', metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 8)')
parser.add_argument('--save_epochs', type=int, default=5,
                    help='Saving period')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--topk', type=int, default=300, metavar='N',
                    help='Number of SCF samples')
parser.add_argument('--data_type', type=str, default='waterbirds')
parser.add_argument('--device_id', type=int, help='device id to use')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--l2', type=float, default=5e-4,
                    help='weight decay')
parser.add_argument('--ent_reg', type=float, default=0.3,
                    help='weight decay')
parser.add_argument('--warmup', type=int, default=0,
                    help='warmup for save checkpoints')
parser.add_argument('--lr', type=float, default=0.5,
                    help='initial learning rate')
parser.add_argument('--pval', type=float, default=0.05,
                    help='pvalue cutoff')
parser.add_argument('--use-ce', action='store_true', default=False)
parser.add_argument('--p_noise', type=float, default=0.0,
                    help='proportion of noisy labels')

args = parser.parse_args()
device_id = args.device_id
use_cuda = torch.cuda.is_available()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")
# Data
print('==> Preparing data..')
dataset = WildsDataset(args.data_type, args)
dataset.inject_label_noise(args.p_noise)
trainloader, valloader, testloader = dataset.get_loader(args)

with open(f'{args.dir}/noisy_data.pk', 'wb') as f:
    pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)

# Model
print('==> Building model..')


def bert_forward(net, inputs, targets=None):
    input_ids = inputs[:, :, 0]
    input_masks = inputs[:, :, 1]
    segment_ids = inputs[:, :, 2]
    return net(input_ids=input_ids, attention_mask=input_masks, token_type_ids=segment_ids, labels=targets)[0]


def resnet_forward(net, inputs, targets=None):
    return net(inputs)


if args.data_type in ['iwilds', 'rxrx1', 'waterbirds', 'celebA']:
    net = resnet50(pretrained=True)  # * Load pre-trained model
    # * change output dimension
    net.fc = torch.nn.Linear(net.fc.in_features, dataset.target_dim)
    model_forward = resnet_forward
elif args.data_type in ['fmow', 'camelyon17']:
    net = densenet121(pretrained=True)
    model_forward = resnet_forward
elif args.data_type in ['civilcomments']:
    from transformers import BertForSequenceClassification
    net = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=2)
    model_forward = bert_forward

if use_cuda:
    net.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.autograd.set_detect_anomaly(False)
    torch.autograd.profiler.profile(False)
    torch.autograd.profiler.emit_nvtx(False)


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets, metadata) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(device_id), targets.cuda(device_id)

        init_grad(net)
        outputs = model_forward(net, inputs)
        loss = criterion(outputs, targets)
        loss.requires_grad_(True)
        log_probs = F.softmax(outputs, dim=-1).log()
        # confidence regularization
        reg = -torch.mean(-torch.sum(log_probs*log_probs.exp(), axis=-1))
        (loss + reg*args.ent_reg).backward()
        optimizer.step()

        train_loss += loss.mean().data.item()
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        if batch_idx % 100 == 0:
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
                  % (train_loss/(batch_idx+1), 100.*correct.item()/total, correct, total))


def test(epoch, loader='test'):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pred_list = []
    truth_res = []
    metadata_list = []
    loader = testloader if loader == 'test' else valloader
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(loader):
            if use_cuda:
                inputs, targets = inputs.cuda(
                    device_id), targets.cuda(device_id)
            truth_res += list(targets.cpu().data)
            metadata_list += list(metadata.cpu().data)
            outputs = model_forward(net, inputs)
            pred_list += list(F.softmax(outputs, dim=-1).cpu().data)
            loss = criterion(outputs, targets)
            test_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss/len(loader), correct, total,
        100. * correct.item() / total))
    pred_list = torch.stack(pred_list).max(-1)[1]
    truth_res = torch.stack(truth_res)
    metadata_list = torch.stack(metadata_list)
    eval_result, eval_result_str = dataset.dataset.eval(
        pred_list, truth_res, metadata_list)
    print(eval_result_str)
    return correct.item()/total


def get_uncertain_index(model_paths, criterion):
    ens_pred_list = []
    ens_target_list = []

    first = True
    for model_path in model_paths:
        net.load_state_dict(torch.load(model_path))
        model = net

        trainloader_ordered, _, _ = dataset.get_loader(
            args, shuffle_train=False)
        model.eval()
        pred_list = []
        target_list = []

        with torch.no_grad():
            for batch_idx, (inputs, targets, metadata) in enumerate(trainloader_ordered):
                if use_cuda:
                    inputs, targets = inputs.cuda(
                        device_id), targets.cuda(device_id)
                outputs = model_forward(net, inputs)
                pred = F.softmax(outputs, dim=-1)
                pred_list.append(pred.cpu().numpy())
                if first:
                    target_list.append(targets.cpu().numpy())

        pred_list = np.concatenate(pred_list)
        ens_pred_list.append(pred_list)
        if first:
            ens_target_list = np.concatenate(target_list)
        first = False

    indices = dataset.train_data.indices
    ens_pred_list = np.mean(ens_pred_list, axis=0)
    scf_indices = get_scf_idxes(entropy(ens_pred_list, axis=1), p=args.pval)
    error_indices = np.where(
        np.argmax(ens_pred_list, axis=1) != ens_target_list)

    return indices[scf_indices], indices[error_indices]


def test_sgld(loader='test'):
    ckpts = glob.glob(args.dir + f'/{args.data_type}_id_model_*.pt')
    print(ckpts)
    sgld_pred_list = []
    
    for ckpt in ckpts:
        net.load_state_dict(torch.load(ckpt))
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        pred_list = []
        truth_res = []
        metadata_list = []
        with torch.no_grad():
            for batch_idx, (inputs, targets, metadata) in enumerate(testloader):
                if use_cuda:
                    inputs, targets = inputs.cuda(
                        device_id), targets.cuda(device_id)
                truth_res += list(targets.cpu().data)
                metadata_list += list(metadata.cpu().data)
                outputs = model_forward(net, inputs)
                pred_list += list(F.softmax(outputs, dim=-1).cpu().data)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss/len(testloader), correct, total,
            100. * correct.item() / total))
        pred_list = torch.stack(pred_list)#.max(-1)[1]
        truth_res = torch.stack(truth_res)
        metadata_list = torch.stack(metadata_list)
        sgld_pred_list.append(pred_list)
    
    sgld_pred_list = torch.stack(sgld_pred_list).mean(axis=0).max(-1)[1]

    eval_result, eval_result_str = dataset.dataset.eval(
        sgld_pred_list, truth_res, metadata_list)
    print(eval_result_str)
    return correct.item()/total


iterations = 1
criterion = MAELoss()
if args.use_ce:
    criterion = nn.CrossEntropyLoss()
datasize = len(dataset.train_data)
norm_sigma = (1/args.l2)**0.5
print(f"L2 reg : {args.l2}, Norm sigma: {norm_sigma}")
optimizer = SGLD(net.parameters(), datasize, lr=args.lr,
                 norm_sigma=norm_sigma, addnoise=True)
"""
for epoch in range(1, args.epochs+1):
    train(epoch)
    test(epoch)

    if epoch % args.save_epochs == 0 and epoch != 0:
        if epoch > args.warmup:
            net.cpu()
            print('save the model')
            torch.save(net.state_dict(), args.dir +
                       f'/{args.data_type}_id_model_{epoch}.pt')
            net.cuda(device_id)
"""

print(args.dir + f'/{args.data_type}_id_model_*.pt')
test_sgld()

SCF_idx, error_idx = get_uncertain_index(glob.glob(
    args.dir + f'/{args.data_type}_id_model_*.pt'), nn.NLLLoss(reduction='none'))


print(f"Length of the errorset : {len(SCF_idx)}")

print(
    f'Save the index of the uncertain samples for [{args.dir}/uncertain_idx.pk]')
with open(f'{args.dir}/uncertain_idx.pk', 'wb') as f:
    pickle.dump(SCF_idx, f, pickle.HIGHEST_PROTOCOL)
