import os
import torch
import torch.distributed as dist
import numpy as np
import logging
from tqdm import tqdm
import argparse
from torch import nn


def get_arguments_eval():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=12576, help="Port")
    parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
    parser.add_argument("--num_workers", type=int, required=True, help="Number of workers for dataloader")
    parser.add_argument("--num_layers", type=int, required=True, help="Number of layers")
    parser.add_argument("--hidden_layer_dim", type=int, required=True, help="Hidden layer dimension")
    parser.add_argument("--model_path", type=str, required=True, help="Path to trained model")
    parser.add_argument("--mol_dict_dir", type=str, default="", help="Path of dictionary with small molecule dictionary")
    parser.add_argument("--log_name", type=str, required=True, help="Logging filename")
    parser.add_argument("--save_dir", type=str, required=True, help="Directory to save predictions")
    parser.add_argument("--input_dir", type=str, required=True, help="Directory with PDB files")
    parser.add_argument("--label_dict_file", type=str, required=True, help="File with label dictionaries")
    parser.add_argument("--output_dim", type=int, required=True, help="Number of output nodes")
    parser.add_argument("--bounding_boxes_dir", type=str, required=True, help="Directory with bounding boxes")
    parser.add_argument("--classification", type=bool, default=False, help="Classification or regression")
    parser.add_argument("--test_names", type=str, required=True, help="Path to test names")
    parser.add_argument("--log_dir", type=str, default="../log_outputs/training/", help="Directory to save the log")
    return parser.parse_args()

def get_arguments_CL():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=12576, help="Port")
    parser.add_argument("--learning_rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
    parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
    parser.add_argument("--num_workers", type=int, required=True, help="Number of workers for dataloader")
    parser.add_argument("--num_epochs", type=int, required=True, help="Number of epochs")
    parser.add_argument("--num_layers", type=int, required=True, help="Number of layers")
    parser.add_argument("--hidden_layer_dim", type=int, required=True, help="Hidden layer dimension")
    parser.add_argument("--input_dir", type=str, required=True, help="Input directory")
    parser.add_argument("--save_dir", type=str, required=True, help="Save directory")
    parser.add_argument("--bounding_boxes_dir", type=str, required=True, help="Directory with bounding boxes")
    parser.add_argument("--log_name", type=str, required=True, help="Logging filename")
    parser.add_argument("--log_dir", type=str, default="../log_outputs/training/", help="Directory to save the log")
    parser.add_argument("--esm_embeddings_path", type=str, required=True, help="Path to ESM-2 embeddings")
    parser.add_argument("--train_names", type=str, required=True, help="Path to train names")
    parser.add_argument("--val_names", type=str, required=True, help="Path to validation names")
    parser.add_argument("--pretrained_model", type=str, default="", help="Path of pretrained model")
    parser.add_argument("--temperature", type=float, default=0.07, help="Temperature for contrastive loss")
    return parser.parse_args()


def is_cuda(device):
    return device == torch.device('cuda')


def setup_logging(log_name, learning_rate, batch_size, log_dir):
    setting = f"ProteinVista_{log_name}_lr_{learning_rate}_bs_{batch_size}"
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    filename = f"{log_dir}/{setting}.txt"
    logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
    fhandler = logging.FileHandler(filename=filename, mode='a')
    logger.addHandler(fhandler)
    return logger, setting

def setup_logging_eval(log_name, batch_size, log_dir):
    setting = f"ProteinVista_{log_name}_bs_{batch_size}"
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    filename = f"{log_dir}/{setting}.txt"
    logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
    fhandler = logging.FileHandler(filename=filename, mode='a')
    logger.addHandler(fhandler)
    return logger, setting

def create_labels_dict(train_df, val_df):
    labels_dict = {}
    train_data = []
    for ind in train_df.index:
        PID, MID, y = train_df.loc[ind, "Protein-Id"], train_df.loc[ind, "Ligand_SMILES"], train_df.loc[ind, "output"]
        train_data.append(PID + "_" + MID)
        labels_dict[PID + "_" + MID] = y
    
    val_data = []
    for ind in val_df.index:
        PID, MID, y = val_df.loc[ind, "Protein-Id"], val_df.loc[ind, "Ligand_SMILES"], val_df.loc[ind, "output"]
        val_data.append(PID + "_" + MID)
        labels_dict[PID + "_" + MID] = y
    return labels_dict, train_data, val_data

def combine_bounding_boxes(bounding_boxes_dir):
    combined_bounding_boxes = {}
    box_files = os.listdir(bounding_boxes_dir)
    for i in tqdm(range(len(box_files)), desc="Combining Bounding Boxes"):
        filename = os.path.join(bounding_boxes_dir, f"bounding_boxes_{i}.npy")
        if os.path.exists(filename):
            boxes = np.load(filename, allow_pickle=True).item()
            combined_bounding_boxes.update(boxes)
    return combined_bounding_boxes

def full_bounding_boxes(combined_bounding_boxes, input_data, train_names = None):
        full_combined_bounding_boxes = {}
        for key in input_data:
            try:
                if train_names == "CL":
                    full_combined_bounding_boxes[key] = combined_bounding_boxes[key]
                elif train_names != "ROSETTA":
                    UID = key.split("_")[0]
                    full_combined_bounding_boxes[key] = combined_bounding_boxes[UID + ".npy"]
                else: 
                    UID = key
                    full_combined_bounding_boxes[key] = combined_bounding_boxes[UID + ".npy"]
            except:
                pass
        return full_combined_bounding_boxes


def setup(rank, world_size, port):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = port

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()



def compute_fmax(y_true, y_scores, thresholds=np.linspace(0, 1, 101)):
    """
    Compute Fmax for multi-label GO term prediction.

    Parameters:
        y_true (np.ndarray): binary ground truth matrix (num_samples x num_classes)
        y_scores (np.ndarray): predicted scores matrix (num_samples x num_classes)
        thresholds (iterable): list of thresholds to evaluate

    Returns:
        float: maximum F1 score (Fmax)
    """
    fmax = 0.0
    num_samples = y_true.shape[0]

    for tau in thresholds:
        y_pred = (y_scores >= tau).astype(int)

        avg_precisions = []
        avg_recalls = []

        for i in range(num_samples):
            true_labels = y_true[i]
            pred_labels = y_pred[i]

            tp = np.sum(true_labels * pred_labels)
            fp = np.sum((1 - true_labels) * pred_labels)
            fn = np.sum(true_labels * (1 - pred_labels))

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

            avg_precisions.append(precision)
            avg_recalls.append(recall)

        avg_precision = np.mean(avg_precisions)
        avg_recall = np.mean(avg_recalls)

        if avg_precision + avg_recall > 0:
            f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall)
            fmax = max(fmax, f1)

    return fmax



class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, protein_proj, esm_proj):
        protein_proj = nn.functional.normalize(protein_proj, dim=1)
        esm_proj = nn.functional.normalize(esm_proj, dim=1)

        similarity = torch.matmul(protein_proj, esm_proj.T) / self.temperature

        labels = torch.arange(similarity.size(0), device=similarity.device)

        loss_p2e = self.criterion(similarity.contiguous(), labels)
        loss_e2p = self.criterion(similarity.t().contiguous(), labels)
        return 0.5 * (loss_p2e + loss_e2p)