from util_model.util_model import ProteinVista, load_pretrained_model, kernel_sizes, depths
from util_model.util_helper import create_labels_dict, combine_bounding_boxes, full_bounding_boxes, setup_logging, is_cuda
from util_model.util_data import get_loader_SM, get_loader
import logging
import os
from os.path import join
import argparse
import pandas as pd

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

import numpy as np
import time
from sklearn.metrics import accuracy_score, matthews_corrcoef, r2_score
import torch
import torch.nn as nn
import torch.multiprocessing as mp  
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import gc
import random


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=12576, help="Port")
    parser.add_argument("--learning_rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
    parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
    parser.add_argument("--num_workers", type=int, required=True, help="Number of workers for dataloader")
    parser.add_argument("--num_epochs", type=int, required=True, help="Number of epochs")
    parser.add_argument("--pretrained_model", type=str, default="", help="Path of pretrained model")
    parser.add_argument("--save_dir", type=str, default="", help="Directory to save the model")
    parser.add_argument("--classification", action="store_true", help="Classification or regression")
    parser.add_argument("--balance_classes", action="store_true", help="Balance classes")
    parser.add_argument("--train_file", type=str, default="", help="Path to train file")
    parser.add_argument("--val_file", type=str, default="", help="Path to validation file")
    parser.add_argument("--pos_class_weight", type=float, default=1.0, help="Positive class weight")
    parser.add_argument("--fix_Molformer", action="store_true", help="Fix Molformer")
    parser.add_argument("--PLI", action="store_true", help="PLI (Protein-Ligand Interaction prediction)")
    parser.add_argument("--log_name", type=str, default="", help="Logfile name")
    return parser.parse_args()


args = get_arguments()
args_dict = vars(args)
globals().update(args_dict)


torch.backends.cudnn.benchmark = True
n_gpus = len(list(range(torch.cuda.device_count())))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if PLI:
    loader_function = get_loader_SM
    Mol_dict = np.load(join(save_dir, "SMILES", "SMILES_repr.npy"), allow_pickle=True).item()
    small_molecule_dim = 768
    print("Protein-Ligand prediction task")
else:
    loader_function = get_loader
    Mol_dict = None
    small_molecule_dim = 0

logger, setting = setup_logging(log_name, learning_rate, batch_size, join(save_dir, "log_outputs"))


def trainer(gpu, device, train_data, val_data, combined_bounding_boxes, bs_dict, n_gpus):    
    
    logging.info("GPU: " + str(gpu) + "Port: " + str(port))
    print(len(val_data), len(train_data))
    print(device, gpu, num_workers, n_gpus, gpu)
    valloader = loader_function(
        data_path=save_dir,
        filenames = val_data,
        device = device,
        gpu = gpu,
        labels_dict = labels_dict,
        Mol_dict = Mol_dict,
        split_files_with_boxes = combined_bounding_boxes,
        bs_dict = bs_dict,
        shuffle = False,           
        num_workers = num_workers,
        n_gpus = n_gpus,
        rank = gpu,                
        seed = 1,
        augment = True
    )
    logging.info("Validation data loaded")
    print(len(valloader))

    if is_cuda(device):
        torch.cuda.set_device(gpu)  # Set the device first
        setup(gpu, n_gpus, str(port))
        torch.manual_seed(0)

    num_layers = 5
    model = ProteinVista(
        channels=5,
        output_dim=1,
        num_layers= num_layers,
        kernel_sizes= kernel_sizes[:num_layers],
        depths= depths[:num_layers],
        hidden_layer_dim=256,
        small_molecule_dim = small_molecule_dim,
    )
    logging.info("Model created")

    if gpu == 0:
        logging.info("Number of trainable parameters: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
        #logging.info("Model parameters: " + str([name for name, param in model.named_parameters()]))


    if pretrained_model != "":
        logging.info("Loading pretrained model: " + pretrained_model)
        checkpoint = torch.load(pretrained_model, map_location=device)
        if "model_state_dict" in checkpoint:
            model = load_pretrained_model(model, checkpoint["model_state_dict"])
        else:
            model = load_pretrained_model(model, checkpoint)

    if is_cuda(device):
        model = model.to(gpu)
        model = DDP(model, device_ids=[gpu])
            
    if classification:
        pos_weight = torch.tensor([pos_class_weight]).to(gpu) if is_cuda(device) else torch.tensor([pos_class_weight])
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    else:   
        criterion = nn.MSELoss()

    optimizer = torch.optim.AdamW(params = list(model.parameters()), lr=learning_rate, weight_decay=weight_decay)

    if pretrained_model != "":
        if "optimizer_state_dict" in checkpoint:
            try:
                optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            except:
                logging.info("Could not load optimizer state.")

    val_loss_old = 0

    scaler = GradScaler()
    logging.info("Starting training:")
    for epochs in range(num_epochs):


        trainloader = loader_function(
                data_path=save_dir,
                filenames=train_data,
                device=device,
                gpu=gpu,
                labels_dict=labels_dict,
                Mol_dict= Mol_dict,
                split_files_with_boxes=combined_bounding_boxes,
                bs_dict=bs_dict,
                shuffle=True,            
                num_workers=num_workers,
                n_gpus=n_gpus,
                rank=gpu,                
                seed=epochs,
                augment = True
        )
        logging.info("Train data loaded")
        logging.info("Length of trainloader: " + str(len(trainloader)))


        epoch_time = time.perf_counter()
        model.train()   

        train_loss = 0
        y_true, y_pred = [], []


        for step, batch in enumerate(trainloader):
            print("Step: " + str(step))
            if PLI:
                X, smiles_emb, smiles_attn_mask, y = batch
        
                if is_cuda(device):
                    X = X.to(gpu, non_blocking=True)
                    y = y.to(gpu, non_blocking=True)
                    smiles_emb = smiles_emb.to(gpu, non_blocking=True)
                    smiles_attn_mask = smiles_attn_mask.to(gpu, non_blocking=True)
                if len(y) <= 1:
                    continue
            else:
                X, y = batch
                if is_cuda(device):
                    X = X.to(gpu, non_blocking=True)
                    y = y.to(gpu, non_blocking=True)
                if len(y) <= 1:
                    continue


            optimizer.zero_grad(set_to_none=True)

            with autocast():
                if PLI:
                    output = model(X, smiles_emb, smiles_attn_mask)
                else:
                    output = model(X)
                y, output = y.float(), output.float()

                loss = criterion(output.view(-1), y.view(-1))

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
            train_loss += loss.item()

            if classification:
                output = torch.sigmoid(output)
            else:
                y_true.extend(y.detach().cpu().numpy().reshape(-1)), y_pred.extend(output.detach().cpu().numpy().reshape(-1))
            
            if step % 20 == 0:
                y_true, y_pred = y_true[-(500):], y_pred[-(500):]
                if classification:
                    y_pred_array = np.array(y_pred)
                    nan_mask = np.isnan(y_pred_array)
                    if np.any(nan_mask):
                        logging.warning(f"Found {np.sum(nan_mask)} NaN values in y_pred. Filtering them out.")
                        valid_indices = ~nan_mask
                        filtered_y_true = np.array(y_true)[valid_indices]
                        filtered_y_pred = y_pred_array[valid_indices]
                        if len(filtered_y_true) > 0:
                            train_mcc = matthews_corrcoef(filtered_y_true, np.round(filtered_y_pred))
                            train_acc = accuracy_score(filtered_y_true, np.round(filtered_y_pred))
                        else:
                            train_mcc = 0
                            train_acc = 0
                            logging.warning("No valid predictions left after filtering NaNs")
                    else:
                        train_mcc = matthews_corrcoef(y_true, np.round(y_pred))
                        train_acc = accuracy_score(y_true, np.round(y_pred))
                    logging.info("Epoch: " + str(epochs) + " Step: " + str(step) + " Loss: " + str(train_loss/(step+1)) + \
                                    " MCC: " + str(train_mcc) + "ACC: " + str(train_acc))

                else:
                    train_mse = np.mean((np.array(y_true) - np.array(y_pred))**2)
                    train_r2 = r2_score(y_true, y_pred)

                    logging.info("Epoch: " + str(epochs) + " Step: " + str(step) + " Loss: " + str(train_loss/(step+1)) + \
                                    " MSE: " + str(train_mse) + " R2: " + str(train_r2))
                
        torch.cuda.empty_cache()
        gc.collect()
        del X, output, loss
                
        logging.info("Epoch: " + str(epochs) + " Loss: " + str(train_loss/(step+1)))
        if gpu == 0:
            epoch_time = time.perf_counter() - epoch_time
            logging.info("Epoch time: " + str(epoch_time))


        with torch.no_grad(), autocast():
            val_loss = 0
            y_true, y_pred = [], []
            for step, batch in enumerate(valloader):
                if len(batch) <= 1:
                    continue
                
                if PLI:
                    X, smiles_emb, smiles_attn_mask, y = batch
                else:
                    X, y = batch
                if is_cuda(device):
                    X = X.to(gpu, non_blocking=True)
                    y = y.to(gpu, non_blocking=True)
                    if PLI:
                        smiles_emb = smiles_emb.to(gpu, non_blocking=True)
                        smiles_attn_mask = smiles_attn_mask.to(gpu, non_blocking=True)
                
                if PLI:
                    output = model(X, smiles_emb, smiles_attn_mask)
                else:
                    output = model(X)
                y, output = y.float(), output.float()
                
                loss = criterion(output.view(-1), y.view(-1))
                val_loss += loss.item()

                if classification:
                    output = torch.sigmoid(output)
                
                y_true.extend(y.detach().cpu().numpy().reshape(-1)), y_pred.extend(output.detach().cpu().numpy().reshape(-1))

            if classification:
                y_pred_array = np.array(y_pred)
                nan_mask = np.isnan(y_pred_array)
                if np.any(nan_mask):
                    logging.warning(f"Found {np.sum(nan_mask)} NaN values in validation y_pred. Filtering them out.")
                    valid_indices = ~nan_mask
                    filtered_y_true = np.array(y_true)[valid_indices]
                    filtered_y_pred = y_pred_array[valid_indices]
                    if len(filtered_y_true) > 0:
                        val_mcc = matthews_corrcoef(filtered_y_true, np.round(filtered_y_pred))
                        val_acc = accuracy_score(filtered_y_true, np.round(filtered_y_pred))
                    else:
                        val_mcc = 0
                        val_acc = 0
                        logging.warning("No valid validation predictions left after filtering NaNs")
                else:
                    val_mcc = matthews_corrcoef(y_true, np.round(y_pred))
                    val_acc = accuracy_score(y_true, np.round(y_pred))
                logging.info("Validation Loss: " + str(val_loss/(step+1)) + " Validation MCC: " + str(val_mcc) + " Validation ACC: " + str(val_acc) + " GPU: " + str(gpu))
            else:
                val_mse = np.mean((np.array(y_true) - np.array(y_pred))**2)
                val_r2 = r2_score(y_true, y_pred)
                logging.info("Validation Loss: " + str(val_loss/(step+1)) + " Validation MSE: " + str(val_mse) + " Validation R2: " + str(val_r2) + " GPU: " + str(gpu))


            torch.cuda.empty_cache()
            gc.collect()
            del X, output, loss

        

        if classification:
            if n_gpus > 1:
                all_val_mccs = [torch.tensor(val_mcc).to(device)]
                torch.distributed.all_gather(all_val_mccs, all_val_mccs[0])
                val_mcc = torch.mean(torch.stack(all_val_mccs)).item()
                logging.info(f"Mean validation MCC across all GPUs: {val_mcc}")

            if val_mcc > val_loss_old:
                logging.info("New best model found")
                val_loss_old = val_mcc
                if gpu == 0:
                    torch.save({'model_state_dict': model.module.state_dict(), 
                                'optimizer_state_dict': optimizer.state_dict()}, 
                                join(save_dir, "models", setting + "_best_model.pth"))
        else:
            if n_gpus > 1:
                all_val_r2s = [torch.tensor(val_r2).to(device)]
                torch.distributed.all_gather(all_val_r2s, all_val_r2s[0])
                val_r2 = torch.mean(torch.stack(all_val_r2s)).item()
                logging.info(f"Mean validation R2 across all GPUs: {val_r2}")

            if val_r2 > val_loss_old:
                logging.info("New best model found")
                val_loss_old = val_r2
                if gpu == 0:
                    torch.save({'model_state_dict': model.module.state_dict(), 
                                'optimizer_state_dict': optimizer.state_dict()}, 
                                join(save_dir, "models", setting + "_best_model.pth"))


    
if __name__ == '__main__':

    if not os.path.exists(join(save_dir, "models")):
        os.makedirs(join(save_dir, "models"))

    train_df = pd.read_csv(train_file)
    val_df = pd.read_csv(val_file)
    labels_dict, train_data, val_data = create_labels_dict(train_df, val_df)

    bs_dict = {'[64 64 64]' : batch_size,
                '[96 96 96]' : batch_size,
                '[128 128 128]' : batch_size,
                '[160 160 160]' : batch_size,
                }

    input_data = list(labels_dict.keys())
    logging.info("Number of input files: " + str(len(input_data)))
    
    combined_bounding_boxes_1 = combine_bounding_boxes(join(save_dir, "protein_3D_bounding_boxes"))
    combined_bounding_boxes = full_bounding_boxes(combined_bounding_boxes_1, input_data)
    logging.info("Number of bounding boxes: " + str(len(combined_bounding_boxes)))
    
    if len(combined_bounding_boxes) == 0:
        combined_bounding_boxes_1 = {key.split('-')[1] + ".npy": value for key, value in combined_bounding_boxes_1.items()}
        combined_bounding_boxes = full_bounding_boxes(combined_bounding_boxes_1, input_data)
        logging.info("Number of bounding boxes: " + str(len(combined_bounding_boxes)))


    keys = list(combined_bounding_boxes.keys())
    train_data = list(set(train_data).intersection(keys))
    val_data = list(set(val_data).intersection(keys))
    logging.info("Train data size: " + str(len(train_data)) + " Validation data size: " + str(len(val_data)))

    if torch.cuda.is_available():
        device = torch.device('cuda')
        device_ids = list(range(torch.cuda.device_count()))
        gpus = len(device_ids)
        args.world_size = gpus
        
    else:
        device = torch.device('cpu')
        args.world_size = -1
  

    if torch.cuda.is_available() and n_gpus > 1:
        mp.spawn(trainer, nprocs=n_gpus, args=([device, train_data, val_data, combined_bounding_boxes, bs_dict, gpus]))
    else:
        trainer(0, device, train_data, val_data, combined_bounding_boxes, bs_dict, 1)
