import argparse
from dataclasses import asdict, dataclass
import functools
import logging
import os
from pathlib import Path
import random
from typing import Callable, Dict, Optional, Union

import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl
import yaml
import pickle

# from ccvae.models.ccvae import ContrastiveCVAE
from ccvae.data.loaders import (
    celligner_labels,
    kang_labels,
    uci_income_labels,
    load_celligner,
    load_kang,
    load_kang_trvae,
    load_kang_trvae_counts,
    load_uci_income,
    make_strata,
    prepare_training_data,
    batch_sampling_modes,
)
from ccvae.metrics import compute_mmd, knn_metric
from ccvae import metric_handlers
from ccvae.nn.utils import Encoder, GaussianDecoder, MultinomialDecoder, calc_input_dims
from ccvae.nn.config import ModelConfig, TrainConfig, TestConfig
from ccvae.pl.vae import VAE
from ccvae.pl.cvae import CVAE
from ccvae.pl.ccvae import ContrastiveCVAE
from ccvae.pl.trvae import TrVAE
from ccvae.pl.trainer import create_trainer

logging.basicConfig(
    format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)


def main(
    data_dir: Path,
    dataset: str,
    batch_sampling: str,
    forward_pass_use_groups: bool,
    top_var_number: int,
    model_config: ModelConfig,
    train_config: TrainConfig,
    test_config: TestConfig,
    use_cuda: bool,
    seed: float,
    output_dir: Path,
    metrics_tracking_fn: Optional[Callable],
    enable_profiler: bool = False,
):

    pl.seed_everything(seed)

    # Load data
    #
    # Dataset: Celliner
    #                   c in batch | c in encoder | d in batch | d in encoder
    # vae               True       | False        | True       | False
    # conditional model True       | True         | True       | Optional
    #
    # Dataset: Kang
    #                   c in batch | c in encoder | d in batch | d in encoder
    # vae               True       | False        | False      | False
    # conditional model True       | True         | False      | Optional
    #
    # input_label_fn returns a list of tensors, one for each set of labels.
    input_label_fn: Optional[Callable[[pd.DataFrame, bool], List[torch.Tensor]]] = None
    dataset_includes_group_labels = True  # "d in batch" above
    if dataset == "celligner":
        load_celligner_filtered = functools.partial(load_celligner, top_var_number=top_var_number)
        load_data_fn = load_celligner_filtered
        # if batch_sampling is None:
        #     batch_sampling = 'batched-by-category'
        # assert batch_sampling in batch_sampling_modes
        if batch_sampling == 'batched-by-category':
            # This really just batches diseases together rather than the more conventional
            # meaning of stratification.
            dataset_includes_group_labels = False  # Do not include disease labels in input label_fn
            strata_fn = (lambda metadata_df, batch_size:
                         (make_strata(metadata_df.disease, batch_size), 'disease')
                         )
        elif batch_sampling == "inverse-weighted":
            # Inverse-weighted stratified sampling.
            dataset_includes_group_labels = True  # Include disease labels in input label_fn
            # Treat all diseases with fewer than 50 samples as one stratum. This includes Bile Duct
            # Cancer (63 samples, 26 cell-lines) as the smallest separate disease.
            celligner_disease_size_cutoff = 50
            strata_fn = (lambda metadata_df, _:
                         (make_strata(metadata_df.disease,
                                      celligner_disease_size_cutoff),
                          'disease')
                         )
        elif batch_sampling is None or batch_sampling == "uniform":
            dataset_includes_group_labels = False
            strata_fn = None
        else:
            raise NotImplementedError(
                f"Batch sampling mode ({batch_sampling}) not implement for celligner data."
            )
        input_label_fn = functools.partial(celligner_labels, dataset_includes_group_labels)
    elif dataset == "kang":
        if batch_sampling is None or batch_sampling == "uniform":
            # Default uniform batch sampling
            load_data_fn = load_kang
            input_label_fn = kang_labels
            strata_fn = None  # No stratifcation for this dataset.
        else:
            raise NotImplementedError(
                f"Batch sampling mode ({batch_sampling}) not implement for kang data."
            )
    elif dataset == "kang-trvae":
        if batch_sampling is None or batch_sampling == "uniform":
            # Default uniform batch sampling
            load_data_fn = load_kang_trvae
            input_label_fn = kang_labels
            strata_fn = None  # No stratifcation for this dataset.
        else:
            raise NotImplementedError(
                f"Batch sampling mode ({batch_sampling}) not implement for kang data."
            )
    elif dataset == "kang-trvae-counts":
        if batch_sampling is None or batch_sampling == "uniform":
            # Default uniform batch sampling
            load_data_fn = load_kang_trvae_counts
            input_label_fn = kang_labels
            strata_fn = None  # No stratifcation for this dataset.
        else:
            raise NotImplementedError(
                f"Batch sampling mode ({batch_sampling}) not implement for kang data."
            )
    elif dataset == 'uci-income':
        assert batch_sampling is None or batch_sampling == "uniform", \
            f'{batch_sampling} batch sampling not supported for uci-income, only "uniform'
        load_data_fn = load_uci_income
        input_label_fn = uci_income_labels
        strata_fn = None  # No stratifcation for this dataset.
    else:
        raise ValueError(f'Unsupported dataset name {dataset}')

    if (top_var_number is not None) and (dataset == "celligner"):
        use_cache = False
    else:
        use_cache = True

    tensor_dataset, train_loader, val_loader, metadata_df, test_indices = prepare_training_data(
        data_dir=data_dir,
        load_data_fn=load_data_fn,
        batch_size=train_config.batch_size,
        input_label_fn=input_label_fn,
        use_cuda=use_cuda,
        batch_sampling=batch_sampling,
        strata_fn=strata_fn,
        dataset_name=dataset,
        categories_to_leave_out=test_config.categories,
        condition_to_leave_out=test_config.condition,
        use_cache=use_cache,
    )

    # save test indices
    with open(output_dir / "test_indices.pkl", 'wb') as fp:
        pickle.dump(test_indices, fp)

    # Instantiate model

    # Number of group classes is used by contrastive models depending on whether batches are grouped
    # This is derived from the tensor_dataset. The model may only use the groups for calculating
    # the loss rather than also including it as input to the encoder and decoder.
    if len(tensor_dataset.tensors) == 3:
        n_groups = tensor_dataset.tensors[2].shape[-1]
    else:
        n_groups = None

    # Set up encoder/decoder from possible inputs (X, c, d) and (z, c, d). Their inclusion
    # depends on the model type and whether conditioning on all labels is selected in the
    # CLI args.
    encoder_input_dim, decoder_input_dim = calc_input_dims(tensor_dataset,
                                                           model_config,
                                                           forward_pass_use_groups)
    gene_expression_dim = tensor_dataset.tensors[0].shape[1]

    if model_config.model == "trvae" or model_config.model == "tr_cvamp":
        return_hidden = True
    else:
        return_hidden = False

    LOGGER.info(f'Input tensor shapes {tensor_dataset.tensors[0].shape[0]} x ({", ".join(str(t.shape[1]) for t in tensor_dataset.tensors)})')
    LOGGER.info(f'Encoder input dim {encoder_input_dim}; decoder input dim {decoder_input_dim}; decoder output dim {gene_expression_dim}')

    encoder = Encoder(
        encoder_input_dim,
        model_config.latent_dim,
        model_config.hidden_dim,
        learn_sigma=model_config.learn_sigma,
        n_layers=model_config.num_layers,
        use_batchnorm=model_config.use_batchnorm,
        bandwidth=model_config.bandwidth,
    )

    if model_config.likelihood == "gaussian":
        decoder = GaussianDecoder(
            gene_expression_dim,
            decoder_input_dim,
            model_config.hidden_dim,
            model_config.num_layers,
            return_hidden=return_hidden,
            use_batchnorm=model_config.use_batchnorm,
        )
        baseline_dist = torch.distributions.Normal(
            loc=tensor_dataset.tensors[0].mean(), scale=1
        )

    elif model_config.likelihood == "multinomial":
        gene_counts = tensor_dataset.tensors[0].sum(axis=0)
        baseline_logits = (gene_counts / gene_counts.sum()).log()
        if use_cuda:
            baseline_logits = baseline_logits.cuda()
        decoder = MultinomialDecoder(
            gene_expression_dim,
            decoder_input_dim,
            model_config.hidden_dim,
            baseline=baseline_logits,
            return_hidden=return_hidden,
            use_batchnorm=model_config.use_batchnorm,
        )
        baseline_dist = torch.distributions.Multinomial(
            logits=baseline_logits, validate_args=False
        )
    else:
        assert False, f"{model_config.likelihood} is not handled when creating decoder"

    with torch.no_grad():
        baseline_logprob = (
            baseline_dist.log_prob(torch.as_tensor(tensor_dataset.tensors[0]))
            .mean()
            .item()
        )
        LOGGER.info(
            "Baseline log prob (using independent marginals): %f", baseline_logprob
        )

    if model_config.model == "vae":
        model = VAE(
            encoder,
            decoder,
            model_config.latent_dim,
            learning_rate=train_config.learning_rate,
            gamma=train_config.gamma,
            beta=model_config.kl_beta,
        )
    elif model_config.model == "cvae":
        model = CVAE(
            encoder,
            decoder,
            model_config.latent_dim,
            n_groups=n_groups,
            forward_pass_use_groups=forward_pass_use_groups,
            penalty=model_config.cvae_penalty,
            penalty_scale=model_config.penalty_scale,
            learning_rate=train_config.learning_rate,
            gamma=train_config.gamma,
            beta=model_config.kl_beta,
        )
    elif model_config.model == "contrastive_cvae":
        model = ContrastiveCVAE(
            encoder,
            decoder,
            model_config.latent_dim,
            n_groups=n_groups,
            forward_pass_use_groups=forward_pass_use_groups,
            penalty_scale=model_config.penalty_scale,
            entropy_relative_scale=model_config.entropy_relative_scale,
            penalty_exp_factor=model_config.penalty_exp_factor,
            learning_rate=train_config.learning_rate,
            gamma=train_config.gamma,
            beta=model_config.kl_beta,
        )
    elif model_config.model == "trvae":
        model = TrVAE(
            encoder,
            decoder,
            model_config.latent_dim,
            n_groups=n_groups,
            forward_pass_use_groups=forward_pass_use_groups,
            penalty_scale=model_config.penalty_scale,
            learning_rate=train_config.learning_rate,
            gamma=train_config.gamma,
            beta=model_config.kl_beta,
            penalise_z=model_config.penalise_z,
            rbf_version=model_config.rbf_version,
        )
    else:
        assert False, f"{model_config.model} is not handled when creating model"

    checkpoint_callback = None
    tb_dir = output_dir / "logs"
    gpu_arg = 1 if use_cuda else None
    trainer_args = dict(
        output_dir=tb_dir,
        num_epochs=train_config.num_epochs,
        gpus=gpu_arg,
        checkpoint_metric_name="valid_loss",
        checkpoint_monitor_mode="min",
        early_stopping=False,
        early_stopping_delta=1e-6,
        early_stopping_patience=50,
        weights_summary='full',
    )
    if enable_profiler:
        LOGGER.warning(
            f"Pytorch profiler enabled; writing TensorBoard logs to {str(tb_dir)}"
        )
        with torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=2, warmup=2, active=6, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(tb_dir),
        ) as profiler:
            trainer, checkpoint_callback = create_trainer(
                **trainer_args, profiler=profiler
            )
            trainer.fit(model, train_loader, val_dataloaders=val_loader)
    else:
        trainer, checkpoint_callback = create_trainer(**trainer_args)
        trainer.fit(model, train_loader, val_dataloaders=val_loader)

    if checkpoint_callback is not None:
        model_list = [tb_dir / "last.ckpt"]
        model_list.extend(Path(k) for k in checkpoint_callback.best_k_models.keys())
    else:
        LOGGER.info(
            "No checkpoint callback found, calculating metrics and results for current model instance instead."
        )
        model_list = [model]

    metric_handlers.calc_metrics(
        handlers=metric_handlers.select_metric_handlers(dataset),
        output_dir=output_dir,
        dataset=tensor_dataset,
        sample_metadata_df=metadata_df,
        models=model_list,
        dataset_name=dataset,
        model_type=model_config.model,
        load_model_fn=model.load_from_checkpoint,
        baseline_logprob=baseline_logprob,
        use_cuda=use_cuda,
        metrics_tracking_fn=metrics_tracking_fn,
    )


