import gzip
import time
from descriptastorus.descriptors import rdNormalizedDescriptors
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from tqdm import tqdm
import torch
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
from rdkit.Chem import DataStructs
from rdkit.Avalon.pyAvalonTools import GetAvalonCountFP
from rdkit.Chem import rdReducedGraphs
from model.utils import pca_twice
from random import shuffle
generator = rdNormalizedDescriptors.RDKit2DNormalized()


def custom_collate_fn(batch):
    # Sort the batch by sequence length in descending order
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    sequences, labels = zip(*batch)
    emb_size = sequences[0].shape[1]

    # Get the length of the longest sequence in the batch
    max_len = len(sequences[0])

    # Pad sequences in the batch to have equal length
    padded_sequences = torch.zeros((len(sequences), max_len, emb_size), dtype=torch.float32)
    attention_masks = torch.zeros((len(sequences), max_len))

    for i, seq in enumerate(sequences):
        padded_sequences[i, :len(seq)] = torch.tensor(seq)
        attention_masks[i, :len(seq)] = 1  # Set attention mask to 1 for non-padding positions

    return {
        'input': padded_sequences,
        'attention_mask': attention_masks,
        'labels': torch.tensor(labels)
    }


def evaluation(prediction, labels, metrics):
    prediction = prediction.detach().cpu()
    labels = labels.detach().cpu()
    if metrics == 'mae':
        test_metrics = torch.mean(torch.abs(prediction - labels)).cpu().detach().item()
    elif metrics == 'spearman':
        test_metrics = -spearmanr(prediction.cpu(), labels.cpu())[0]
    elif metrics == 'auroc':
        prediction = torch.nn.functional.softmax(prediction, dim=-1)
        test_metrics = -roc_auc_score(labels.cpu(), prediction[:, 1].cpu())
    elif metrics == 'auprc':
        prediction = torch.nn.functional.softmax(prediction, dim=-1)
        test_metrics = -average_precision_score(labels.cpu(), prediction[:, 1].cpu())
    else:
        print("Not supported type of metrics, supported metrics are mae, spearman, auroc or auprc")
        return None
    return test_metrics, prediction, labels


def evaluation_lightgbm(prediction, labels, metrics):
    if metrics == 'mae':
        test_metrics = torch.mean(torch.abs(prediction - labels)).cpu().detach().item()
    elif metrics == 'spearman':
        test_metrics = -spearmanr(prediction, labels)[0]
    elif metrics == 'auroc':
        test_metrics = -roc_auc_score(labels, prediction)
    elif metrics == 'auprc':
        test_metrics = -average_precision_score(labels, prediction)
    else:
        print("Not supported type of metrics, supported metrics are mae, spearman, auroc or auprc")
        return None
    return test_metrics, prediction, labels


def evaluation_belka(prediction, labels):
    test_metrics = roc_auc_score(labels.cpu(), prediction.cpu())
    ap_metrics = average_precision_score(labels.cpu(), prediction.cpu())
    ap_micro_metrics = average_precision_score(labels.cpu(), prediction.cpu(), average="micro")
    return test_metrics, ap_metrics, ap_micro_metrics


def test_single(net, test_data_loader, test_dataset, args, device, pid):
    start = time.time()
    all_p_test = []
    pbar = tqdm(test_data_loader)
    all_ids = []
    with torch.no_grad():
        for input_ids, attention_mask, idx in pbar:
            p = net(input_ids.to(device), attention_mask.to(device))
            all_p_test.append(p.squeeze(-1))
            all_ids.append(idx)
        all_p_test = torch.cat(all_p_test)
    all_ids = torch.cat(all_ids)
    all_p_test = torch.sigmoid(all_p_test)
    output_test = pd.DataFrame.from_dict(
        {
            'id': all_ids.tolist(),
            'binds': all_p_test.tolist()
        }
    )
    with gzip.open(args.output + args.dataset_train.split("/")[-1] + "." + str(pid) + ".prediction.gz",
                   "wt") as f:
        output_test.to_csv(f, index=False)
    print("Test time", time.time() - start)


def test(net, test_data_loader, test_dataset, args, device, pid):
    start = time.time()
    all_p_test = []
    pbar = tqdm(test_data_loader)
    all_ids = []
    with torch.no_grad():
        for input_ids, attention_mask, idx in pbar:
            p = net(input_ids.to(device), attention_mask.to(device))
            all_p_test.append(p.squeeze(-1))
            all_ids.append(idx)
        all_p_test = torch.cat(all_p_test)
    all_ids = torch.cat(all_ids)
    all_p_test = torch.sigmoid(all_p_test)
    final_ids = []
    final_p = []
    for idx, p in zip(all_ids, all_p_test):
        vv = test_dataset.data[idx.item()][1]
        for v in vv:
            final_ids.append(v[1])
            final_p.append(p[test_dataset.protein_name_map[v[0]]].item())
    output_test = pd.DataFrame.from_dict(
        {
            'id': final_ids,
            'binds': final_p
        }
    )
    with gzip.open(args.output + args.dataset_train.split("/")[-1] + "." + str(pid) + ".prediction.gz",
                   "wt") as f:
        output_test.to_csv(f, index=False)
    print("Test time", time.time() - start)


