from typing import List
import torch
from argparse import ArgumentParser
import pytorch_lightning as pl
import os
import logging
import dgl



# Adapted from DeepInter
# Applies the radial bais function to the distance matrix D
def rbf(D):
    # D: (1, L, L)
    # return: (64, L, L)
    # Distance radial basis function
    device = D.device  # Get the device of the input tensor
    dtype = D.dtype  # Get the data type of the input tensor
    D_min, D_max, D_count = 2., 22., 64
    D_mu = torch.linspace(D_min, D_max, D_count, device=device, dtype=dtype)
    D_mu = D_mu[None,:]
    D_sigma = (D_max - D_min) / D_count

    D = D.permute(1,2,0)
    RBF = torch.exp(-((D - D_mu) / D_sigma) ** 2)
    return RBF.permute(2,0,1)


# Returns current GPU memory usage
def memory_mb():
    return torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.max_memory_allocated() / 1024 / 1024, torch.cuda.memory_reserved() / 1024 / 1024, torch.cuda.max_memory_reserved() / 1024 / 1024

# Prints the time difference in minutes, seconds, milliseconds
def print_time_diff(start_time, end_time, task_name='Task'):
    time_diff = end_time - start_time
    total_seconds = int(time_diff.total_seconds())
    minutes = total_seconds // 60
    seconds = total_seconds % 60
    milliseconds = int((time_diff.total_seconds() - total_seconds) * 1000)
    print(f'Time taken for {task_name}: {minutes} minutes {seconds} seconds {milliseconds} milliseconds')


# -------------------------------------------------------------------------------------------------------------------------------------
# Following code curated for DeepInteract (https://github.com/BioinfoMachineLearning/DeepInteract):
# -------------------------------------------------------------------------------------------------------------------------------------
def custom_dgl_picp_collate(complex_dicts: List[dict]):
    """Assemble a protein complex dictionary batch into two large batched DGLGraphs and a batched labels tensor."""
    batched_graph1 = dgl.batch([complex_dict['graph1'] for complex_dict in complex_dicts])
    batched_graph2 = dgl.batch([complex_dict['graph2'] for complex_dict in complex_dicts])
    batched_seqA = [complex_dict['seqA'] for complex_dict in complex_dicts]
    batched_seqB = [complex_dict['seqB'] for complex_dict in complex_dicts]
    batched_distA = [complex_dict['distA'] for complex_dict in complex_dicts]
    batched_distB = [complex_dict['distB'] for complex_dict in complex_dicts]
    examples_list = [complex_dict['examples'] for complex_dict in complex_dicts]
    complex_pdb_ids = [complex_dict['complex'] for complex_dict in complex_dicts]
    cropped_indices = [complex_dict['crop_ind'] for complex_dict in complex_dicts]

    return batched_graph1, batched_graph2, batched_seqA, batched_seqB, batched_distA, batched_distB, examples_list, complex_pdb_ids,cropped_indices

def custom_dgl_picp_collate_predict(complex_dicts: List[dict]):
    """Assemble a protein complex dictionary batch into two large batched DGLGraphs and a batched labels tensor."""
    batched_graph1 = dgl.batch([complex_dict['graph1'] for complex_dict in complex_dicts])
    batched_graph2 = dgl.batch([complex_dict['graph2'] for complex_dict in complex_dicts])
    batched_seqA = [complex_dict['seqA'] for complex_dict in complex_dicts]
    batched_seqB = [complex_dict['seqB'] for complex_dict in complex_dicts]
    batched_struct_seqA = [complex_dict['struct_seqA'] for complex_dict in complex_dicts]
    batched_struct_seqB = [complex_dict['struct_seqB'] for complex_dict in complex_dicts]
    batched_distA = [complex_dict['distA'] for complex_dict in complex_dicts]
    batched_distB = [complex_dict['distB'] for complex_dict in complex_dicts]

    return batched_graph1, batched_graph2, batched_seqA, batched_seqB, batched_struct_seqA, batched_struct_seqB, batched_distA, batched_distB


def glorot_orthogonal(tensor, scale):
    """Initialize a tensor's values according to an orthogonal Glorot initialization scheme."""
    if tensor is not None:
        torch.nn.init.orthogonal_(tensor.data)
        scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var())
        tensor.data *= scale.sqrt()

