from __future__ import print_function
import pickle
import pdb
import random
from torchvision.datasets import Food101
from torchvision.models import efficientnet_b4
import numpy as np
from utils.dta_model import DeepDTA
from utils.dta_data import DataSet, DTIDataset, DataLoader, collate_dataset, get_spurious_prots, get_spurious_chems
from utils.utils import init_grad
from utils.dataloader import *
import heapq
import argparse
from torchvision.models import resnet50, densenet121
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import sys
sys.path.append('..')

parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default=None, required=True,
                    help='path to save checkpoints (default: None)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--device_id', type=int, help='device id to use')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--l2', type=float, default=1e-4,
                    help='weight decay')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='initial learning rate')
parser.add_argument('--p_h', type=int, default=5,
                    help='Coefficient to train the hard sample set.')
parser.add_argument('--ckpt',
                    help='Previous checkpoint.')
parser.add_argument('--curricular', action='store_true')
parser.add_argument('--idx', type=str,
                    help='The pickle file including indices of hard samples.')
parser.add_argument("-f", "--fold_num", type=int,
                    help="Fold number. It must be one of the {0,1,2,3,4}.")
parser.add_argument("--type", default='None',
                    help="Davis or Kiba; dataset select.")

args = parser.parse_args()

device_id = args.device_id
use_cuda = torch.cuda.is_available()

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")

FOLD_NUM = int(args.fold_num)  # {0,1,2,3,4}


class DataSetting:
    def __init__(self):
        self.dataset_path = 'data/{}/'.format(args.type)
        self.problem_type = '1'
        self.is_log = False if args.type == 'kiba' else True


data_setting = DataSetting()

dataset = DataSet(data_setting.dataset_path,
                  1000 if args.type == 'kiba' else 1200,
                  100 if args.type == 'kiba' else 85)  # KIBA (1000,100) DAVIS (1200, 85)
smiles, proteins, Y = dataset.parse_data(data_setting)
test_fold, train_folds = dataset.read_sets(data_setting)

label_row_inds, label_col_inds = np.where(np.isnan(Y) == False)
test_drug_indices = label_row_inds[test_fold]
test_protein_indices = label_col_inds[test_fold]

train_fold_sum = []
for i in range(5):
    if i != FOLD_NUM:
        train_fold_sum += train_folds[i]

train_drug_indices = label_row_inds[train_fold_sum]
train_protein_indices = label_col_inds[train_fold_sum]

valid_drug_indices = label_row_inds[train_folds[FOLD_NUM]]
valid_protein_indices = label_col_inds[train_folds[FOLD_NUM]]

dti_dataset = DTIDataset(
    smiles, proteins, Y, train_drug_indices, train_protein_indices, spurious=False)
valid_dti_dataset = DTIDataset(
    smiles, proteins, Y, valid_drug_indices, valid_protein_indices)
test_dti_dataset = DTIDataset(
    smiles, proteins, Y, test_drug_indices, test_protein_indices)

print(f"Dataset size before upsampling: [{len(dti_dataset)}]")

if args.idx != None:
    hard_idx = pickle.load(open(args.idx, 'rb'))
    print(f"The number of the SCF samples: [{len(hard_idx)}]")
    if len(hard_idx) > 0:
        upsampled_data = []
        for idx in hard_idx:
            upsampled_data.append(dti_dataset[idx])
        upsampled_data = np.repeat(np.array(upsampled_data), args.p_h, axis=0)
        dti_dataset.X = np.concatenate((dti_dataset.X, upsampled_data), axis=0)
    else:
        print("[Warning] ! ! ! The length of errorset is zero. ! ! !")

print(f"Dataset size after upsampling: [{len(dti_dataset)}]")

trainloader = DataLoader(dti_dataset, batch_size=256, shuffle=True,
                         collate_fn=collate_dataset, pin_memory=True, num_workers=8)
validloader = DataLoader(valid_dti_dataset, batch_size=256, shuffle=False,
                         collate_fn=collate_dataset, pin_memory=True, num_workers=8)
testloader = DataLoader(test_dti_dataset, batch_size=256, shuffle=False,
                        collate_fn=collate_dataset, pin_memory=True, num_workers=8)

#######* Build model #######################################
print('==> Building model..')
net = DeepDTA()

