from typing import List, Optional, Tuple

import hydra
import pytorch_lightning as L
import pyrootutils
import copy
import GPUtil
import time
import os
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import pandas as pd
import multiprocessing as mp
from biggs.models.BiG_module import GwgPairSampler
from biggs.data.sequence_dataset import SequenceDataset, PreScoredSequenceDataset

pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
#       (so you don't need to force user to install project as a package)
#       (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
#       (which is used as a base for paths in "configs/paths/default.yaml")
#       (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #

from biggs import utils

log = utils.get_pylogger(__name__)
to_list = lambda x: x.cpu().detach().numpy().tolist()

def save_pairs(pairs_dict, save_path):
    """Save generated pairs to file"""
    log.info(f"Saving generated pairs to {save_path}")
    with open(save_path, "w") as f:
        for src_seq in pairs_dict:
            for tgt_seq in pairs_dict[src_seq]:
                f.write(f"{src_seq},{tgt_seq}\n")

def _worker_fn(args):
    """Worker function for multiprocessing.

    Args:
        args (tuple): (worker_i, exp_cfg, inputs)
            worker_i: worker id. Used for setting CUDA device.
            exp_cfg: model config.
            inputs: list of inputs to process.

    Returns:
        all_outputs: results of GWG.
    """
    worker_i, exp_cfg, inputs = args
    model = GwgPairSampler(
        **exp_cfg,
        device=f"cuda:{worker_i}"
    )
    all_outputs = []
    for batch in inputs:
        all_outputs.append(model(batch))
    log.info(f'Done with worker: {worker_i}')
    return all_outputs

def _setup_dataset(cfg):
    if cfg.data.csv_path is not None:
        raise ValueError(f'cfg.data.csv_file must be None.')
    data_dir = os.path.dirname(cfg.experiment.predictor_dir).replace('ckpt', 'data')
    for i, subdir in enumerate(data_dir.split('/')):
        if 'percentile' in subdir:
            break
    data_dir = '/'.join(data_dir.split('/')[:i+1])
    cfg.data.csv_path = os.path.join(data_dir, 'base_seqs.csv')
    if not os.path.exists(cfg.data.csv_path):
        raise ValueError(f'Could not find dataset at {cfg.data.csv_path}.')

    return PreScoredSequenceDataset(**cfg.data)

def generate_pairs(cfg: DictConfig, sample_write_path: str) -> Tuple[dict, dict]:
    """Generate pairs using GWG."""
    cfg = copy.deepcopy(cfg)
    run_cfg = cfg.run

    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)

    dataset = _setup_dataset(cfg)

    # Special settings for debugging.
    if run_cfg.debug:
        log.info("WARNING: running in debug mode.")
        num_workers = 1
    else:
        num_workers = run_cfg.num_workers
    exp_cfg = dict(cfg.experiment)

    epoch = 0
    start_time = time.time()
    while epoch < run_cfg.max_epochs and len(dataset):
        epoch += 1
        epoch_start_time = time.time()
        if run_cfg.debug:
            batch_size = 2
        else:
            batch_size = len(dataset) // num_workers + 1
        dataloader = DataLoader(dataset, batch_size=batch_size)

        # Run sampling with workers
        batches_per_worker = [[] for _ in range(num_workers)]
        for i, batch in enumerate(dataloader):
            batches_per_worker[i % num_workers].append(batch)
            if run_cfg.debug:
                break
        available_gpus = ''.join([str(x) for x in GPUtil.getAvailable(
            order='memory', limit = 8)])
        worker_gpus = list(available_gpus[:num_workers])
        log.info(f'Using GPUs: {worker_gpus}' )
        if num_workers == 1:
            all_worker_outputs = [
                _worker_fn((worker_gpus[0], exp_cfg, batches_per_worker[0]))
            ]
        else:
            with mp.Pool(num_workers) as pool:
                all_worker_outputs = pool.map(
                    _worker_fn, zip(worker_gpus, [exp_cfg]*num_workers, batches_per_worker))

        # Process results.
        epoch_pair_count = 0
        candidate_seqs = []
        for worker_results in all_worker_outputs:
            for new_pairs in worker_results:
                if new_pairs is None:
                    continue
                candidate_seqs.append(
                    new_pairs[['mutant_sequences', 'mutant_scores']].rename(
                        columns={'mutant_sequences': 'sequences', 'mutant_scores': 'scores'}
                    )
                )
                epoch_pair_count += dataset.add_pairs(new_pairs, epoch)
        if len(candidate_seqs) > 0:
            candidate_seqs = pd.concat(candidate_seqs)
            candidate_seqs.drop_duplicates(subset='sequences', inplace=True)
        epoch_elapsed_time = time.time() - epoch_start_time

        log.info(f"Epoch {epoch} finished in {epoch_elapsed_time:.2f} seconds")
        log.info("------------------------------------")
        log.info(f"Generated {epoch_pair_count} pairs in this epoch")
        dataset.reset()
        if epoch < run_cfg.max_epochs and len(candidate_seqs) > 0:
            dataset.add(candidate_seqs)
            dataset.cluster()
        log.info(f"Next dataset = {len(dataset)} sequences")
    dataset.pairs.to_csv(sample_write_path, index=False)
    elapsed_time = time.time() - start_time
    log.info(f'Finished generation in {elapsed_time:.2f} seconds.')
    log.info(f'Samples written to {sample_write_path}.')


@hydra.main(version_base="1.3", config_path="../configs", config_name="BiG.yaml")
def main(cfg: DictConfig) -> Optional[float]:
    # apply extra utilities
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
    utils.extras(cfg)
    mp.set_start_method('spawn')

    # Set-up output path
    run_cfg = cfg.run
    output_dir = os.path.join(cfg.experiment.predictor_dir, 'samples', run_cfg.run_name)
    cfg_write_path = os.path.join(output_dir, 'config.yaml')
    os.makedirs(os.path.dirname(cfg_write_path), exist_ok=True)
    with open(cfg_write_path, 'w') as f:
        OmegaConf.save(config=cfg, f=f)
    log.info(f'Config saved to {cfg_write_path}')

    # Generate samples for multiple seeds.
    for seed in range(1, cfg.run.num_seeds+1):
        sample_write_path = os.path.join(output_dir, f'seed_{seed}.csv')
        log.info(f'On seed {seed}. Saving results to {sample_write_path}')
        generate_pairs(cfg, sample_write_path)

if __name__ == "__main__":
    main()
