from __future__ import print_function
import pickle
import pdb
import random
from torchvision.datasets import Food101
from torchvision.models import efficientnet_b4
import numpy as np
from utils.sgmcmc import SGLD
from utils.dta_model import DeepDTA
from utils.dta_data import DataSet, DTIDataset, DataLoader, collate_dataset
from utils.utils import init_grad, get_scf_idxes
from utils.dataloader import *
import argparse
from torchvision.models import resnet50, densenet121
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import sys
sys.path.append('..')


parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default=None, required=True,
                    help='path to save checkpoints (default: None)')
parser.add_argument('--epochs', type=int, default=10,
                    help='number of epochs to train')
parser.add_argument('--device_id', type=int, help='device id to use')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--l2', type=float, default=1e-4,
                    help='weight decay')
parser.add_argument('--lr', type=float, default=1e-2,
                    help='initial learning rate')
parser.add_argument('--save_epochs', type=int, default=2,
                    help='Saving period')
parser.add_argument("-f", "--fold_num", type=int,
                    help="Fold number. It must be one of the {0,1,2,3,4}.")
parser.add_argument("--type", default='None',
                    help="Davis or Kiba; dataset select.")

args = parser.parse_args()

device_id = args.device_id
use_cuda = torch.cuda.is_available()

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")

FOLD_NUM = int(args.fold_num)  # {0,1,2,3,4}


class DataSetting:
    def __init__(self):
        self.dataset_path = 'data/{}/'.format(args.type)
        self.problem_type = '1'
        self.is_log = False if args.type == 'kiba' else True


data_setting = DataSetting()

dataset = DataSet(data_setting.dataset_path,
                  1000 if args.type == 'kiba' else 1200,
                  100 if args.type == 'kiba' else 85)  # KIBA (1000,100) DAVIS (1200, 85)
smiles, proteins, Y = dataset.parse_data(data_setting)
test_fold, train_folds = dataset.read_sets(data_setting)

label_row_inds, label_col_inds = np.where(np.isnan(Y) == False)
test_drug_indices = label_row_inds[test_fold]
test_protein_indices = label_col_inds[test_fold]

train_fold_sum = []
for i in range(5):
    if i != FOLD_NUM:
        train_fold_sum += train_folds[i]

train_drug_indices = label_row_inds[train_fold_sum]
train_protein_indices = label_col_inds[train_fold_sum]

valid_drug_indices = label_row_inds[train_folds[FOLD_NUM]]
valid_protein_indices = label_col_inds[train_folds[FOLD_NUM]]

dti_dataset = DTIDataset(
    smiles, proteins, Y, train_drug_indices, train_protein_indices, spurious=False)
valid_dti_dataset = DTIDataset(
    smiles, proteins, Y, valid_drug_indices, valid_protein_indices)
test_dti_dataset = DTIDataset(
    smiles, proteins, Y, test_drug_indices, test_protein_indices)

trainloader = DataLoader(dti_dataset, batch_size=256, shuffle=True,
                         collate_fn=collate_dataset, pin_memory=True, num_workers=8)
validloader = DataLoader(valid_dti_dataset, batch_size=256, shuffle=False,
                         collate_fn=collate_dataset, pin_memory=True, num_workers=8)
testloader = DataLoader(test_dti_dataset, batch_size=256, shuffle=False,
                        collate_fn=collate_dataset, pin_memory=True, num_workers=8)

#######* Build model #######################################
print('==> Building model..')
net = DeepDTA()

if use_cuda:
    net.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.autograd.set_detect_anomaly(False)
    torch.autograd.profiler.profile(False)
    torch.autograd.profiler.emit_nvtx(False)

#######* Train models #######################################


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0

    for batch_idx, (d, p, y) in enumerate(trainloader):
        if use_cuda:
            d, p, y = d.cuda(device_id), p.cuda(
                device_id), y.cuda(device_id).unsqueeze(1)
        init_grad(net)
        outputs = net(d, p)
        loss = criterion(outputs, y)
        loss.requires_grad_(True)
        loss.backward()
        optimizer.step()

        train_loss += loss.data.item()
        if batch_idx % 50 == 0:
            print('Loss: %.3f' % (train_loss/(batch_idx+1)))


