import os
import argparse
import pandas as pd
import numpy as np

import torch
import torch_geometric.transforms as T
from torch.optim.lr_scheduler import StepLR

# custom modules
from MOTASG_Foundation.utils import tab_printer
from MOTASG_Foundation.model import MOTASG_Foundation, DegreeDecoder, EdgeDecoder, GNNEncoder
from MOTASG_Foundation.mask import MaskEdge

# custom dataloader
from GeoDataLoader.read_geograph import read_pretrain_batch
from GeoDataLoader.geograph_sampler import GeoGraphLoader

from MOTASG_Foundation.lm_model import TextEncoder

def build_pretrain_model(args, device):
    mask = MaskEdge(p=args.p)

    text_encoder = TextEncoder(args.text_lm_model_path, device)

    graph_encoder = GNNEncoder(args.num_omic_feature, args.encoder_channels, args.hidden_channels,
                        num_layers=args.encoder_layers, dropout=args.encoder_dropout,
                        bn=args.bn, layer=args.layer, activation=args.encoder_activation)

    internal_graph_encoder = GNNEncoder(args.num_omic_feature, args.input_dim, args.input_dim,
                            num_layers=args.internal_encoder_layers, dropout=args.encoder_dropout,
                            bn=args.bn, layer=args.layer, activation=args.encoder_activation)

    edge_decoder = EdgeDecoder(args.hidden_channels, args.decoder_channels,
                            num_layers=args.decoder_layers, dropout=args.decoder_dropout)

    degree_decoder = DegreeDecoder(args.hidden_channels, args.decoder_channels,
                                num_layers=args.decoder_layers, dropout=args.decoder_dropout)

    pretrain_model = MOTASG_Foundation(text_input_dim=args.lm_emb_dim,
                    omic_input_dim=args.num_omic_feature,
                    input_dim=args.input_dim, 
                    text_encoder=text_encoder,
                    encoder=graph_encoder,
                    internal_encoder=internal_graph_encoder,
                    edge_decoder=edge_decoder,
                    degree_decoder=degree_decoder,
                    mask=mask).to(device)
    
    return pretrain_model