def get_geo_feats_from_edges(orig_edge_feats: torch.Tensor, feature_indices: dict):
    """Retrieve and return geometric features from a given batch of edges."""
    dist_feats = orig_edge_feats[:, feature_indices['edge_dist_feats_start']:feature_indices['edge_dist_feats_end']]
    dir_feats = orig_edge_feats[:, feature_indices['edge_dir_feats_start']:feature_indices['edge_dir_feats_end']]
    o_feats = orig_edge_feats[:, feature_indices['edge_orient_feats_start']:feature_indices['edge_orient_feats_end']]
    amide_feats = orig_edge_feats[:, feature_indices['edge_amide_angles']]
    return dist_feats, dir_feats, o_feats, amide_feats

def calculate_top_k_prec(sorted_pred_indices: torch.Tensor, labels: torch.Tensor, k: int):
    """Calculate the top-k interaction precision."""
    num_interactions_to_score = k
    selected_pred_indices = []
    curr = 0
    while len(selected_pred_indices) < num_interactions_to_score:
        if labels[sorted_pred_indices[curr]] != -1:
            selected_pred_indices.append(sorted_pred_indices[curr])
        curr += 1
    true_labels = labels[torch.tensor(selected_pred_indices)]
    num_correct = torch.sum(true_labels).item()
    prec = num_correct / num_interactions_to_score
    return prec

def extract_object(obj: any):
    """If incoming object is of type torch.Tensor, convert it to a NumPy array. If it is a scalar, simply return it."""
    return obj.cpu().numpy() if type(obj) == torch.Tensor else obj

