import os
import time
import torch
import pickle
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from utils import create_directory
from models import LinRegModel, EmbeddingNN
from utils import split_indices, IndicesDataset


def train_baseline(data_dict, y, lr=0.0001, weight_decay=0.0001, epochs=200, batch_size=512,
                    output_dir='baseline_rnaseq', early_stopping=5, use_cuda=True, batchnorm=False, train_types=None,
                    test_types=None):

    use_cuda = bool(use_cuda)
    device = torch.device("cuda:0" if torch.cuda.is_available() and use_cuda else "cpu")

    base_output_dir = os.path.join(os.path.abspath('.'), 'experiment_output', output_dir)
    create_directory(base_output_dir, remove_curr=False)
    X = data_dict['X']
    parent_path = data_dict['pp']

    num_samples = X.shape[0]
    num_feats = X.shape[1]
    latent_dim = data_dict['Z'].shape[1]

    # creating the dataset objects
    sample_idx = list(range(num_samples))
    path_tuples = [(idx, data_dict['path_dict'][idx]) for idx in range(num_samples)]
    if test_types is not None:
        test_idx = [idx for idx, ct in enumerate(data_dict['cell_types']) if ct in test_types]
        if train_types is not None:
            train_idx, validation_idx = split_indices([idx for idx, ct in enumerate(data_dict['cell_types']) if ct in train_types], train_percentage=0.7, seed=0)  # seed for reproducibility
        else:
            train_idx, validation_idx = split_indices([idx for idx in sample_idx if idx not in test_idx], train_percentage=0.7, seed=0)  # seed for reproducibility
    else:
        train_idx, eval_idx = split_indices(sample_idx, train_percentage=0.7, seed=0)  # seed for reproducibility
        validation_idx, test_idx = split_indices(eval_idx, train_percentage=0.5, seed=0)  # seed for reproducibility

    train_set = IndicesDataset(np.asarray([path_tuples[idx] for idx in train_idx], dtype=np.int64))
    validation_set = IndicesDataset(np.asarray([path_tuples[idx] for idx in validation_idx], dtype=np.int64))
    test_set = IndicesDataset(np.asarray([path_tuples[idx] for idx in test_idx], dtype=np.int64))

    X = torch.tensor(X, device=device, dtype=torch.double)
    y = torch.tensor(y, device=device, dtype=torch.double)

    metrics_dict = {
        'valid_mse': None,
        'valid_loss': None,
        'test_mse': None,
        'valid_generative_latents': data_dict['Z'][validation_idx],
        'valid_embeddings': None,
        'test_generative_latents': data_dict['Z'][test_idx],
        'test_embeddings': None
    }

    # Parameters for shuffle batch
    params = {'batch_size': batch_size,
              'shuffle': True,
              'num_workers': 0,
              'drop_last': False}

    train_batch_gen = torch.utils.data.DataLoader(train_set, **params)

    # Embedding model E in the text
    embedding_model = EmbeddingNN(input_dim=num_feats, output_dim=latent_dim).double()
    embedding_model.to(device)

    num_edges = parent_path.shape[0]

    # Predictive model P in the text
    pheno_model = LinRegModel(input_dim=latent_dim).double()
    pheno_model.to(device)

    loss_function = nn.MSELoss()
    if torch.cuda.is_available() and use_cuda:
        loss_function = loss_function.cuda()
    optimizer = torch.optim.Adam(list(embedding_model.parameters()) + list(pheno_model.parameters()), lr=lr,
                                 weight_decay=weight_decay)

    # pre-collecting the validation and test tensors
    validation_X = X[validation_idx]
    validation_y = y[validation_idx]
    validation_nodes = np.asarray(path_tuples)[validation_idx][:, 1]
    test_X = X[test_idx]
    test_y = y[test_idx]
    test_nodes = np.asarray(path_tuples)[test_idx][:, 1]
    metrics_dict['test_y'] = test_y.detach().cpu().numpy()
    metrics_dict['test_nodes'] = test_nodes

    # creating a tuple tracking best validation result + epoch for early stopping
    best_result = (np.inf, 0)

    # train loop
    for epoch in range(epochs):
        start = time.time()
        for step, batch_tuples in enumerate(tqdm(train_batch_gen)):
            # separating batch indices
            batch_samples = batch_tuples[:, 0]
            batch_X = X[batch_samples]
            batch_y = y[batch_samples]

            # forward pass
            optimizer.zero_grad()
            embeddings = embedding_model.forward(batch_X, batchnorm=batchnorm)
            y_hat = pheno_model.forward(embeddings)

            # loss, backward pass
            loss = loss_function(y_hat, batch_y)
            loss.backward()
            optimizer.step()
        end = time.time()
        print('Time for train epoch: {} seconds'.format(round(end - start, 2)))
        # validation set
        with torch.no_grad():
            valid_embeddings = embedding_model.forward(validation_X)
            y_hat = pheno_model.forward(valid_embeddings)
            valid_mse = loss_function(y_hat, validation_y)
        print("Validation MSE: {}".format(valid_mse.item()))
        if valid_mse.item() < best_result[0]:
            best_result = (valid_mse.item(), epoch)
        if epoch == epochs - 1 or epoch - best_result[1] >= early_stopping:
            metrics_dict['valid_mse'] = valid_mse.item()
            print('Ending training at epoch', epoch)
            # running test set
            test_embeddings = embedding_model.forward(test_X)
            y_hat = pheno_model.forward(test_embeddings)
            test_mse = loss_function(y_hat, test_y)
            metrics_dict['test_mse'] = test_mse.item()
            metrics_dict['test_embeddings'] = test_embeddings.detach().cpu().numpy()
            break
    with open(os.path.join(base_output_dir, "metric_results"), 'wb') as f:
        pickle.dump([metrics_dict], f)
    print('Experiment complete')