def pretrain_foundation(args, device):
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    # Load data
    print('--- LOADING TRAINING FILES ... ---')
    xAll = np.load(os.path.join(args.data_path, 'pretrain_plain_feature.npy'))
    all_edge_index = np.load(os.path.join(args.data_path, 'edge_index.npy'))
    internal_edge_index = np.load(os.path.join(args.data_path, 'internal_edge_index.npy'))
    ppi_edge_index = np.load(os.path.join(args.data_path, 'ppi_edge_index.npy'))

    # shuffle the xAll (since this is only the pretrain omics)
    np.random.seed(args.sf_seed)  # Set fixed seed for reproducibility
    shuffle_idx = np.random.permutation(xAll.shape[0])
    xAll = xAll[shuffle_idx]

    num_cell = xAll.shape[0]
    num_entity = xAll.shape[1]
    all_edge_index = torch.from_numpy(all_edge_index).long()
    internal_edge_index = torch.from_numpy(internal_edge_index).long()
    ppi_edge_index = torch.from_numpy(ppi_edge_index).long()

    # Build Pretrain Model
    pretrain_model = build_pretrain_model(args, device)
    num_feature = args.num_omic_feature
    optimizer = torch.optim.Adam(pretrain_model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    # scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    pretrain_model.reset_parameters()

    if args.train_text:
        # Use language model to embed the name and description
        s_name_df = pd.read_csv(os.path.join(args.data_path, 'bmgc_name.csv'))
        s_desc_df = pd.read_csv(os.path.join(args.data_path, 'bmgc_desc.csv'))
        name_sentence_list = s_name_df['Names_and_IDs'].tolist()
        name_sentence_list = [str(name) for name in name_sentence_list]
        desc_sentence_list = s_desc_df['Description'].tolist()
        desc_sentence_list = [str(desc) for desc in desc_sentence_list]
        text_encoder = pretrain_model.text_encoder
        text_encoder.load_model()
        name_embeddings = text_encoder.generate_embeddings(name_sentence_list, batch_size=args.pretrain_text_batch_size, text_emb_dim=args.lm_emb_dim)
        print(f'Name Embeddings Shape: {name_embeddings.shape}')
        text_encoder.save_embeddings(name_embeddings, os.path.join(args.data_path, 'x_name_emb.npy'))
        desc_embeddings = text_encoder.generate_embeddings(desc_sentence_list, batch_size=args.pretrain_text_batch_size, text_emb_dim=args.lm_emb_dim)
        print(f'Description Embeddings Shape: {desc_embeddings.shape}')
        text_encoder.save_embeddings(desc_embeddings, os.path.join(args.data_path, 'x_desc_emb.npy'))
    else:
        name_embeddings = np.load(os.path.join(args.data_path, 'x_name_emb.npy')).reshape(-1, args.lm_emb_dim)
        name_embeddings = torch.from_numpy(name_embeddings)
        print(f'Name Embeddings Shape: {name_embeddings.shape}')
        desc_embeddings = np.load(os.path.join(args.data_path, 'x_desc_emb.npy')).reshape(-1, args.lm_emb_dim)
        desc_embeddings = torch.from_numpy(desc_embeddings)
        print(f'Description Embeddings Shape: {desc_embeddings.shape}')


    # load textual embeddings into torch tensor
    name_embeddings = name_embeddings.float().to(device)
    desc_embeddings = desc_embeddings.float().to(device)

    # Pretrain model
    upper_index = 0
    batch_size = args.pretrain_batch_size

    batch_avg_loss_list = []
    all_step_avg_loss_list = []
    batch_auc_list = []
    batch_acc_list = []
    best_loss = 1000
    best_auc = 0.0
    for index in range(0, num_cell, batch_size):
        if (index + batch_size) < num_cell:
            upper_index = index + batch_size
        else:
            upper_index = num_cell
        current_cell_num = upper_index - index
        geo_datalist = read_pretrain_batch(index, upper_index, xAll, num_feature, num_entity, all_edge_index, internal_edge_index, ppi_edge_index)
        dataset_loader = GeoGraphLoader.load_graph(geo_datalist, args.pretrain_batch_size, args.pretrain_num_workers) # read by batch size

        for batch_idx, data in enumerate(dataset_loader):
            print(f'Starting {index} - {upper_index}')
            print('Start Training (Link Prediction Pretext Training)...')
            
            train_data, val_data, test_data = T.RandomLinkSplit(num_test=0.1, num_val=0.0,
                                                            is_undirected=False,
                                                            split_labels=True,
                                                            add_negative_train_samples=False)(data)
            
            train_data = train_data.to(device)
            avg_loss, step_avg_loss_list = pretrain_model.train_step(data=train_data,
                                        num_entity=num_entity,
                                        name_embeddings=name_embeddings, 
                                        desc_embeddings=desc_embeddings,
                                        optimizer=optimizer,
                                        alpha=args.alpha, 
                                        batch_size=current_cell_num,
                                        grad_norm=args.grad_norm)
            
            # scheduler.step()  # <- decays the learning rate at epoch level
            # print(f"LR: {scheduler.get_last_lr()[0]}")

            batch_avg_loss_list.append(avg_loss)
            all_step_avg_loss_list.extend(step_avg_loss_list)
            # save loss list to text file
            with open(args.save_path.replace('.pt', '_batch_avg_loss_list.txt'), 'w') as f:
                for item in batch_avg_loss_list:
                    f.write("%s\n" % item)
            with open(args.save_path.replace('.pt', '_all_step_avg_loss_list.txt'), 'w') as f:
                for item in all_step_avg_loss_list:
                    f.write("%s\n" % item)

            test_data = test_data.to(device)
            test_auc, test_ap = pretrain_model.test_step(test_data, 
                                    test_data.pos_edge_label_index, 
                                    test_data.neg_edge_label_index) 
            batch_auc_list.append(test_auc)
            batch_acc_list.append(test_ap)
            # save auc list to text file
            with open(args.save_path.replace('.pt', '_batch_auc.txt'), 'w') as f:
                for item in batch_auc_list:
                    f.write("%s\n" % item) 
            with open(args.save_path.replace('.pt', '_batch_ap.txt'), 'w') as f:
                for item in batch_acc_list:
                    f.write("%s\n" % item) 
            print(f'Link Prediction Pretraining Results:\n'
                f'AUC: {test_auc:.2%}',
                f'AP: {test_ap:.2%}')
            print(f'Pretraining {upper_index} done!')

            # Save the model with condition of best loss and best auc
            if avg_loss < best_loss and test_auc >= best_auc:
                best_auc = test_auc
                best_loss = avg_loss
                print(f'Best AUC: {best_auc}')
                print(f'Best Loss: {best_loss}')
                torch.save(pretrain_model.state_dict(), args.save_path)
    return pretrain_model


def arg_parse():
    parser = argparse.ArgumentParser()

    # pre-training parameters
    parser.add_argument('--pretrain_batch_size', type=int, default=4, help='Batch size for pretraining. (default: 2)')
    parser.add_argument('--pretrain_text_batch_size', type=int, default=64, help='Batch size for pretraining text. (default: 64)')
    parser.add_argument('--sf_seed', type=int, default=0, help='Seed for shuffling the data. (default: 0)')
    parser.add_argument('--text_lm_model_path', nargs='?', default='dmis-lab/biobert-v1.1', help='Path to the pretrained language model. (default: dmis-lab/biobert-v1.1)')
    parser.add_argument('--train_text', default=False, help='Whether to train the text encoder. (default: False)')

    parser.add_argument('--layer', nargs='?', default='gat', help='GNN layer, (default: gcn)')
    parser.add_argument('--encoder_activation', nargs='?', default='leaky_relu', help='Activation function for GNN encoder, (default: leaky_relu)')

    parser.add_argument('--num_omic_feature', type=int, default=1, help='Omic feature size. (default: 1)')
    parser.add_argument('--lm_emb_dim', type=int, default=1, help='Text embedding dimension. (default: 1)')

    parser.add_argument('--input_dim', type=int, default=1, help='Input feature dimension. (default: 1)')
    parser.add_argument('--encoder_channels', type=int, default=8, help='Channels of GNN encoder layers. (default: 8)')
    parser.add_argument('--hidden_channels', type=int, default=8, help='Channels of hidden representation. (default: 8)')
    parser.add_argument('--decoder_channels', type=int, default=4, help='Channels of decoder layers. (default: 4)')

    parser.add_argument('--encoder_layers', type=int, default=2, help='Number of layers for encoder. (default: 2)')
    parser.add_argument('--internal_encoder_layers', type=int, default=4, help='Number of layers for internal encoder. (default: 1)')
    parser.add_argument('--decoder_layers', type=int, default=2, help='Number of layers for decoders. (default: 2)')
    parser.add_argument('--encoder_dropout', type=float, default=0.2, help='Dropout probability of encoder. (default: 0.2)')
    parser.add_argument('--decoder_dropout', type=float, default=0.2, help='Dropout probability of decoder. (default: 0.2)')
    parser.add_argument('--alpha', type=float, default=0., help='loss weight for degree prediction. (default: 0.)')

    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for pre-training. (default: 0.001)')
    parser.add_argument('--weight_decay', type=float, default=5e-5, help='weight_decay for link prediction training. (default: 5e-5)')
    # parser.add_argument('--lr_step_size', type=int, default=2, help='Step size for learning rate decay. (default: 5)')
    # parser.add_argument('--lr_gamma', type=float, default=0.3, help='Gamma for learning rate decay. (default: 0.5)')
    parser.add_argument('--grad_norm', type=float, default=1.0, help='grad_norm for training. (default: 1.0.)')
    parser.add_argument('--pretrain_num_workers', dest = 'pretrain_num_workers', type = int, default=0, help = 'Number of workers to load data.')

    parser.add_argument('--start', nargs='?', default='node', help='Which Type to sample starting nodes for random walks, (default: node)')
    parser.add_argument('--p', type=float, default=0.00005, help='Mask ratio for MaskEdge')

    parser.add_argument('--bn', action='store_true', help='Whether to use batch normalization for GNN encoder. (default: False)')
    parser.add_argument('--data_path', nargs='?', default='./data/pretrain_plain_data', help='Path to the pretrain data. (default: .data/pretrain_plain_data)')
    parser.add_argument('--save_path', nargs='?', default='./checkpoints/pretrained_models_gat/pretrained_plain_foundation.pt', help='save path for model. (default: pretrained_plain_foundation.pt)')
    parser.add_argument('--device', type=int, default=0)

    return parser.parse_args()


if __name__ == "__main__":
    # Set arguments and print
    args = arg_parse()
    print(tab_printer(args))
    
    # Check device
    if args.device < 0:
        device = 'cpu'
    else:
        device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    # Pretrain model
    pretrain_model = pretrain_foundation(args, device)
