"""Script for running inference and evaluation."""

import os
import time
import numpy as np
import hydra
import torch
import pandas as pd
import glob
import GPUtil
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from omegaconf import DictConfig, OmegaConf
from experiments import utils as eu
from models.flow_module import FlowModule


torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)


class EvalRunner:

    def __init__(self, cfg: DictConfig):
        """Initialize sampler.

        Args:
            cfg: inference config.
        """

        # Read in checkpoint.
        if cfg.inference.task == 'unconditional':
            ckpt_path = cfg.inference.unconditional_ckpt_path
        elif cfg.inference.task == 'scaffolding':
            ckpt_path = cfg.inference.scaffolding_ckpt_path
        else:
            raise ValueError(f'Unknown task {self._infer_cfg.task}')
        ckpt_dir = os.path.dirname(ckpt_path)
        ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))

        # Set-up config.
        OmegaConf.set_struct(cfg, False)
        OmegaConf.set_struct(ckpt_cfg, False)
        cfg = OmegaConf.merge(cfg, ckpt_cfg)
        cfg.experiment.checkpointer.dirpath = './'
        self._cfg = cfg
        self._exp_cfg = cfg.experiment
        self._infer_cfg = cfg.inference
        self._samples_cfg = self._infer_cfg.samples
        self._rng = np.random.default_rng(self._infer_cfg.seed)

        # Set-up output directory only on rank 0
        local_rank = os.environ.get('LOCAL_RANK', 0)
        if local_rank == 0:
            inference_dir = self.setup_inference_dir(ckpt_path)
            self._exp_cfg.inference_dir = inference_dir
            config_path = os.path.join(inference_dir, 'config.yaml')
            with open(config_path, 'w') as f:
                OmegaConf.save(config=self._cfg, f=f)
            log.info(f'Saving inference config to {config_path}')

        # Read checkpoint and initialize module.
        self._flow_module = FlowModule.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            cfg=self._cfg,
            folding_cfg=self._infer_cfg.folding,
        )
        log.info(pl.utilities.model_summary.ModelSummary(self._flow_module))
        self._flow_module.eval()
        self._flow_module._infer_cfg = self._infer_cfg
        self._flow_module._samples_cfg = self._samples_cfg

    @property
    def inference_dir(self):
        return self._flow_module.inference_dir

    def setup_inference_dir(self, ckpt_path):
        self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
        output_dir = os.path.join(
            self._infer_cfg.predict_dir,
            self._ckpt_name,
            self._infer_cfg.task,
            self._infer_cfg.inference_subdir,
        )
        os.makedirs(output_dir, exist_ok=True)
        log.info(f'Saving results to {output_dir}')
        return output_dir

    def run_sampling(self):
        devices = GPUtil.getAvailable(
            order='memory', limit = 8)[:self._infer_cfg.num_gpus]
        log.info(f"Using devices: {devices}")
        log.info(f'Evaluating {self._infer_cfg.task}')
        if self._infer_cfg.task == 'unconditional':
            eval_dataset = eu.LengthDataset(self._samples_cfg)
        elif self._infer_cfg.task == 'scaffolding':
            eval_dataset = eu.ScaffoldingDataset(self._samples_cfg)
        else:
            raise ValueError(f'Unknown task {self._infer_cfg.task}')
        dataloader = torch.utils.data.DataLoader(
            eval_dataset, batch_size=1, shuffle=False, drop_last=False)
        trainer = Trainer(
            accelerator="gpu",
            strategy="ddp",
            devices=devices,
        )
        trainer.predict(self._flow_module, dataloaders=dataloader)

    def compute_unconditional_metrics(self, output_dir):
        log.info(f'Calculating metrics for {output_dir}')
        top_sample_csv = eu.get_all_top_samples(output_dir)
        top_sample_csv['designable'] = top_sample_csv.bb_rmsd <= 2.0
        metrics_df = pd.DataFrame(data={ 
            'Total designable': top_sample_csv.designable.sum(),
            'Designable': top_sample_csv.designable.mean(),
            'Total samples': len(top_sample_csv),
        }, index=[0])
        designable_csv_path = os.path.join(output_dir, 'designable.csv')
        metrics_df.to_csv(designable_csv_path, index=False)
        eu.calculate_diversity_novelty(
            output_dir, metrics_df, top_sample_csv, designable_csv_path)

    def compute_scaffolding_metrics(self, output_dir):
        all_targets = glob.glob(os.path.join(output_dir, '*'))
        all_targets = [x for x in all_targets if '.yaml' not in x]
        for target_dir in all_targets:
            log.info(f'Calculating metrics for {target_dir}')
            top_sample_csv = eu.get_all_top_samples(
                target_dir, csv_fname='*/top_sample.csv')
            top_sample_csv['designable'] = (
                (top_sample_csv.scaffold_rmsd <= 2.0)
                & (top_sample_csv.motif_rmsd <= 1.0)
            )
            metrics_df = pd.DataFrame(data={ 
                'Total designable': top_sample_csv.designable.sum(),
                'Designable': top_sample_csv.designable.mean(),
                'Total samples': len(top_sample_csv),
            }, index=[0])
            designable_csv_path = os.path.join(target_dir, 'designable.csv')
            metrics_df.to_csv(designable_csv_path, index=False)
            eu.calculate_diversity_novelty(
                target_dir, metrics_df, top_sample_csv, designable_csv_path)


@hydra.main(version_base=None, config_path="../configs", config_name="inference_unconditional")
def run(cfg: DictConfig) -> None:

    # Read model checkpoint.
    log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
    start_time = time.time()
    sampler = EvalRunner(cfg)
    sampler.run_sampling()
    if cfg.inference.task == 'unconditional':
        sampler.compute_unconditional_metrics(sampler.inference_dir)
    elif cfg.inference.task == 'scaffolding':
        sampler.compute_scaffolding_metrics(sampler.inference_dir)
    else:
        raise ValueError(f'Unknown task {cfg.inference.task}')
    elapsed_time = time.time() - start_time
    log.info(f'Finished in {elapsed_time:.2f}s')

if __name__ == '__main__':
    run()