def create_arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-dir", default="data")
    parser.add_argument("--dataset", choices=["celligner", "kang", "kang-trvae", "kang-trvae-counts", "uci-income"])
    parser.add_argument("--batching-mode", choices=batch_sampling_modes, default=None,
                        help='The method of stratified sampling, for datasets that require it')
    parser.add_argument("--forward-use-groups", choices=[0, 1], default=0, type=int,
                        help='If 1 (true) then the group labels are passed as input to the encoder and decoder, otherwise they are only used for grouping the penalty calculation for models where this applies. Ignored when batching-mode is batched-by-category.')
    parser.add_argument("--top-var-number", type=int, default=None, help="No. of features to filter for each condition.")

    # Model config
    parser.add_argument("--model", choices=["vae", "cvae", "contrastive_cvae", "trvae", "tr_cvamp"])
    parser.add_argument("--likelihood", choices=["gaussian", "multinomial"])
    parser.add_argument("--hidden-dim", type=int, default=10)
    parser.add_argument("--latent-dim", type=int, default=16)
    parser.add_argument("--num-layers", type=int, default=1)
    parser.add_argument("--penalty-scale", type=float, default=1.0)
    parser.add_argument("--cvae-penalty", default=None)
    parser.add_argument(
        "--entropy-relative-scale",
        type=float,
        default=1.0,
        help="Relevant only for CCVAE. Scales the entropy term relative to the cross-mixture penalty",
    )
    parser.add_argument(
        "--penalty-exp-factor",
        type=float,
        default=1.0,
        help="Relevant only for CCVAE. Exponent for each mixture component. 1.0 ~ l2 reg; 0.5 ~ l1 reg",
    )
    parser.add_argument(
        "--kl-beta",
        type=float,
        default=1.0,
        help="Beta-VAE scale factor for the KL term in the VAE ELBO",
    )
    parser.add_argument(
        "--use-batchnorm",
        type=int, # int rather than bool to allow Polyaxon to parse
        default=0,
        help="Whether to use batchnorm in the decoder. Encoder not yet implemented", # TODO
    )
    parser.add_argument(
        "--learn-sigma",
        type=str,
        choices=["fix_all", "learn_all", "learn_but_decouple", "learn_elbo_fix_penalty"],
        default="fix_all",
        help="Options to learn/fix sigma for elbo term and penalty term separately.",
    )
    parser.add_argument(
        "--bandwidth",
        type=float,
        default=0.1,
        help="The constant value of the posterior Gaussian scale. If learn-sigma is True, this is ignored."
    )
    parser.add_argument(
        "--penalise-z",
        type=int, # int rather than bool to allow Polyaxon to parse
        default=0,
        help="Whether to penalise z. If False, penalise first hidden layer. Applicable to TrVAE"
    )
    parser.add_argument(
        "--rbf-version",
        type=int,
        default=0,  # This is multiscale version from TrVAE
        help="RBF kernel version. Applicable to TrVAE only. For versions, see the global variables in the modules."
    )

    # Training config
    parser.add_argument("--batch-size", type=int, default=50)
    parser.add_argument("--num-epochs", type=int, default=10)
    parser.add_argument("--learning-rate", type=float, default=0.01)

    parser.add_argument("--use-cuda", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--output-dir", default="/tmp/ccvae")
    parser.add_argument(
        "--profiler",
        action="store_true",
        default=False,
        help="Enable Pytorch profiler, logging to TensorBoard (Lightning only).",
    )

    # Test split config
    parser.add_argument("--category-to-leave-out",
        help="One or more 'categories' (e.g. CD14+ Monocytes or Breast Cancer) that you want to leave out",
        action='append',
        type=int,
        default=None)

    parser.add_argument("--condition-to-leave-out",
        help="The condition you want to leave out (e.g. 'perturbed' or 'unperturbed')",
        action='append',
        default=None)
    return parser


if __name__ == "__main__":
    parser = create_arg_parser()
    args = parser.parse_args()

    if args.output_dir == "polyaxon":
        from polyaxon_client.tracking import get_outputs_path, Experiment

        experiment = Experiment()

        def polyx_log_metrics(*args, **kwargs):
            experiment.log_metrics(*args, **kwargs)

        log_metrics_to_tracking = polyx_log_metrics
        output_dir = Path(get_outputs_path())
    else:
        log_metrics_to_tracking = None
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

    with open(output_dir / "config.yaml", "w") as fp:
        yaml.dump(vars(args), fp)

    main(
        data_dir=Path(args.data_dir),
        dataset=args.dataset,
        batch_sampling=args.batching_mode,
        forward_pass_use_groups=args.forward_use_groups,
        top_var_number=args.top_var_number,
        model_config=ModelConfig(
            model=args.model,
            likelihood=args.likelihood,
            latent_dim=args.latent_dim,
            hidden_dim=args.hidden_dim,
            num_layers=args.num_layers,
            penalty_scale=args.penalty_scale,
            entropy_relative_scale=args.entropy_relative_scale,
            penalty_exp_factor=args.penalty_exp_factor,
            cvae_penalty=args.cvae_penalty,
            kl_beta=args.kl_beta,
            use_batchnorm=bool(args.use_batchnorm),
            learn_sigma=args.learn_sigma,
            bandwidth=args.bandwidth,
            penalise_z=bool(args.penalise_z),
            rbf_version=args.rbf_version,
        ),
        train_config=TrainConfig(
            batch_size=args.batch_size,
            num_epochs=args.num_epochs,
            learning_rate=args.learning_rate,
            gamma=1.0,
        ),
        test_config=TestConfig(
            categories=args.category_to_leave_out,
            condition=args.condition_to_leave_out,
        ),
        use_cuda=args.use_cuda,
        seed=args.seed,
        output_dir=output_dir,
        metrics_tracking_fn=log_metrics_to_tracking,
        enable_profiler=args.profiler,
    )