if use_cuda:
    net.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.autograd.set_detect_anomaly(False)
    torch.autograd.profiler.profile(False)
    torch.autograd.profiler.emit_nvtx(False)

#######* Train models #######################################


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0

    for batch_idx, (d, p, y) in enumerate(trainloader):
        if use_cuda:
            d, p, y = d.cuda(device_id), p.cuda(
                device_id), y.cuda(device_id).unsqueeze(1)
        init_grad(net)
        outputs = net(d, p)
        loss = criterion(outputs, y)
        loss.requires_grad_(True)
        loss.backward()
        optimizer.step()

        train_loss += loss.data.item()
        if batch_idx % 50 == 0:
            print('Loss: %.3f' % (train_loss/(batch_idx+1)))


def test(epoch, loader='test'):
    global best_mse, sp_prots, sp_chems
    net.eval()
    test_loss = 0
    loss_per_samples = []
    labels = []

    dataset = test_dti_dataset if loader == 'test' else valid_dti_dataset
    loader = testloader if loader == 'test' else validloader

    with torch.no_grad():
        for batch_idx, (d, p, y) in enumerate(loader):
            if use_cuda:
                d, p, y = d.cuda(device_id), p.cuda(
                    device_id), y.cuda(device_id).unsqueeze(1)

            outputs = net(d, p)
            loss = criterion(outputs, y)
            loss_per_samples.append(((outputs - y)**2).squeeze().cpu().numpy())
            labels.append(y.squeeze().cpu().numpy())
            test_loss += loss.data.item()

    loss_per_samples = np.concatenate(loss_per_samples)
    prot_idx = np.unique(dataset.protein_idxes)
    lig_idx = np.unique(dataset.drug_idxes)

    mse_by_lig = dict(zip(lig_idx, np.zeros(len(lig_idx))))
    mse_by_prot = dict(zip(prot_idx, np.zeros(len(prot_idx))))

    for i in range(len(loss_per_samples)):
        mse_by_prot[dataset.protein_idxes[i]] += loss_per_samples[i]
        mse_by_lig[dataset.drug_idxes[i]] += loss_per_samples[i]

    mse_by_prot = np.array(list(mse_by_prot.values())) / \
        np.unique(dataset.protein_idxes, return_counts=True)[1]
    mse_by_lig = np.array(list(mse_by_lig.values())) / \
        np.unique(dataset.drug_idxes, return_counts=True)[1]

    #spurious_prot_mse = mse_by_prot[sp_prots].mean()
    #spurious_chem_mse = mse_by_prot[sp_chems].mean()

    worst_prot_mse = max(mse_by_prot)
    adjusted_mse = np.mean(mse_by_prot)

    worst_lig_mse = max(mse_by_lig)
    adjusted_lig_mse = np.mean(mse_by_lig)

    print('\nTest set: Average loss: {:.4f}, worst-protein loss: {:.4f}, worst-ligand loss: {:.4f}\n'.format(
        test_loss/len(testloader), worst_prot_mse, worst_lig_mse))

    worst_mse = worst_prot_mse if args.type == 'kiba' else worst_lig_mse

    return test_loss/len(testloader), worst_mse


datasize = len(dti_dataset)
num_batch = datasize/args.batch_size+1
print(f"Num batch: [{num_batch}]")

criterion = F.mse_loss
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.l2)
sp_prots = get_spurious_prots(dti_dataset, test_dti_dataset)
sp_chems = get_spurious_chems(dti_dataset, test_dti_dataset)

if args.ckpt != None:
    test(0)

prev_best_mse, prev_best_wo_mse = 10000, 10000
for epoch in range(args.epochs):
    train(epoch)
    test(epoch, 'test')

    cur_mse, cur_wo_mse = test(epoch, 'val')
    if cur_wo_mse < prev_best_wo_mse:
        print("Save the best wo mse model !")
        torch.save(net.state_dict(), args.dir +
                   f'/{args.type}_model_best_wo_mse.pt')
        prev_best_wo_mse = cur_wo_mse
    if cur_mse < prev_best_mse:
        print("Save the best mse model !")
        torch.save(net.state_dict(), args.dir +
                   f'/{args.type}_model_best_mse.pt')
        prev_best_mse = cur_mse

torch.save(net.state_dict(), args.dir + f'/{args.type}_model_final.pt')
