from typing import Tuple, Dict, Callable, List, Union

# import python and torch garbage collectors
import gc
import torch
import torch.nn as nn

import logging
import os.path as osp
from pathlib import Path
import glob
from omegaconf import OmegaConf, DictConfig
import wandb
import json

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from torch_geometric.transforms import Compose
from src.datatypes.sparse import MyToUndirected

import src.data.pipelines as ppl
from src.data.datamodule import GraphDataModule
from src.data.simple_transforms.molecular import GraphToMoleculeConverter

from src.models.generator import ReinsertionDenoisingModel
from src.noise.data_transform.subgraph_sampler import SubgraphSampler
from src.noise.data_transform.subsequence_sampler import SubsequenceSampler, resolve_collater

from src.metrics.sampling import BaseSamplingMetric, GraphSamplingMetrics, MoleculeSamplingMetrics, SamplingMetricsHandler


DATASETS_ROOT = 'datasets'
CHECKPOINT_PATH = 'checkpoints'

import os
import random
import numpy as np
from copy import deepcopy

# from https://github.com/Lightning-AI/lightning/issues/1565
def seed_everything(seed=191117):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)


class ContextException(Exception):
    pass


class PrepareData:

    def __init__(self, removal_process):

        self.removal_process = removal_process

    def __call__(self, batch):
            
        self.removal_process.prepare_data(datapoint=batch)
        return batch