def test(epoch, loader='test'):
    global best_mse
    net.eval()
    test_loss = 0

    loader = testloader if loader == 'test' else validloader
    with torch.no_grad():
        for batch_idx, (d, p, y) in enumerate(testloader):
            if use_cuda:
                d, p, y = d.cuda(device_id), p.cuda(
                    device_id), y.cuda(device_id).unsqueeze(1)

            outputs = net(d, p)
            loss = criterion(outputs, y)
            test_loss += loss.data.item()

    print('\nTest set: Average loss: {:.4f}\n'.format(
        test_loss/len(testloader)))
    return test_loss/len(testloader)


def get_normal_entropy(ens_pred_list):
    var = np.var(ens_pred_list, axis=0)
    entropy = 0.5*np.log(var*2*np.pi) + 0.5
    return entropy


def get_uncertain_index(model_paths):
    ens_pred_list = []

    for model_path in model_paths:
        net.load_state_dict(torch.load(model_path))
        net.eval()
        trainloader_ordered = DataLoader(
            dti_dataset, batch_size=256, shuffle=False, collate_fn=collate_dataset, pin_memory=True, num_workers=8)
        pred_list = []

        with torch.no_grad():
            for batch_idx, (d, p, y) in enumerate(trainloader_ordered):
                if use_cuda:
                    d, p, y = d.cuda(device_id), p.cuda(
                        device_id), y.cuda(device_id).unsqueeze(1)
                outputs = net(d, p)
                pred_list.append(outputs.squeeze().cpu().numpy())

        pred_list = np.concatenate(pred_list)
        ens_pred_list.append(pred_list)

    ens_pred_list = np.array(ens_pred_list)
    entropy = get_normal_entropy(ens_pred_list)
    scf_indices = get_scf_idxes(entropy)

    return scf_indices


def get_hard_index(model_paths, k=500):
    ens_pred_list = []
    first_model = True
    label_list = []
    for model_path in model_paths:
        net.load_state_dict(torch.load(model_path))
        net.eval()
        trainloader_ordered = DataLoader(
            dti_dataset, batch_size=256, shuffle=False, collate_fn=collate_dataset, pin_memory=True, num_workers=8)
        pred_list = []
        with torch.no_grad():
            for batch_idx, (d, p, y) in enumerate(trainloader_ordered):
                if use_cuda:
                    d, p, y = d.cuda(device_id), p.cuda(
                        device_id), y.cuda(device_id).unsqueeze(1)
                outputs = net(d, p)
                pred_list.append(outputs.squeeze().cpu().numpy())
                if first_model:
                    label_list.append(y.squeeze().cpu().numpy())
        pred_list = np.concatenate(pred_list)
        ens_pred_list.append(pred_list)
        first_model = False

    ens_pred_list = np.array(ens_pred_list)
    ens_pred_list = np.mean(ens_pred_list, axis=0)
    label_list = np.concatenate(label_list)
    mse_by_sample = (ens_pred_list - label_list)**2

    hard_idx = np.argsort(mse_by_sample)[::-1]
    return hard_idx[:k]


datasize = len(dti_dataset)
num_batch = datasize/256+1
norm_sigma = (1/args.l2)**0.5
print(f"Num batch: [{num_batch}]")
criterion = F.l1_loss
optimizer = SGLD(net.parameters(), datasize, lr=args.lr,
                 norm_sigma=norm_sigma, addnoise=True)

prev_best_mse, prev_best_mse = 10000, 10000

for epoch in range(1, args.epochs+1):
    train(epoch)
    test(epoch, 'test')

    if epoch % args.save_epochs == 0 and epoch != 0:
        net.cpu()
        print('save the model')
        torch.save(net.state_dict(), args.dir +
                   f'/{args.type}_id_model_{epoch}.pt')
        net.cuda(device_id)

SCF_idx = get_uncertain_index(
    glob.glob(args.dir + f'/{args.type}_id_model_*.pt'))
hard_idx = get_hard_index(glob.glob(args.dir + f'/{args.type}_id_model_*.pt'))
print(f"Length of the SCF set : {len(SCF_idx)}")
print(f"Length of the Hard set : {len(hard_idx)}")

print(
    f'Save the index of the uncertain samples for [{args.dir}/uncertain_idx.pk]')
with open(f'{args.dir}/uncertain_idx.pk', 'wb') as f:
    pickle.dump(SCF_idx, f, pickle.HIGHEST_PROTOCOL)

print(f'Save the index of the hard samples for [{args.dir}/hard_idx.pk]')
with open(f'{args.dir}/hard_idx.pk', 'wb') as f:
    pickle.dump(hard_idx, f, pickle.HIGHEST_PROTOCOL)
