from util_model.util_model import ProteinVista, load_pretrained_model, kernel_sizes, depths
from util_model.util_helper import combine_bounding_boxes, full_bounding_boxes, setup_logging_eval, is_cuda
from util_model.util_data import get_loader_SM_inference, get_loader_inference
import logging
import os
from os.path import join
import argparse
import pandas as pd

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

import numpy as np
import time
from sklearn.metrics import accuracy_score, matthews_corrcoef, r2_score
import torch
import torch.nn as nn
import torch.multiprocessing as mp  
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import gc
import random


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=12576, help="Port")
    parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
    parser.add_argument("--num_workers", type=int, required=True, help="Number of workers for dataloader")
    parser.add_argument("--pretrained_model", type=str, required=True, help="Path of pretrained model")
    parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the model predictions")
    parser.add_argument("--classification", type=bool, default=False, help="Classification or regression")
    parser.add_argument("--input_file", type=str, default="", help="Path to input file")
    parser.add_argument("--PLI", action="store_true", help="PLI (Protein-Ligand Interaction prediction)")
    parser.add_argument("--log_name", type=str, default="", help="Logfile name")
    parser.add_argument("--num_iterations", type=int, default=1, help="Number of iterations")
    return parser.parse_args()


args = get_arguments()
args_dict = vars(args)
globals().update(args_dict)


torch.backends.cudnn.benchmark = True
n_gpus = len(list(range(torch.cuda.device_count())))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if PLI:
    loader_function = get_loader_SM_inference
    Mol_dict = np.load(join(save_dir, "SMILES", "SMILES_repr.npy"), allow_pickle=True).item()
    small_molecule_dim = 768
    print("Protein-Ligand prediction task")
else:
    loader_function = get_loader_inference
    Mol_dict = None
    small_molecule_dim = 0

logger, setting = setup_logging_eval(log_name, batch_size, join(save_dir, "log_outputs"))


def inference(gpu, device, input_data, input_df, combined_bounding_boxes, bs_dict, n_gpus):

    num_layers = 5
    model = ProteinVista(
        channels=5,
        output_dim=1,
        num_layers= num_layers,
        kernel_sizes= kernel_sizes[:num_layers],
        depths= depths[:num_layers],
        hidden_layer_dim=256,
        small_molecule_dim = small_molecule_dim,
    )

    if is_cuda(device):
        model = model.to(gpu)
        model = DDP(model, device_ids=[gpu])

    checkpoint = torch.load(pretrained_model, map_location=device)
    if "model_state_dict" in checkpoint:
        model = load_pretrained_model(model, checkpoint["model_state_dict"])
    else:
        model = load_pretrained_model(model, checkpoint)


    for iter in range(num_iterations):
        pred_dict = {}

        valloader = loader_function(
            data_path=save_dir,
            filenames = input_data,
            device = device,
            gpu = gpu,
            Mol_dict = Mol_dict,
            split_files_with_boxes = combined_bounding_boxes,
            bs_dict = bs_dict,
            shuffle = True,           
            num_workers = num_workers,
            n_gpus = n_gpus,
            rank = gpu,                
            seed = iter,
            augment = False,
            return_names= True,
        )

        print("Number of batches: ", len(valloader))

        model.eval();
        with torch.no_grad(), autocast():
            for step, batch in enumerate(valloader):
                print(step)
                if len(batch) <= 1:
                    continue
                
                if PLI:
                    X, smiles_emb, smiles_attn_mask, names = batch
                else:
                    X, names = batch
                if is_cuda(device):
                    X = X.to(gpu, non_blocking=True)
                    if PLI:
                        smiles_emb = smiles_emb.to(gpu, non_blocking=True)
                        smiles_attn_mask = smiles_attn_mask.to(gpu, non_blocking=True)
                
                if PLI:
                    output = model(X, smiles_emb, smiles_attn_mask)
                else:
                    output = model(X)

                if classification:
                    output = torch.sigmoid(output)
                
                output = output.cpu().numpy()
                
                for i in range(len(names)):
                    pred_dict.update({names[i]: output[i]})

        if torch.distributed.is_initialized() and n_gpus > 1:
            gathered = [None] * n_gpus
            torch.distributed.all_gather_object(gathered, pred_dict)   # now every rank has gathered

            combined_pred_dict = {}
            for d in gathered:
                combined_pred_dict.update(d)

            pred_dict = combined_pred_dict
        
        if gpu == 0:
            if iter == 0:
                old_pred_dict = pred_dict.copy()
            else:
                for key in pred_dict.keys():
                    old_pred_dict[key] = (old_pred_dict[key]*iter + pred_dict[key]) / (iter + 1)




    if gpu == 0:
        np.save(join(save_dir, "predictions", "predictions.npy"), old_pred_dict)
        input_df["Prediction"] = ""
        for ind in input_df.index:
            PID, MID = input_df.loc[ind, "Protein-Id"], input_df.loc[ind, "Ligand_SMILES"]
            input_df.loc[ind, "Prediction"] = old_pred_dict[PID + "_" + MID]

        input_df.to_csv(join(save_dir, "predictions", "predictions.csv"), index=False)

    
if __name__ == '__main__':

    if not os.path.exists(join(save_dir, "predictions")):
        os.makedirs(join(save_dir, "predictions"))

    input_df = pd.read_csv(input_file)
    input_data = []
    for ind in input_df.index:
        PID, MID = input_df.loc[ind, "Protein-Id"], input_df.loc[ind, "Ligand_SMILES"]
        input_data.append(PID + "_" + MID)

    bs_dict = {'[64 64 64]' : batch_size,
                '[96 96 96]' : batch_size,
                '[128 128 128]' : batch_size,
                '[160 160 160]' : batch_size,
                }

    logging.info("Number of input files: " + str(len(input_data)))
    
    combined_bounding_boxes_1 = combine_bounding_boxes(join(save_dir, "protein_3D_bounding_boxes"))
    combined_bounding_boxes = full_bounding_boxes(combined_bounding_boxes_1, input_data)
    logging.info("Number of bounding boxes: " + str(len(combined_bounding_boxes)))
    
    if len(combined_bounding_boxes) == 0:
        combined_bounding_boxes_1 = {key.split('-')[1] + ".npy": value for key, value in combined_bounding_boxes_1.items()}
        combined_bounding_boxes = full_bounding_boxes(combined_bounding_boxes_1, input_data)
        logging.info("Number of bounding boxes: " + str(len(combined_bounding_boxes)))


    keys = list(combined_bounding_boxes.keys())
    input_data = list(set(input_data).intersection(keys))
    logging.info("Input data size: " + str(len(input_data)))

    if torch.cuda.is_available():
        device = torch.device('cuda')
        device_ids = list(range(torch.cuda.device_count()))
        gpus = len(device_ids)
        args.world_size = gpus
        
    else:
        device = torch.device('cpu')
        args.world_size = -1
  

    if torch.cuda.is_available() and n_gpus > 1:
        mp.spawn(inference, nprocs=n_gpus, args=([device, input_data,input_df, combined_bounding_boxes, bs_dict, gpus]))
    else:
        inference(0, device, input_data, input_df, combined_bounding_boxes, bs_dict, 1)