class RunContext:

    def __init__(self, console_logger: logging.Logger|int = None):
        
        self.wandb_active = False
        self.logger = setup_logger(console_logger)

    
    @classmethod
    def from_config(cls, cfg: DictConfig, console_logger: logging.Logger|int = None) -> 'RunContext':
        context = cls(console_logger)
        context.configure(cfg)
        return context


    def configure(self, cfg):

        # preprocess configuration
        cfg = preprocess_config(cfg)

        # configure context
        self.config_name = cfg['config_name']
        if 'group_name' in cfg:
            self.group_name = cfg['group_name']
        else:
            self.group_name = self.config_name
        self.debug = cfg['debug']
        self.persistent = cfg['persistent']
        self.resume = cfg['resume']
        self.version = cfg['version']
        self.seed = cfg['seed']
        self.profile = cfg['profile']
        self.cfg = cfg

        self.logger.info(f'Configuring run "{self.config_name}"')
        self.logger.info(f'Debug mode: {self.debug}')
        self.logger.info(f'Resuming run: {self.resume}')
        self.logger.info(f'Version: {self.version}')
        self.logger.info(f'Seed: {self.seed}')
        self.logger.info(f'Profiling: {self.profile}')

        # set seed
        seed_everything(self.seed)

        # check that context parameters are valid
        self.validate_context()
        self.logger.info(f'Context validated with success')

        # checkpoints, run ids, etc..., are all contained in the run directory
        self.run_directory, self.version, self.run_id = self._setup_run_directory(self.config_name)
        self.logger.info(f'Run directory: {self.run_directory}')

        # configuring datamodule
        self.logger.info(f'Configuring and loading datamodule...')
        self.datamodule = self._configure_datamodule(cfg['data'], cfg['model']['removal'])
        self.dataset_infos = self.datamodule.get_info('train')
        self.logger.info(f'Datamodule configured with success')
        
        # configuring model
        self.logger.info(f'Configuring and loading model...')
        datatype = ppl.REGISTERED_DATATYPES[cfg['data']['dataset']['name']]
        self.model = self._configure_model(datatype, cfg['model'], cfg['run']['training'], cfg['metric'], self.datamodule, self.dataset_infos)
        self.logger.info(f'Model configured with success')

        # configuring trainer
        self.logger.info(f'Configuring trainer...')
        self.trainer = self._configure_trainer(cfg['run'], cfg['platform'])
        self.logger.info(f'Trainer configured with success')

        if not self.resume and not self.debug:
            self.logger.info(f'Creating run directory...')
            self.run_directory.mkdir(parents=True)

        self.logger.info(f'Configuration completed')



    def fit(self, ckpt='last'):
        return self.trainer.fit(
            model = 		self.model,
            datamodule = 	self.datamodule,
            ckpt_path =		ckpt
        )
    def validate(self, ckpt='last'):
        return self.trainer.validate(
            model = 		self.model,
            datamodule = 	self.datamodule,
            ckpt_path =		ckpt
        )
    def test(self, ckpt='last'):
        return self.trainer.test(
            model = 		self.model,
            datamodule = 	self.datamodule,
            ckpt_path =		ckpt
        )


    def get_training_info(self):
        return self.trainer.current_epoch


    def load_checkpoint(self, checkpoint_name: str=None, strict: bool=True):

        if checkpoint_name is None:
            checkpoint_name = 'last'
        if not checkpoint_name.endswith('.ckpt'):
            checkpoint_name += '.ckpt'
        self.model = self.model.__class__.load_from_checkpoint(
            str(self.run_directory / checkpoint_name),
            sampling_metrics =              self.model.sampling_metrics,
            inference_samples_converter =   self.model.inference_samples_converter,
            console_logger =                self.model.console_logger,
            strict =                       strict
        )
        self.logger.info(f'Loaded checkpoint {checkpoint_name}')


    def load_module_from(self, ckpt_path: str, module: str):
        self.logger.info(f'Loading module {module} from checkpoint {ckpt_path}')
        
        other_model = ReinsertionDenoisingModel.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            sampling_metrics =              self.model.sampling_metrics,
            inference_samples_converter =   self.model.inference_samples_converter,
            console_logger =                self.model.console_logger,
        )

        if module == 'reinsertion':
            self.model.reinsertion_model = other_model.reinsertion_model
        elif module == 'denoising':
            self.model.denoising_model = other_model.denoising_model


    def evaluate_all_checkpoints(self, log_table: bool=True, save_dictionary: bool=True):
        
        if log_table:
            test_table = wandb.Artifact(f'test_table_{self.run_id}', type='table')
        table = None
        columns = None
        dictionary = {}

        for ckpt in self.get_all_checkpoints(include_last=False):
            self.logger.info(f'Current checkpoint: {ckpt}')

            # test the model using current checkpoint
            curr_metrics = self.test(ckpt=ckpt)[0]

            if table is None:
                columns = list(curr_metrics.keys())
                table = wandb.Table(columns=['run', 'seed', 'ckpt'] + columns)

            table.add_data(self.group_name, self.seed, ckpt.name, *[curr_metrics[k] for k in columns])
            dictionary[ckpt.name] = {k: curr_metrics[k] for k in columns}

        if log_table:
            test_table.add(table, 'test_table')
            wandb.log({"test/test_table": table})
            wandb.log_artifact(test_table)

        if save_dictionary:
            with open(self.run_directory / 'test_dictionary.json', 'w') as f:
                json.dump(dictionary, f)

        return table, dictionary


    def get_all_checkpoints(self, include_path=True, include_last=False) -> List[Union[str, Path]]:
        checkpoints = list(self.run_directory.glob('*.ckpt'))

        if not include_last:
            checkpoints = [c for c in checkpoints if not c.name.startswith('last')]

        if not include_path:
            checkpoints = [c.name for c in checkpoints]

        return checkpoints


    def sample_batch(self, which_split: str='train'):
        self.datamodule.setup(which_split)
        batch = next(iter(self.datamodule.get_dataloader(which_split)))
        return batch
    

    def dry_run(self, which_split: str='train', num_steps: int=1, no_grad: bool=True):
        # save all relevant states
        grad_state = torch.is_grad_enabled()

        trn_curr_steps = self.trainer.limit_train_batches
        val_curr_steps = self.trainer.limit_val_batches
        tst_curr_steps = self.trainer.limit_test_batches
        curr_step = self.trainer.global_step
        curr_batch_step = self.trainer.fit_loop.epoch_loop.batch_progress.current.ready
        epoch_progress = deepcopy(self.trainer.fit_loop.epoch_progress.current)
        

        disable_generation = self.model._disable_generation
        debug_state = self.model.debug
        
        try:    # run the desired split in safety

            if no_grad:
                torch.set_grad_enabled(False)
            self.model._disable_generation = True
            self.model.debug = True

            if which_split == 'train':
                # temporarily update the trainer
                self.trainer.limit_train_batches = num_steps
                self.trainer.limit_val_batches = 0
                self.trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = 0
                self.trainer.fit_loop.epoch_loop.batch_progress.current.ready = 0
                self.trainer.fit_loop.epoch_progress.reset()

                # reset trainer to restart training
                # recall you cannot set global_step
                self.trainer.fit_loop.epoch_loop.batch_progress.current.reset()

                self.fit()

            elif which_split == 'valid':
                # temporarily update the trainer
                self.trainer.limit_val_batches = num_steps

                self.validate()

            elif which_split == 'test':
                # temporarily update the trainer
                self.trainer.limit_test_batches = num_steps

                self.test()
            
            elif which_split == 'gen':
                # temporarily update the trainer
                self.model._disable_generation = False

                self.model.sample_batch(64 * num_steps)

        finally:    # restore torch, trainer and model states

            if no_grad:
                torch.set_grad_enabled(grad_state)

            self.trainer.limit_train_batches = trn_curr_steps
            self.trainer.limit_val_batches = val_curr_steps
            self.trainer.limit_test_batches = tst_curr_steps
            self.trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = curr_step
            self.trainer.fit_loop.epoch_loop.batch_progress.current.ready = curr_batch_step
            self.trainer.fit_loop.epoch_progress.current = epoch_progress

            self.model._disable_generation = disable_generation
            self.model.debug = debug_state


    def cleanup(self):
        gc.collect()
        torch.cuda.empty_cache()

    
    def close(self):

        if self.wandb_active:
            wandb.finish()

    ############################################################################
    #                             UTILITY METHODS                              #
    ############################################################################

    def validate_context(self):
        if self.resume and self.version is None:
            raise ContextException('Cannot resume run without specifying a version')
        if not self.resume and self.version is not None:
            self.logger.warning('Version specified, but not resuming run. Version will be ignored.')

    def _setup_run_directory(self, config_name: str) -> Tuple[Path, int, str]:
        if not self.debug:
            
            run_path = Path(CHECKPOINT_PATH, config_name)

            if self.resume:
                version = self.version
                # if the version is not specified, get the latest version
                if version == -1:
                    # get latest version
                    matched_run_path = list(run_path.glob('v*'))
                    if len(matched_run_path) == 0:
                        raise ContextException(f'No runs found matching the name {config_name}')
                    version = max([int(p.name.split('_')[0][1:]) for p in matched_run_path])
                
                # check that the run path exists (checking the prefix)
                matched_run_path = list(run_path.glob(f'v{version}_*'))
                if len(matched_run_path) == 0:
                    raise ContextException(f'No run found matching the version {version}')
                if len(matched_run_path) > 1:
                    raise ContextException(f'Multiple runs found matching the version {version}')

                # now we know the run exists, so we can get the run id
                run_path = matched_run_path[0]
                run_id = matched_run_path[0].name.split('_')[-1]

                
            else:
                # generate run id
                run_id = wandb.util.generate_id()
                version = 0

                if not self.debug and self.persistent:
                    # check that the run path exists, and if not, create it
                    run_path.mkdir(parents=True, exist_ok=True)

                    # list all directories in the run path and check the latest version
                    matched_run_path = list(run_path.iterdir())
                    if len(matched_run_path) > 0:
                        version = max([int(p.name.split('_')[0][1:]) for p in matched_run_path]) + 1
                
                # create the new run path
                run_path = Path(run_path, f'v{version}_{run_id}')

            return run_path, version, run_id
        
        else:
            return None, None, None


    ############################################################################
    #                          CONFIGURATION METHODS                           #
    ############################################################################


    def _configure_datamodule(
            self,
            cfg_data: DictConfig,
            cfg_removal: DictConfig
        ) -> GraphDataModule:

        dataset_name = cfg_data['dataset']['name']

        def setup_collater(cfg_dataloader):
            if 'collate_fn' in cfg_dataloader:
                cfg_dataloader['collate_fn'] = resolve_collater(cfg_dataloader['collate_fn'])
                
        cfg_dataloader = OmegaConf.to_container(cfg_data['dataloader'])

        self.logger.info(f'Using dataset "{dataset_name}"')
        if 'batch_size' in cfg_dataloader:
            bs = cfg_dataloader["batch_size"]
            setup_collater(cfg_dataloader)
        else:
            bs = cfg_dataloader["train"]["batch_size"]
            for k in cfg_dataloader:
                setup_collater(cfg_dataloader[k])
        self.logger.info(f'Batch size: {bs}')

        # setting up data transform with graph subsampling
        cfg_datatf = cfg_data['datatransform']

        if cfg_datatf['sample_whole_sequence']:
            num_sequences = cfg_datatf['num_sequences']
            data_sampler = SubsequenceSampler.create_subsequence_sampler(
                process_config = cfg_removal,
                num_sequences = num_sequences
            )
            self.logger.info(f'Sampling {num_sequences} sequences of subgraphs from 1 datapoint')
        else:
            data_sampler = SubgraphSampler.create_subgraph_sampler(
                process_config = cfg_removal
            )
            self.logger.info(f'Sampling 1 random subgraphs from 1 datapoint')

        postprocess_pl = Compose([
            MyToUndirected(),
            #data_sampler
            PrepareData(data_sampler.removal_process)
        ])
        #postprocess_pl = MyToUndirected()

        # get download and preprocess pipelines
        download_pl_kwargs = cfg_data['dataset']['download']
        if download_pl_kwargs is None:
            download_pl_kwargs = {}
        preprocess_pl_kwargs = cfg_data['dataset']['preprocess']
        if preprocess_pl_kwargs is None:
            preprocess_pl_kwargs = {}


        # create datamodule
        datamodule = GraphDataModule(
            root_dir =          osp.join(DATASETS_ROOT, cfg_data['dataset']['root']),
            download_pl =       ppl.REGISTERED_DOWNLOAD_PIPELINES[dataset_name](**download_pl_kwargs),
            preprocess_pl =     ppl.REGISTERED_PREPROCESS_PIPELINES[dataset_name](**preprocess_pl_kwargs),
            postprocess_pl =    postprocess_pl,
            dataloader_config = cfg_dataloader
        )

        return datamodule
    

    def _configure_inference_samples_converter(self, datatype: str, dataset_infos) -> Callable:

        if datatype == 'molecular':
            # get graph to molecule converter for transforming graphs into molecules
            return GraphToMoleculeConverter(
                atom_decoder = dataset_infos['atom_types'],
                bond_decoder = dataset_infos['bond_types']
            )
        
        else:
            return None


    def _configure_sampling_metrics(self, datatype: str, cfg_metrics: DictConfig, datamodule: GraphDataModule, inference_samples_converter: Callable) -> SamplingMetricsHandler:

        datamodule.setup('train', disable_transform=True)
        datamodule.setup('test', disable_transform=True)
        
        def configure_split_samp_metr(curr_cfg_metrics):

            metrics = SamplingMetricsHandler(
                datamodule =        datamodule,
                generation_cfg =    curr_cfg_metrics['generation'],
                metrics_cfg =       curr_cfg_metrics['metrics'],
                samples_converter = inference_samples_converter
            )
            
            return metrics

        
        if 'valid' in cfg_metrics or 'test' in cfg_metrics:
            metrics = nn.ModuleDict({
                f'_{split}': configure_split_samp_metr(curr_cfg) for split, curr_cfg in cfg_metrics.items()
            })
        else:
            metrics = configure_split_samp_metr(cfg_metrics)

        
        datamodule.clear_datasets()
        
        return metrics
    

    def _configure_model(
            self,
            datatype: str,
            cfg_model: DictConfig,
            cfg_training: DictConfig,
            cfg_metrics: DictConfig,
            datamodule: GraphDataModule,
            dataset_infos: Dict
        ):

        # get graph to molecule converter for transforming graphs into molecules
        inference_samples_converter = self._configure_inference_samples_converter(
            datatype = datatype,
            dataset_infos = dataset_infos
        )
        self.logger.info(f'Generated graphs are transformed by: {type(inference_samples_converter)}')

        # get sampling metrics
        sampling_metrics = self._configure_sampling_metrics(
            datatype = datatype,
            cfg_metrics = cfg_metrics,
            datamodule = datamodule,
            inference_samples_converter = inference_samples_converter
        )
        self.logger.info(f'Sampling metrics: {type(sampling_metrics).__name__}')

        # setup console logger
        console_logger = logging.getLogger('generator')
        console_logger.setLevel(self.logger.level)
        console_logger = setup_logger(console_logger)

        # create model
        model = ReinsertionDenoisingModel(
            architecture_config =           cfg_model['architecture'],
            diffusion_config =              cfg_model['diffusion'],
            removal_config =                cfg_model['removal'],
            dataset_info =                  dataset_infos,
            run_config =                    cfg_training,
            sampling_metrics =              sampling_metrics,
            inference_samples_converter =   inference_samples_converter,
            console_logger =                console_logger,
            conditional_generator =         cfg_model['conditional'],
            debug =                         self.debug or (not self.persistent)
        )

        return model
    
    
    def _configure_checkpoint(self, cfg_checkpoint: DictConfig) -> ModelCheckpoint:
        if not self.debug and self.persistent:
            checkpoint_callback = ModelCheckpoint(
                dirpath =       self.run_directory,
                filename =      None,
                save_last =     True,
                **cfg_checkpoint
            )

            return checkpoint_callback
        else:
            # callback that never saves anything
            # it is enough to disable saving last
            # and saving top k
            checkpoint_callback = ModelCheckpoint(
                dirpath =       self.run_directory,
                filename =      None,
                save_last =     False,
                save_top_k =    0

            )
            return checkpoint_callback
    

    def _configure_logger(self, cfg_logger: DictConfig):
        if not self.debug and self.persistent and not self.wandb_active:
            wandb.init(
                name =      self.config_name,
                resume =    self.resume,
                id =        self.run_id,
                config =    OmegaConf.to_container(self.cfg),
                **cfg_logger['wandb']
            )
            self.wandb_active = True


    def _configure_trainer(
            self,
            cfg_run: DictConfig,
            cfg_platform: DictConfig
        ) -> Trainer:

        callbacks = []

        # if debug is activated or persistent is deactivated, checkpointing and logging are deactivated!

        if 'checkpoint' in cfg_run:
            checkpoint_callback = self._configure_checkpoint(cfg_run['checkpoint'])
            callbacks.append(checkpoint_callback)

        if 'logger' in cfg_run:
            self._configure_logger(cfg_run['logger'])

        # configure trainer
        gpus_ok = gpus_available(cfg_platform)
        gpus_num = cfg_platform['gpus'] if gpus_ok else 0

        self.logger.info(f'Using GPU: {gpus_ok}, N={gpus_num}')
        self.logger.info(f'Number of epochs: {cfg_run["trainer"]["max_epochs"]}')

        from pytorch_lightning.profiler import AdvancedProfiler
        if self.profile:
            # remove file if exists
            for f in glob.glob('perf_logs*'):
                os.remove(f)
            profiler = AdvancedProfiler(dirpath=".", filename="perf_logs")
        else:
            profiler = None
        
        # build trainer
        trainer = Trainer(
            # training
            max_epochs =                cfg_run['trainer']['max_epochs'],

            # validation
            val_check_interval =		cfg_run['trainer']['val_check_interval'],
            check_val_every_n_epoch =	cfg_run['trainer']['check_val_every_n_epoch'],
            num_sanity_val_steps =		cfg_run['trainer']['num_sanity_val_steps'],

            # testing
            limit_train_batches =		20 if cfg_run['running_test'] else None,
            limit_val_batches =			20 if cfg_run['running_test'] else None,
            limit_test_batches =		20 if cfg_run['running_test'] else None,

            # computing devices
            accelerator =				'gpu' 		if gpus_ok else 'cpu',
            devices =					gpus_num 	if gpus_ok else None,
            strategy =					'ddp' 		if gpus_num > 1 else None,

            # visualization and debugging
            fast_dev_run = 				self.debug,
            enable_progress_bar =		cfg_run['trainer']['enable_progress_bar'],

            # logging and checkpointing
            logger =                    False,

            # callbacks
            callbacks =					[checkpoint_callback] if checkpoint_callback is not None else None,

            # for network debugging in case of NaNs
            profiler=profiler,
            detect_anomaly=False
        )
    
        return trainer


