""" Solver module. """

from abc import ABCMeta, abstractmethod
from ..model.network import Network
import lightning.pytorch as pl
from ..utils import check_consistency
from ..problem import AbstractProblem
import torch


class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
    """ Solver base class. """
    def __init__(self,
                 models,
                 problem,
                 optimizers, 
                 optimizers_kwargs, 
                 extra_features=None):
        """
        :param models: A torch neural network model instance.
        :type models: torch.nn.Module
        :param problem: A problem definition instance.
        :type problem: AbstractProblem
        :param list(torch.nn.Module) extra_features: the additional input
        features to use as augmented input. If ``None`` no extra features 
        are passed. If it is a list of ``torch.nn.Module``, the extra feature
        list is passed to all models. If it is a list of extra features' lists,
        each single list of extra feature is passed to a model.
        """
        super().__init__()

        # check consistency of the inputs
        check_consistency(models, torch.nn.Module)
        check_consistency(problem, AbstractProblem)
        check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
        check_consistency(optimizers_kwargs, dict)

        # put everything in a list if only one input
        if not isinstance(models, list):
            models = [models]
        if not isinstance(optimizers, list):
            optimizers = [optimizers]
            optimizers_kwargs = [optimizers_kwargs]

        # number of models and optimizers
        len_model = len(models)
        len_optimizer = len(optimizers)
        len_optimizer_kwargs = len(optimizers_kwargs)

        # check length consistency optimizers
        if len_model != len_optimizer:
            raise ValueError('You must define one optimizer for each model.'
                             f'Got {len_model} models, and {len_optimizer}'
                             ' optimizers.')
        
        # check length consistency optimizers kwargs
        if len_optimizer_kwargs != len_optimizer:
            raise ValueError('You must define one dictionary of keyword'
                             ' arguments for each optimizers.'
                             f'Got {len_optimizer} optimizers, and'
                             f' {len_optimizer_kwargs} dicitionaries')
        
        # extra features handling
        if extra_features is  None:
            extra_features = [None] * len_model
        else:
            # if we only have a list of extra features
            if not isinstance(extra_features[0], (tuple, list)):
                extra_features = [extra_features] * len_model
            else: # if we have a list of list extra features
                if len(extra_features) != len_model:
                    raise ValueError('You passed a list of extrafeatures list with len'
                                     f'different of models len. Expected {len_model} '
                                     f'got {len(extra_features)}. If you want to use'
                                     'the same list of extra features for all models, '
                                     'just pass a list of extrafeatures and not a list '
                                     'of list of extra features.')
            
        # assigning model and optimizers
        self._pina_models = []
        self._pina_optimizers = []

        for idx in range(len_model):
            model_ = Network(model=models[idx], extra_features=extra_features[idx])
            optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx])
            self._pina_models.append(model_)
            self._pina_optimizers.append(optim_)

        # assigning problem
        self._pina_problem = problem

    @abstractmethod
    def forward(self):
        pass
    
    @abstractmethod
    def training_step(self):
        pass

    @abstractmethod
    def configure_optimizers(self):
        pass

    @property
    def models(self):
        """
        The torch model."""
        return self._pina_models

    @property
    def optimizers(self):
        """
        The torch model."""
        return self._pina_optimizers

    @property
    def problem(self):
        """
        The problem formulation."""
        return self._pina_problem

    # @model.setter
    # def model(self, new_model):
    #     """
    #     Set the torch."""
    #     check_consistency(new_model, nn.Module, 'torch model')
    #     self._model= new_model

    # @problem.setter
    # def problem(self, problem):
    #     """
    #     Set the problem formulation."""
    #     check_consistency(problem, AbstractProblem, 'pina problem')
    #     self._problem = problem