import numpy as np
import torch
import argparse
import os
import sys
import math
from BayesianDTI.utils import *
from torch.utils.data import DataLoader
from BayesianDTI.datahelper import *
from BayesianDTI.model import DeepDTA
from BayesianDTI.loss import *
from scipy.stats import t
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('ggplot')

parser = argparse.ArgumentParser()
parser.add_argument("-f", "--fold_num", type=int, default=0,
                    help="Fold number. It must be one of the {0,1,2,3,4}.")
parser.add_argument("-e", "--epochs", type=int, default=200,
                    help="Number of epochs.")
parser.add_argument("-o", "--output", default='test',
                    help="The output directory.")
parser.add_argument("--type", default='None',
                    help="Davis or Kiba; dataset select.")
parser.add_argument("--model",
                    help="The trained baseline model. If given, keep train the model.")
parser.add_argument("--l2", type=float, default=0.0001,
                    help="Coefficient of L2 regularization")
parser.add_argument("--fo", type=float, default=0.0,
                    help="Coefficient of activation orthogonalization regularization")
parser.add_argument("--eoe", type=float, default=1e-4,
                    help="OOD loss")
parser.add_argument("--cuda", type=int, default=1, help="cuda device number")

args = parser.parse_args()
torch.cuda.set_device(args.cuda)
args.type = args.type.lower()
dir = args.output
print("Arguments: ########################")
print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
print("###################################")

try:
    os.mkdir(args.output)
except FileExistsError:
    print("The output directory {} is already exist.".format(args.output))

#######################################################################
### Load data
FOLD_NUM = int(args.fold_num) # {0,1,2,3,4}

class DataSetting:
    def __init__(self):
        self.dataset_path = 'data/{}/'.format(args.type)
        self.problem_type = '1'
        self.is_log = False if args.type == 'kiba' else True

data_setting = DataSetting()

dataset = DataSet(data_setting.dataset_path,
                  1000 if args.type == 'kiba' else 1200,
                  100 if args.type == 'kiba' else 85) ## KIBA (1000,100) DAVIS (1200, 85)
smiles, proteins, Y = dataset.parse_data(data_setting)
test_fold, train_folds = dataset.read_sets(data_setting)

label_row_inds, label_col_inds = np.where(np.isnan(Y)==False)
test_drug_indices = label_row_inds[test_fold]
test_protein_indices = label_col_inds[test_fold]

train_fold_sum = []
for i in range(5):
    if i != FOLD_NUM:
        train_fold_sum += train_folds[i]

train_drug_indices = label_row_inds[train_fold_sum]
train_protein_indices = label_col_inds[train_fold_sum]

valid_drug_indices = label_row_inds[train_folds[FOLD_NUM]]
valid_protein_indices = label_col_inds[train_folds[FOLD_NUM]]

dti_dataset = DTIDataset(smiles, proteins, Y, train_drug_indices, train_protein_indices)
valid_dti_dataset = DTIDataset(smiles, proteins, Y, valid_drug_indices, valid_protein_indices)
test_dti_dataset = DTIDataset(smiles, proteins, Y, test_drug_indices, test_protein_indices)

dataloader = DataLoader(dti_dataset, batch_size=256, shuffle=True, collate_fn=collate_dataset)
valid_dataloader = DataLoader(valid_dti_dataset, batch_size=256, shuffle=True, collate_fn=collate_dataset)
test_dataloader = DataLoader(test_dti_dataset, batch_size=256, shuffle=True, collate_fn=collate_dataset)
##########################################################################
### Define models
device = 'cuda:{}'.format(args.cuda)


max_ens_step = 5
prior = None
for ens_step in range(max_ens_step):
    ##########################################################################
    ### Training Ensemble Members
    ##########################################################################

    dti_model = DeepDTA().to(device)
    objective_mse = torch.nn.MSELoss()
    opt = torch.optim.Adam(dti_model.parameters(), lr=0.001, weight_decay=args.l2)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=opt, lr_lambda=lambda epoch: 0.99 ** epoch,
                                    last_epoch=-1,
                                    verbose=False)
    best_loss = 10000
    total_valid_loss = 0.
    total_loss = 0.
    train_loss_history = []
    valid_loss_history = []
    for epoch in range(args.epochs):
        it = 1
        act_loss = 0
        total_eoe_loss = 0.0
        dti_model.train()
        for d, p, y in dataloader:
            d, p, y = d.to(device), p.to(device), y.unsqueeze(1).to(device)
            pred = dti_model(d, p)

            opt.zero_grad()
            loss = objective_mse(pred, y)
            (loss + args.fo * dti_model.activation_ort).backward()
            total_loss += loss.item()
            act_loss += dti_model.activation_ort.item()
            opt.step()

            #* OOD loss
            if ens_step > 0 and args.eoe > 0.0:
                delta = torch.rand([d.shape[0], 1, 1]).to(device) * 5000 + 100
                pred = dti_model(d, p, delta=delta)
                pred_p = prior(d, p, delta=delta).detach()
                eoe_loss = (-torch.log((pred - pred_p).square())).mean()
                (args.eoe * eoe_loss).backward()
                total_eoe_loss += eoe_loss.item()
                opt.step()
                opt.zero_grad()

            if it % 50 == 0:
                if ens_step > 0:
                    print(f"Iter {it}: Train MSE [{total_loss/it:.5f}] Act ort loss: [{act_loss/it:.5f}] OOD loss: [{total_eoe_loss/it:.5f}]")
                else:
                    print(f"Iter {it}: Train MSE [{total_loss/it:.5f}] Act ort loss: [{act_loss/it:.5f}]")

            it += 1
        scheduler.step()

        dti_model.eval()
        for d_v, p_v, y_v in valid_dataloader:
            y_v = y_v.unsqueeze(1).to(device)
            pred_v = dti_model(d_v.to(device), p_v.to(device))
            loss_v = objective_mse(pred_v, y_v)
            total_valid_loss += loss_v.item()

        train_loss = total_loss/len(dataloader)
        valid_loss = total_valid_loss/len(valid_dataloader)

        train_loss_history.append(train_loss)
        valid_loss_history.append(valid_loss)

        if best_loss >= valid_loss:
            torch.save(dti_model, dir + f'/dti_model_best_{ens_step + 1}.model')
            best_loss = valid_loss

        print(f"Epoch {epoch + 1}: Train NLL [{train_loss:.5f}] Val MSE [{valid_loss:.5f}]")

        total_loss = 0.
        total_valid_loss = 0.

    dti_model.activation_ort = None
    prior = DeepDTA().to(device)#copy.deepcopy(dti_model)
    prior.load_state_dict(torch.load(dir + f'/dti_model_best_{ens_step + 1}.model').state_dict())
    for w in prior.parameters():
        w.requires_grad_(False)

##########################################################################
fig = plt.figure(figsize=(15,5))
plt.plot(valid_loss_history, label="MSE")
plt.plot(train_loss_history, label="NLL")
plt.title("Validate loss")
plt.xlabel("Validate steps")
plt.legend(facecolor='white', edgecolor='black')

plt.tight_layout()

plt.savefig(dir + "/MultitaskLoss.png")
##########################################################################