def collect_args():
    """Collect all arguments required for training/testing."""
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)

    # -----------------
    # Model arguments
    # -----------------
    parser.add_argument('--model_name', type=str, default='GINI', help='Default option is GINI')
    parser.add_argument('--num_gnn_layers', type=int, default=2, help='Number of GNN layers')
    parser.add_argument('--num_interact_layers', type=int, default=14, help='Number of layers in interaction module')
    parser.add_argument('--metric_to_track', type=str, default='val_ce', help='Scheduling and early stop')

    # -----------------
    # Data arguments
    # -----------------
    parser.add_argument('--knn', type=int, default=20, help='Number of nearest neighbor edges for each node')
    parser.add_argument('--self_loops', action='store_true', dest='self_loops', help='Allow node self-loops')
    parser.add_argument('--no_self_loops', action='store_false', dest='self_loops', help='Disable self-loops')
    parser.add_argument('--db5_percent_to_use', type=float, default=1.0, help='Fraction of DB5-Plus dataset to use')
    parser.add_argument('--training_with_db5', action='store_true', dest='training_with_db5', help='Train on DB5-Plus')
    parser.add_argument('--db5_data_dir', type=str, default='datasets/DB5/final/raw', help='Path to DB5-Plus')
    parser.add_argument('--pn_ratio', type=float, default=0.1,
                        help='Positive-negative class ratio to instate during training with DIPS-Plus')
    parser.add_argument('--dips_percent_to_use', type=float, default=1.00,
                        help='Fraction of DIPS-Plus dataset splits to use')
    parser.add_argument('--dips_data_dir', type=str, default='datasets/DIPS/final/raw', help='Path to DIPS')
    parser.add_argument('--casp_capri_data_dir', type=str, default='datasets/CASP_CAPRI/final/raw', help='CAPRI path')
    parser.add_argument('--casp_capri_percent_to_use', type=float, default=1.0, help='Fraction of CASP-CAPRI to use')
    parser.add_argument('--process_complexes', action='store_true', dest='process_complexes',
                        help='Check if all complexes for a dataset are processed and, if not, process those remaining')
    parser.add_argument('--testing_with_casp_capri', action='store_true', dest='testing_with_casp_capri',
                        help='Test on the 13th and 14th CASP-CAPRI\'s dataset of homo and heterodimers')
    parser.add_argument('--input_dataset_dir', type=str, default='datasets/Input',
                        help='Path to directory in which to generate features and outputs for the given inputs')
    parser.add_argument('--psaia_dir', type=str, default='~/Programs/PSAIA_1.0_source/bin/linux/psa',
                        help='Path to locally-compiled copy of PSAIA (i.e., to PSA, one of its CLIs)')
    parser.add_argument('--psaia_config', type=str, default='datasets/builder/psaia_config_file_input.txt',
                        help='Path to input config file for PSAIA')
    parser.add_argument('--hhsuite_db', type=str,
                        default='~/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt',
                        help='Path to downloaded and extracted HH-suite3-compatible database (e.g., BFD or Uniclust30)')

    # -----------------
    # Logging arguments
    # -----------------
    parser.add_argument('--logger_name', type=str, default='TensorBoard', help='Which logger to use for experiments')
    parser.add_argument('--experiment_name', type=str, default=None, help='Logger experiment name')
    parser.add_argument('--project_name', type=str, default='DeepInteract', help='Logger project name')
    parser.add_argument('--entity', type=str, default='bml-lab', help='Logger entity (i.e. team) name')
    parser.add_argument('--run_id', type=str, default='', help='Logger run ID')
    parser.add_argument('--offline', action='store_true', dest='offline', help='Whether to log locally or remotely')
    parser.add_argument('--online', action='store_false', dest='offline', help='Whether to log locally or remotely')
    parser.add_argument('--tb_log_dir', type=str, default='tb_logs', help='Where to store TensorBoard log files')
    parser.set_defaults(offline=False)  # Default to using online logging mode

    # -----------------
    # Seed arguments
    # -----------------
    parser.add_argument('--seed', type=int, default=None, help='Seed for NumPy and PyTorch')

    # -----------------
    # Meta-arguments
    # -----------------
    parser.add_argument('--batch_size', type=int, default=1, help='Number of samples included in each data batch')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-2, help='Decay rate of optimizer weight')
    parser.add_argument('--num_epochs', type=int, default=50, help='Maximum number of epochs to run for training')
    parser.add_argument('--dropout_rate', type=float, default=0.2, help='Dropout (forget) rate')
    parser.add_argument('--patience', type=int, default=5, help='Number of epochs to wait until early stopping')
    parser.add_argument('--pad', action='store_true', dest='pad', help='Whether to zero pad interaction tensors')

    # -----------------
    # Miscellaneous
    # -----------------
    parser.add_argument('--max_hours', type=int, default=1, help='Maximum number of hours to allot for training')
    parser.add_argument('--max_minutes', type=int, default=55, help='Maximum number of minutes to allot for training')
    parser.add_argument('--multi_gpu_backend', type=str, default='ddp', help='Multi-GPU backend for training')
    parser.add_argument('--num_gpus', type=int, default=1, help='Number of GPUs to use (e.g. -1 = all available GPUs)')
    parser.add_argument('--auto_choose_gpus', action='store_true', dest='auto_choose_gpus', help='Auto-select GPUs')
    parser.add_argument('--num_compute_nodes', type=int, default=1, help='Number of compute nodes to use')
    parser.add_argument('--gpu_precision', type=int, default=32, help='Bit size used during training (e.g. 16-bit)')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of CPU threads for loading data')
    parser.add_argument('--profiler_method', type=str, default=None, help='PL profiler to use (e.g. simple)')
    parser.add_argument('--ckpt_dir', type=str, default=f'{os.path.join(os.getcwd(), "checkpoints")}',
                        help='Directory in which to save checkpoints')
    parser.add_argument('--ckpt_name', type=str, default='', help='Filename of best checkpoint')
    parser.add_argument('--min_delta', type=float, default=5e-6, help='Minimum percentage of change required to'
                                                                      ' "metric_to_track" before early stopping'
                                                                      ' after surpassing patience')
    parser.add_argument('--accum_grad_batches', type=int, default=1, help='Norm over which to clip gradients')
    parser.add_argument('--grad_clip_val', type=float, default=0.5, help='Norm over which to clip gradients')
    parser.add_argument('--grad_clip_algo', type=str, default='norm', help='Algorithm with which to clip gradients')
    parser.add_argument('--stc_weight_avg', action='store_true', dest='stc_weight_avg', help='Smooth loss landscape')
    parser.add_argument('--find_lr', action='store_true', dest='find_lr', help='Find an optimal learning rate a priori')
    parser.add_argument('--input_indep', action='store_true', dest='input_indep', help='Whether to zero input for test')

    return parser


def process_args(args):
    """Process all arguments required for training/testing."""
    # ---------------------------------------
    # Seed fixing for random numbers
    # ---------------------------------------
    if not args.seed:
        args.seed = 42  # np.random.randint(100000)
    logging.info(f'Seeding everything with random seed {args.seed}')
    pl.seed_everything(args.seed)
    dgl.seed(args.seed)

    return args
