import os
from abc import ABCMeta, abstractmethod

import numpy as np

import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from ...data.spectrum_dataset import SpectrumDataset
from ...data.spectrum_transform import ComposeSpectrumTransform
from ...data.spectrum_transform import AugmentationSpectrumTransform
from .property_prediction_model import PropertyPredictionModel

filename_is_ckpt = lambda filename: os.path.splitext(filename)[1] == '.ckpt'
get_filename_epoch = key=lambda filename: int(filename.split('-')[0].split('=')[1])

class Model(metaclass=ABCMeta):
    def __init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision):
        self._train_data_path = train_data_path
        self._test_data_path = test_data_path
        self._model_path = model_path
        self._num_epochs = num_epochs
        self._batch_size = batch_size
        self._num_batch_per_update = num_batch_per_update
        self._num_workers = num_workers
        self._augment_spectrum = augment_spectrum
        self._prefetch_factor = prefetch_factor
        self._optimizer = optimizer
        self._machine_params = machine_params
        self._wandb_args = wandb_args
        self._gradient_clip_val = gradient_clip_val
        self._precision = precision
        self._base_net = None,
        self._base_model = None,
        self._spectrum_transform = None
        self._spectra_collate = None

    def setup(self):
        self._setup_spectrum_transform()
        self._setup_spectra_collate()
        self._setup_base_model()

        self._base_model = PropertyPredictionModel(self._base_net, self._optimizer)

    def get_dataset(
            self, datatype_source, spectrum_transform_type=None, keep_idx=None,
            filter=None):
        if datatype_source == 'train':
            data_path = self._train_data_path
        elif datatype_source == 'test':
            data_path = self._test_data_path

        if spectrum_transform_type is None:
            spectrum_transform_type = datatype_source

        return SpectrumDataset.load(
            data_path, transform=self._spectrum_transform[spectrum_transform_type],
            keep_idx=keep_idx, filter=filter)

    def create_dataset(
            self, spectra, spectrum_transform_type='prediction', keep_idx=None, filter=None):
        return SpectrumDataset(
            spectra, transform=self._spectrum_transform[spectrum_transform_type], keep_idx=keep_idx,
            filter=filter)

    def load_dataset(
        self, data_path, spectrum_transform_type='prediction', keep_idx=None, filter=None):
        return SpectrumDataset.load(
            data_path, transform=self._spectrum_transform[spectrum_transform_type],
            keep_idx=keep_idx, filter=filter)

    def get_data_loader(
            self, dataset, batch_size=None, shuffle=None, num_workers=None, prefetch_factor=None,
            drop_last=None):
        dataset_contents, _ = dataset[0]
        if isinstance(dataset_contents, np.ndarray):
            collate_fn = self._spectra_collate['unlabeled']
        else:
            collate_fn = self._spectra_collate['labeled']

        return DataLoader(
            dataset,
            batch_size = self._batch_size if batch_size is None else batch_size,
            shuffle = False if shuffle is None else shuffle,
            collate_fn = collate_fn,
            num_workers = self._num_workers if num_workers is None else num_workers,
            prefetch_factor = self._prefetch_factor if prefetch_factor is None else prefetch_factor,
            drop_last = False if drop_last is None else drop_last)

    def train(self):
        # setup dataloaders
        train_data_loader = self.get_data_loader(
            self.get_dataset('train'), shuffle=True, drop_last=True)
        test_data_loader = self.get_data_loader(
            self.get_dataset('test'), shuffle=False, drop_last=False)

        # setup checkpointing
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=self._model_path, save_top_k=-1)

        # setup wandb logging
        wandb_logger = WandbLogger(**self._wandb_args)

        # setup trainer
        trainer = pl.Trainer(
            accumulate_grad_batches = self._num_batch_per_update,
            max_epochs = self._num_epochs,
            gradient_clip_val = self._gradient_clip_val,
            precision = self._precision,
            logger = wandb_logger,
            callbacks = [checkpoint_callback],
            **self._machine_params
        )

        # train the model
        if (os.path.exists(self._model_path)
                and len(list(filter(filename_is_ckpt, os.listdir(self._model_path)))) > 0):
            # resume training from a checkpoint
            ckpt_filename = next(iter(sorted(
                filter(filename_is_ckpt, os.listdir(self._model_path)),
                key=get_filename_epoch, reverse=True)))
            ckpt_path = os.path.join(self._model_path, ckpt_filename)
            trainer.fit(
                model=self._base_model, train_dataloaders=train_data_loader,
                val_dataloaders=test_data_loader, ckpt_path=ckpt_path)
        else:
            # start training a newly initialized model
            trainer.fit(
                model=self._base_model, train_dataloaders=train_data_loader,
                val_dataloaders=test_data_loader)

    def predict_properties(self, data):
        trainer = pl.Trainer(precision=self._precision, **self._machine_params)

        properties = np.vstack(trainer.predict(model=self._base_model, dataloaders=data))
        properties = properties.astype(np.float)
        properties = self._spectrum_transform['train']._scaler.inverse_transform(properties)

        return properties

    def eval_loss(self, data, limit_test_batches=1.0):
        trainer = pl.Trainer(
            precision=self._precision, limit_test_batches=limit_test_batches,
            **self._machine_params)

        return trainer.test(model=self._base_model, dataloaders=data)[0]['test_loss']

    def load_base_model(self, epoch=None):
        if epoch is None:
            ckpt_filename = next(iter(sorted(
                filter(filename_is_ckpt, os.listdir(self._model_path)),
                key=get_filename_epoch, reverse=True)))
            ckpt_path = os.path.join(self._model_path, ckpt_filename)
            self._base_model.load_from_checkpoint(
                ckpt_path, base_net=self._base_net, optimizer=self._optimizer)
        else:
            ckpt_filename = next(iter(filter(
                lambda filename: 'epoch={}-'.format(epoch) in filename,
                filter(filename_is_ckpt, os.listdir(self._model_path)))))
            ckpt_path = os.path.join(self._model_path, ckpt_filename)
            self._base_model.load_from_checkpoint(
                ckpt_path, base_net=self._base_net, optimizer=self._optimizer)

    @abstractmethod
    def _setup_spectrum_transform(self):
        pass

    @abstractmethod
    def _setup_spectra_collate(self):
        pass

    @abstractmethod
    def _setup_base_model(self):
        pass