################################################################################
#                                UTILITY METHOD                                #
################################################################################

def preprocess_config(cfg: DictConfig):
    
    # resolve configuration interpolations which are using hydra choices
    choices = cfg.hydra.runtime.choices
    cfg.hydra = OmegaConf.create({'runtime': {'choices': choices}})
    cfg = OmegaConf.to_container(cfg, resolve=True)

    # remove hydra configuration
    cfg.pop('hydra')
    cfg = OmegaConf.create(cfg)

    return cfg


################################################################################
#                            CONFIGURATION METHODS                             #
################################################################################

def setup_logger(logger: logging.Logger|int = None):
    if logger is None:
        level = logging.INFO
    elif isinstance(logger, int):
        level = logger
        logger = None
    elif isinstance(logger, logging.Logger):
        level = logger.level

    if logger is None:
        logger = logging.getLogger('configurator')
        logger.setLevel(level=level)
        
    # remove all handlers
    logger.handlers = []
    # set format of logger
    formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
    # create console handler
    ch = logging.StreamHandler()
    ch.setLevel(level=level)
    ch.setFormatter(formatter)
    # add console handler to logger
    logger.addHandler(ch)

    return logger


def gpus_available(platform_config):
    return torch.cuda.is_available() and platform_config['gpus'] > 0
    