

import torch
import numpy as np
from abc import ABC, abstractmethod


class ConstrainedLearningProblem(ABC):
    def __init_subclass__(cls, *args, **kwargs):
        cls.model = NotImplemented
        cls.parameters = NotImplemented
        cls.data = NotImplemented
        cls.data_size = NotImplemented
        
        cls.obj_function = NotImplemented
        # Takes batch indices, return scalar
        
        cls.constraints = []
        # Takes batch indices, returns a scalar average value
        cls.rhs = []
        
        cls.pointwise = []
        # Takes batch indices, returns a vector with one element per data point
        cls.pointwise_rhs = []

        cls.lambdas = None
        cls.mus = None
        
        cls._solved = False

        cls.models = []
        
        super().__init_subclass__(*args, **kwargs)

    # Lagrangian
    def lagrangian(self, batch_idx = None):
        if batch_idx is None:
            batch_idx = slice(None,None,None)

        obj_value = self.obj_function(batch_idx)
        constraint_slacks = self._constraint_slacks(batch_idx, dualized = True)
        pointwise_slacks = self._pointwise_slacks(batch_idx, dualized = True)

        L = obj_value
        
        for slack in constraint_slacks:
            L += slack
        
        for slack in pointwise_slacks:
            L += slack

        return L

    # Evaluate slacks
    def _constraint_slacks(self, batch_idx, dualized = False):
        if dualized:
            slacks_value = [lambda_value*(ell(batch_idx) - c) for lambda_value, ell, c in zip(self.lambdas, self.constraints, self.rhs)]
        else:
            slacks_value = [ell(batch_idx) - c for ell, c in zip(self.constraints, self.rhs)]

        return slacks_value

    def _pointwise_slacks(self, batch_idx, dualized = False):
        if dualized:
            slacks_value = [torch.dot(mu_value[batch_idx], ell(batch_idx) - c[batch_idx])  for mu_value, ell, c in zip(self.mus, self.pointwise, self.pointwise_rhs)]
        else:
            slacks_value = [ell(batch_idx) - c[batch_idx] for ell, c in zip(self.pointwise, self.pointwise_rhs)]

        return slacks_value
        
    
    def primal_batch(self, batch_size = None, shuffle = True, num_workers = 4):
        # if batch_size is not None:
        #     dataloader = torch.utils.data.DataLoader(self.data, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers)
        # else:
        #     dataloader = torch.utils.data.DataLoader(self.data, batch_size = len(self.data), shuffle = shuffle, num_workers = num_workers)
        
        # for x, y, _ in dataloader:
        #     yhat = self.model(x)
        #     obj_value = self.obj_function(x = x, yhat = yhat, y = y)
        #     slacks_value = self._slacks(x = x, yhat = yhat, y = y)
    
        #     yield (obj_value, slacks_value)
        
        if batch_size is None:
            batch_idx = [0, self.data_size]
        else:
            batch_idx = np.arange(0, self.data_size+1, batch_size)
            if batch_idx[-1] < self.data_size:
                batch_idx = np.append(batch_idx, self.data_size)

        if shuffle:
            idx_epoch = np.random.permutation(np.arange(self.data_size))
        else:
            idx_epoch = range(0, self.data_size)

        for batch_start, batch_end in zip(batch_idx, batch_idx[1:]):
            # Extract batch indices
            current_batch = idx_epoch[batch_start : batch_end]

            yield self.lagrangian(current_batch), batch_end-batch_start

        
    def objective(self, batch_size = None, num_workers = None):
        # if batch_size is not None:
        #     dataloader = torch.utils.data.DataLoader(self.data, batch_size = batch_size, num_workers = num_workers)
        #     obj_value = 0
        #     for x, y, _ in dataloader:
        #         obj_value += self.obj_function(x = x, yhat = self.model(x), y = y)
        # else:
        #     x, y, _ = self.data[:]
        #     obj_value = self.obj_function(x = x, yhat = self.model(x), y = y)

        if batch_size is None:
            obj_value = self.obj_function(slice(None,None,None))
        else:
            batch_idx = np.arange(0, self.data_size+1, batch_size)
            if batch_idx[-1] < self.data_size:
                batch_idx = np.append(batch_idx, self.data_size)
        
            obj_value = 0
            for batch_start, batch_end in zip(batch_idx, batch_idx[1:]):
                obj_value += self.obj_function(range(batch_start,batch_end))*(batch_end - batch_start)/self.data_size

        return obj_value
    
    
    def slacks(self, batch_size = None, num_workers = None):
        # if batch_size is not None:
        #     dataloader = torch.utils.data.DataLoader(self.data, batch_size = batch_size, num_workers = num_workers)
        #     slacks_value = [0]*len(self.constraints)
        #     for x, y, _ in dataloader:
        #         for ii, s in enumerate(self._slacks(x = x, yhat = self.model(x), y = y)):
        #             slacks_value[ii] = slacks_value[ii] + s*y.shape[0]/len(self.data)
        # else:
        #     x, y, _ = self.data[:]
        #     slacks_value = self._slacks(x = x, yhat = self.model(x), y = y)
        
        if batch_size is None:
            constraint_values = self._constraint_slacks(slice(None,None,None), dualized = False)
            pointwise_values = self._pointwise_slacks(slice(None,None,None), dualized = False)
        else:
            batch_idx = np.arange(0, self.data_size+1, batch_size)
            if batch_idx[-1] < self.data_size:
                batch_idx = np.append(batch_idx, self.data_size)
        
            constraint_values = [0]*len(self.constraints)
            for batch_start, batch_end in zip(batch_idx, batch_idx[1:]):
                for ii, s in enumerate(self._constraint_slacks(range(batch_start,batch_end), dualized = False)):
                    constraint_values[ii] += s*(batch_end - batch_start)/self.data_size
            
            pointwise_values = [torch.zeros([0])]*len(self.pointwise)
            for batch_start, batch_end in zip(batch_idx, batch_idx[1:]):
                for ii, s in enumerate(self._pointwise_slacks(range(batch_start,batch_end), dualized = False)):
                    pointwise_values[ii] = torch.cat((pointwise_values[ii], s))
        
        return constraint_values, pointwise_values
    