import os
import time
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from utils import create_directory
from data_generation import DataHandler, generate_binary_tree
from models import LinRegModel, CellEmbeddingLinear, EmbeddingNN
from utils import newick_to_adjacency_matrix, build_parent_path_mat, split_indices, IndicesDataset


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--latent-corr', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--feature-dim', type=int, default=16)
    parser.add_argument('--latent-dim', type=int, default=5)
    parser.add_argument('--embedding-dim', type=int, default=5)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--tree-depth', type=int, default=8)
    parser.add_argument('--pop-size', type=int, default=5000)
    parser.add_argument('--output-dir', type=str, default='baseline_exp')
    parser.add_argument('--early-stopping', type=int, default=5)
    parser.add_argument('--use-cuda', type=int, default=1, help="1->True ; 0->False")
    parser.add_argument('--tanh', type=int, default=0, help="1->True ; 0->False")
    parser.add_argument('--ksparse', type=int, default=1)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--mask-dim', type=int, default=0)
    args = parser.parse_args()

    # setting the torch device
    device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda == 1 else "cpu")

    base_output_dir = os.path.join(os.path.abspath('.'), 'experiment_output', args.output_dir)
    create_directory(base_output_dir, remove_curr=False)

    print('generating binary tree of depth {}'.format(args.tree_depth))
    cell_tree, cell_names = generate_binary_tree(depth=args.tree_depth)

    pp_ordered_nodes, parent_child = newick_to_adjacency_matrix(cell_tree, pops_list=None)
    pp_ordered_nodes = [node.name for node in pp_ordered_nodes]
    # naming the internal nodes
    curr = 0
    for i in range(len(pp_ordered_nodes)):
        if pp_ordered_nodes[i] is None:
            pp_ordered_nodes[i] = 'internal_{}'.format(curr)
            curr += 1


    parent_path = build_parent_path_mat(parent_child)

    data_class = DataHandler(feat_dim=args.feature_dim, latent_dim=args.latent_dim, pop_size=args.pop_size,
                             latent_correlation=args.latent_corr,
                             parent_path=parent_path, cell_tree=cell_tree, pp_ordered_names=pp_ordered_nodes,
                             tanh=bool(args.tanh), k_sparse=args.ksparse, seed=args.seed, mask_dim=args.mask_dim)


    """
    we need to create the datasets for training
    """

    X = np.reshape(data_class.expression_array, newshape=(-1, data_class.expression_array.shape[-1]))
    y = np.asarray([y for y in data_class.phenotypes_by_cell]).flatten()

    # Embedding model E in the text
    embedding_dim = args.embedding_dim
    if args.tanh:
        embedding_model = EmbeddingNN(input_dim=data_class.feat_dim, output_dim=embedding_dim, p=0).double()
    else:
        embedding_model = CellEmbeddingLinear(input_dim=data_class.feat_dim, output_dim=embedding_dim).double()
    embedding_model.to(device)

    # Predictive model P in the text
    pheno_model = LinRegModel(input_dim=embedding_dim).double()
    pheno_model.to(device)

    # creating the dataset objects
    sample_idx = list(range(y.shape[0]))
    train_idx, eval_idx = split_indices(sample_idx, train_percentage=0.5, seed=0)  # seed for reproducibility
    validation_idx, test_idx = split_indices(eval_idx, train_percentage=0.5, seed=0)  # seed for reproducibility

    # could be simplified
    train_set = IndicesDataset(np.asarray([idx for idx in train_idx], dtype=np.int64))
    validation_set = IndicesDataset(np.asarray([idx for idx in validation_idx], dtype=np.int64))
    test_set = IndicesDataset(np.asarray([idx for idx in test_idx], dtype=np.int64))

    X = torch.tensor(X, device=device, dtype=torch.double)
    y = torch.tensor(y, device=device, dtype=torch.double)

    # storing the true latents for later analysis
    latents_arr = np.asarray([l for l in data_class.latents_by_cell])
    latents_arr = np.reshape(latents_arr, newshape=(-1, latents_arr.shape[-1]))
    validation_latents = latents_arr[validation_idx]

    metrics_dict = {
        'valid_mse': 0.0,
        'test_mse': 0.0,
        'valid_generative_latents': validation_latents,
        'valid_embeddings': None,
        'test_generative_latents': latents_arr[test_idx],
        'test_embeddings': None
    }

    # Parameters for shuffle batch
    params = {'batch_size': args.batch_size,
              'shuffle': True,
              'num_workers': 0,
              'drop_last': False}

    train_batch_gen = torch.utils.data.DataLoader(train_set, **params)

    loss_function = nn.MSELoss()
    if torch.cuda.is_available() and args.use_cuda == 1:
        loss_function = loss_function.cuda()
    optimizer = torch.optim.Adam(list(embedding_model.parameters()) + list(pheno_model.parameters()), lr=args.lr)

    # pre-collecting the validation and test tensors
    validation_X = X[validation_idx]
    validation_y = y[validation_idx]
    test_X = X[test_idx]
    test_y = y[test_idx]
    """
    Block for identifying samples from a single env for later analysis
    """
    # tuples relating the index of a sample in X to a column in the parent path matrix
    path_tuples = list()
    for cell_type_idx, cell_type in enumerate(pp_ordered_nodes):  # todo: legacy, simplify
        path_idx = pp_ordered_nodes.index(cell_type)
        starting_sample_idx = len(path_tuples)
        for i in range(len(data_class.phenotypes_by_cell[cell_type_idx])):
            path_tuples.append((starting_sample_idx+i, path_idx))  # index of sample in idx, path_idx

    metrics_dict['test_y'] = test_y.detach().cpu().numpy()
    test_nodes = np.asarray(path_tuples)[test_idx][:, 1]
    metrics_dict['test_nodes'] = test_nodes



    # creating a tuple tracking best validation result + epoch for early stopping
    best_result = (np.inf, 0)

    # train loop
    for epoch in range(args.epochs):
        start = time.time()
        for step, batch_idx in enumerate(tqdm(train_batch_gen)):
            # separating batch indices
            batch_X = X[batch_idx]
            batch_y = y[batch_idx]

            # forward pass
            optimizer.zero_grad()
            embeddings = embedding_model.forward(batch_X)
            y_hat = pheno_model.forward(embeddings)

            # loss, backward pass
            loss = loss_function(y_hat, batch_y)
            loss.backward()
            optimizer.step()
        end = time.time()
        print('Time for train epoch: {} seconds'.format(round(end - start, 2)))
        # validation set
        with torch.no_grad():
            valid_embeddings = embedding_model.forward(validation_X)
            y_hat = pheno_model.forward(valid_embeddings)
            valid_mse = loss_function(y_hat, validation_y)
        print("Validation loss: {}".format(valid_mse.item()))
        if valid_mse.item() < best_result[0]:
            best_result = (valid_mse.item(), epoch)
        if epoch == args.epochs - 1 or epoch - best_result[1] >= args.early_stopping:
            metrics_dict['valid_mse'] = valid_mse.item()
            metrics_dict['valid_embeddings'] = valid_embeddings.detach().cpu().numpy()
            print('Ending training at epoch', epoch)
            # running test set
            test_embeddings = embedding_model.forward(test_X)
            y_hat = pheno_model.forward(test_embeddings)
            test_mse = loss_function(y_hat, test_y)
            metrics_dict['test_mse'] = test_mse.item()
            metrics_dict['test_embeddings'] = test_embeddings.detach().cpu().numpy()
            break
    with open(os.path.join(base_output_dir, "metric_results"), 'wb') as f:
        pickle.dump([metrics_dict], f)
    print('Experiment complete')