def test_bdb(net, test_data_loader, test_dataset, args, device, pid):
    start = time.time()
    all_p_test = []
    pbar = tqdm(test_data_loader)
    all_ids = []
    with torch.no_grad():
        for input_ids, attention_mask, idx, input_ids_b1, attention_mask_b1, \
            input_ids_b2, attention_mask_b2, input_ids_b3, attention_mask_b3 in pbar:
            p = net(input_ids.to(device),
                    attention_mask.to(device),
                    input_ids_b1.to(device),
                    attention_mask_b1.to(device),
                    input_ids_b2.to(device),
                    attention_mask_b2.to(device),
                    input_ids_b3.to(device),
                    attention_mask_b3.to(device))
            all_p_test.append(p.squeeze(-1))
            all_ids.append(idx)
        all_p_test = torch.cat(all_p_test)
    all_ids = torch.cat(all_ids)
    all_p_test = torch.sigmoid(all_p_test)
    final_ids = []
    final_p = []
    for idx, p in zip(all_ids, all_p_test):
        vv = test_dataset.data[idx.item()][1]
        for v in vv:
            final_ids.append(v[1])
            final_p.append(p[test_dataset.protein_name_map[v[0]]].item())
    output_test = pd.DataFrame.from_dict(
        {
            'id': final_ids,
            'binds': final_p
        }
    )
    with gzip.open(args.output + args.dataset_train.split("/")[-1] + "." + str(pid) + ".prediction.gz",
                   "wt") as f:
        output_test.to_csv(f, index=False)
    print("Test time", time.time() - start)


# example for converting a smiles string into the values
def rdkit_2d_normalized_features(smiles: str):
    results = generator.process(smiles)
    processed, features = results[0], results[1:]
    features = torch.tensor(features, dtype=torch.float32)
    features[torch.isnan(features)] = 0.0
    return features


def rdkit_descriptors(smiles):
    molecule = Chem.MolFromSmiles(smiles)

    # Get all descriptor names and functions from the Descriptors module
    descriptor_names = [desc_name for desc_name, _ in Descriptors._descList]
    descriptor_functions = [desc_fn for _, desc_fn in Descriptors._descList]

    # Calculate all descriptors
    descriptor_values = [desc_fn(molecule) for desc_fn in descriptor_functions]
    features = torch.tensor(descriptor_values, dtype=torch.float32)
    features[torch.isnan(features)] = 0.0
    return features


def get_scaffold(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    scaffold_smiles = Chem.MolToSmiles(scaffold)
    return scaffold_smiles


def sort_df_with_smiles_by_scaffold(data):
    # Extract scaffolds and group by scaffold
    smiles_list = [(idx, d[0]) for idx, d in enumerate(data)]
    scaffold_list = []
    for idx, smiles in smiles_list:
        scaffold = get_scaffold(smiles)
        scaffold_list.append((idx, scaffold))
    scaffold_list = sorted(scaffold_list, key=lambda x: x[1], reverse=True)
    indices = [x[0] for x in scaffold_list]
    new_data = [data[idx] for idx in indices]
    return new_data


def kfold_indices(data, n_splits=5, is_shuffle=True):
    n_samples = len(data)
    fold_size = n_samples // n_splits

    indices = np.arange(n_samples)
    if is_shuffle:
        shuffle(indices)
    folds = []

    for i in range(n_splits):
        test_start = i * fold_size
        test_end = test_start + fold_size
        test_indices = indices[test_start:test_end]
        train_indices = np.concatenate((indices[:test_start], indices[test_end:]))
        folds.append((train_indices, test_indices))

    return folds


def get_morgangen_embeddings(morgangen_model, input_ids, attention_mask):
    with torch.no_grad():
        outputs = morgangen_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.encoder_hidden_states[-1]
        attention_mask = attention_mask.unsqueeze(-1)
        drug = (last_hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
        return drug


def get_morgangen_pca_embeddings(morgangen_model, input_ids, attention_mask):
    with torch.no_grad():
        outputs = morgangen_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.encoder_hidden_states[-1]
        drug = []
        for i in range(len(last_hidden_states)):
            hidden_states = last_hidden_states[i][:10]
            hidden_states = pca_twice(hidden_states, 64, 10)
            drug.append(hidden_states.unsqueeze(0))
        drug = torch.cat(drug)
        drug = drug.view(drug.size(0), -1)
        return drug


def count_to_array(fingerprint):
    array = np.zeros((0,), dtype=np.int8)

    DataStructs.ConvertToNumpyArray(fingerprint, array)

    return array


def get_avalon_fingerprints(molecules, n_bits=1024):
    fingerprints = GetAvalonCountFP(molecules, nBits=n_bits)
    fingerprints = count_to_array(fingerprints)
    return fingerprints


def get_erg_fingerprints(molecules):
    fingerprints = rdReducedGraphs.GetErGFingerprint(molecules)
    return fingerprints


def get_avalon_erg_fingerprints(smiles):
    molecules = Chem.MolFromSmiles(smiles)
    fingerprints = []
    fingerprints.append(get_avalon_fingerprints(molecules))
    fingerprints.append(get_erg_fingerprints(molecules))
    return torch.tensor(np.concatenate(fingerprints, axis=0))