

from abc import ABC, abstractmethod

import numpy as np

import torch

import matplotlib.pyplot as plot

import logging

import copy

## Define default solver settings
def settings(**kwargs):
    all_settings = {
        'iterations': 100,
        'batch_size': None,
        'lr_p0': 0.001,
        'lr_p_decay': None,
        'lr_p_period': None,
        'lr_d0': 0.1,
        'lr_d_decay': None,
        'lr_d_period': None,
        'lambdas0': 1,
        'mus0': 1,
        'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
        'STOP_DIVERGENCE': 1e4,
        'STOP_ABS_PVAL': None,
        'STOP_ABS_DGAP': None,
        'STOP_REL_DGAP': None,
        'STOP_ABS_PFEAS': None,
        'STOP_REL_PFEAS': None,
        'STOP_PGRAD': None,
        'STOP_DGRAD': None,
        'STOP_NO_UPDATE': None,
        'STOP_USER_DEFINED': (lambda state_dict: False)
    }

    for key, value in kwargs.items():
        all_settings[key] = value
    
    if 'verbose' not in all_settings:
        all_settings['verbose'] = all_settings['iterations']/10
    
    if 'logger' not in all_settings:
        all_settings['logger'] = logging.getLogger(__name__)
        
        # Add stream handler (console), if not defined already
        if not all_settings['logger'].hasHandlers():
            all_settings['logger'].addHandler(logging.StreamHandler())
            
        # Set level to info
        all_settings['logger'].setLevel(logging.INFO)
    
    return all_settings




class PrimalDual(ABC):
    def __init__(self, settings_):
        ### Solver setup
        if settings_ is not None:
            self.settings = settings_
        else:
            # Default settings
            self.settings = settings()

        # Initializations
        self.state_dict = {}
        self.state_dict['iteration'] = 0
        self.state_dict['lr_p'] = self.settings['lr_p0']
        self.state_dict['lr_d'] = self.settings['lr_d0']
        self.state_dict['no_update_iterations'] = self.settings['STOP_NO_UPDATE']
        
        # Logging variables
        self.state_dict['primal_value_log'] = np.zeros(0)
        self.state_dict['dual_value_log'] = np.zeros(0)
        self.state_dict['feas_log'] = np.zeros(0)
        self.state_dict['rel_feas_log'] = np.zeros(0)
        self.state_dict['lambdas_log'] = None
        self.state_dict['mus_log'] = None
        self.state_dict['all_slacks'] = np.zeros((0,0))

        self.models = []
        self.slack_evolution = []

    # Primal descent step
    @abstractmethod
    def primal(self, problem):
        pass

    # Dual ascent step
    def dual(self, problem):
        constraint_value, pointwise_value = problem.slacks(batch_size = self.settings['batch_size'])
        dual_grad_norm = 0
        
        for ii, slack in enumerate(constraint_value):
            problem.lambdas[ii] += self.state_dict['lr_d'] * slack
            if problem.lambdas[ii] < 0:
                problem.lambdas[ii] = 0
            dual_grad_norm += slack**2
            
        for ii, slack in enumerate(pointwise_value):
            problem.mus[ii] += self.state_dict['lr_d'] * slack
            problem.mus[ii][problem.mus[ii] < 0] = 0
            dual_grad_norm += torch.norm(slack)**2
        
        lagrangian_value = problem.lagrangian()

        return lagrangian_value, dual_grad_norm
    
    
    def update_state(self, **kwargs):
        for key, value in kwargs.items():
            if key in self.state_dict:
                self.state_dict[key+'_prev'] = self.state_dict[key]
            else:
                self.state_dict[key+'_prev'] = value
            
            self.state_dict[key] = value

    
    # Primal-dual algorithm
    def solve(self, problem):
        if problem._solved:
            self.settings['logger'].debug('Problem solved!')
        else:
            # Initializations
            if problem.constraints:
                if problem.lambdas is None:
                    problem.lambdas = [torch.tensor(self.settings['lambdas0'], dtype = torch.float, requires_grad = False, device = self.settings['device']) for i in problem.constraints]
                if self.state_dict['lambdas_log'] is None:
                    self.state_dict['lambdas_log'] = np.zeros([0, len(problem.constraints)])
            else:
                problem.lambdas = []
            
            if problem.pointwise:
                if problem.mus is None:
                    problem.mus = [self.settings['mus0']*torch.ones_like(rhs, dtype = torch.float, requires_grad = False, device = self.settings['device']) for rhs in problem.pointwise_rhs]
                if self.state_dict['mus_log'] is None:
                    self.state_dict['mus_log'] = np.zeros([0, len(problem.pointwise), 3])
            else:
                problem.mus = []


            ### START OF ITERATIONS ###
            for self.state_dict['iteration'] in range(self.state_dict['iteration'], self.state_dict['iteration'] + self.settings['iterations']):
                ### PRIMAL ###
                primal_value, primal_grad_norm = self.primal(problem)
                self.update_state(primal_value = primal_value,
                                  primal_grad_norm = primal_grad_norm)
                
                # Update primal step size
                if self.settings['lr_p_period'] is not None and self.settings['lr_p_decay'] is not None:
                    if self.state_dict['iteration'] % self.settings['lr_p_period'] == self.settings['lr_p_period']-1:
                        self.state_dict['lr_p'] *= self.settings['lr_p_decay']
                
                # Log values
                self.state_dict['primal_value_log'] = np.append(self.state_dict['primal_value_log'], primal_value)


                ### DUAL ###
                if problem.constraints or problem.pointwise:
                    dual_value, dual_grad_norm = self.dual(problem)
                    self.update_state(dual_value = dual_value,
                                      dual_grad_norm = dual_grad_norm)
                    
                    # Update dual step size
                    if self.settings['lr_d_period'] is not None and self.settings['lr_d_decay'] is not None:
                        if self.state_dict['iteration'] % self.settings['lr_d_period'] == self.settings['lr_d_period']-1:
                            self.state_dict['lr_d'] *= self.settings['lr_d_decay']
                            
                    # Duality gap
                    self.update_state(duality_gap = np.abs(primal_value - dual_value))
                    if primal_value < 1e-6:
                        self.update_state(rel_duality_gap = np.float("Inf"))
                    else:
                        self.update_state(rel_duality_gap = self.state_dict['duality_gap']/primal_value)
                    
                    # Feasibility
                    with torch.no_grad():
                        constraint_value, pointwise_value = problem.slacks(self.settings['batch_size'])
                        
                        if constraint_value:
                            constraint_feas = np.array([slacks.to('cpu') for slacks in constraint_value])
                            constraint_rel_feas = np.array([(slack/rhs).to('cpu').numpy() if (rhs != 0) else float("Inf")\
                                                            for slack, rhs in zip(constraint_value, problem.rhs)])
                            self.update_state(constraint_feas = constraint_feas,
                                              constraint_rel_feas = constraint_rel_feas)
                            
                            self.slack_evolution.append(constraint_feas)
                        
                        if pointwise_value:
                            pointwise_feas = np.concatenate([value.to('cpu').numpy() for value in pointwise_value])
                            pointwise_rel_feas = np.concatenate([(slack/rhs).to('cpu').numpy() if torch.all(rhs != 0) else float("Inf")\
                                                                 for slack, rhs in zip(pointwise_value, problem.pointwise_rhs)])
                            self.update_state(pointwise_feas = pointwise_feas,
                                              pointwise_rel_feas = pointwise_rel_feas,
                                              mus_max = np.max(np.array([torch.max(value).item() for value in problem.mus])))
    
                        feas = np.max([self.state_dict.get('constraint_feas', np.array(np.float('-Inf'))).max(),
                                       self.state_dict.get('pointwise_feas', np.array(np.float('-Inf'))).max()])
                        rel_feas = np.max([self.state_dict.get('constraint_rel_feas', np.array(np.float('-Inf'))).max(),
                                       self.state_dict.get('pointwise_rel_feas', np.array(np.float('-Inf'))).max()])
                        self.update_state(feas = feas, rel_feas = rel_feas)


                    # Log values
                    self.state_dict['dual_value_log'] = np.append(self.state_dict['dual_value_log'] , self.state_dict['dual_value'])
                    self.state_dict['feas_log'] = np.append(self.state_dict['feas_log'], self.state_dict['feas'])
                    self.state_dict['rel_feas_log'] = np.append(self.state_dict['rel_feas_log'], self.state_dict['rel_feas'])
                    if self.state_dict['lambdas_log'] is not None:
                        self.state_dict['lambdas_log'] = np.append(self.state_dict['lambdas_log'],
                                                                   np.array([lambda_value.to('cpu') for lambda_value in problem.lambdas], ndmin = 2),
                                                                   axis=0)
                        self.update_state(lambdas_max = np.max(self.state_dict['lambdas_log'][-1,:]))
                        
                    if self.state_dict['mus_log'] is not None:
                        self.state_dict['mus_log'] = np.append(self.state_dict['mus_log'],
                                                               np.array([[mu.mean().to('cpu'), mu.min().to('cpu'), mu.max().to('cpu')] \
                                                                         for mu in problem.mus], ndmin=3),
                                                               axis=0)
                        self.update_state(mus_max = np.max(self.state_dict['mus_log'][-1,:,2]))
                    
                    # Tracking message
                    log_message = (f"Iteration {self.state_dict['iteration']}: P = {self.state_dict['primal_value']:.3g} / "
                                   f"DGAP = {self.state_dict['duality_gap']:.3g} / "
                                   f"REL_DGAP = {self.state_dict['rel_duality_gap']:.3g} / "
                                   f"PFEAS = {self.state_dict['feas']:.3g} / "
                                   f"REL_PFEAS = {self.state_dict['rel_feas']:.3g} / "
                                   f"PGRAD = {self.state_dict['primal_grad_norm']:.3g} / "
                                   f"DGRAD = {self.state_dict['dual_grad_norm']:.3g}")
                else:
                    # Tracking message
                    log_message = (f"Iteration {self.state_dict['iteration']}: P = {self.state_dict['primal_value']:.3g} / "
                                   f"PGRAD = {self.state_dict['primal_grad_norm']:.3g}")


                ### EARLY STOPPING ###
                # Divergence
                if self.state_dict['primal_value'] >= self.settings['STOP_DIVERGENCE'] or np.isnan(self.state_dict['primal_value']) or \
                    self.state_dict.get('lambdas_max', 0) >= self.settings['STOP_DIVERGENCE'] or np.isnan(self.state_dict.get('lambdas_max', 0)) or \
                        self.state_dict.get('mus_max', 0) >= self.settings['STOP_DIVERGENCE'] or np.isnan(self.state_dict.get('mus_max', 0)):
                            self.settings['logger'].error(f"Algorithm diverged ({self.state_dict['iteration']+1} iterations)")
                            break
                
                # Stopping criteria
                if self.check_stopping_criteria():
                    problem._solved = True
                    break
                
                if self.settings['STOP_USER_DEFINED'](self.state_dict):
                    problem._solved = True
                    break

                # Save current model to list
                if (problem.constraints or problem.pointwise):
                    if (self.state_dict['iteration'] > self.settings['iterations']*0.7):
                        self.models.append(copy.deepcopy(problem.model))

                # Tracking
                if self.state_dict['iteration'] % self.settings['verbose'] == 0:
                    self.settings['logger'].info(log_message)
                else:
                    self.settings['logger'].debug(log_message)

                

            ### END OF ITERATIONS ###
            self.state_dict['iteration'] += 1
            if self.state_dict['iteration'] == self.settings['iterations']:
                problem._solved = True
            
            ### LOGGING ###
            # End of optimization details
            self.settings['logger'].info('')
            self.settings['logger'].info('===== FINAL REPORT =====')
            if problem.constraints or problem.pointwise:
                self.settings['logger'].info('(%d iterations) P = %.3g / DGAP = %.3g / REL_DGAP = %.3g / PFEAS = %.3g / REL_PFEAS = %.3g / PGRAD = %.3g / DGRAD = %3g',
                                             self.state_dict['iteration'],
                                             self.state_dict['primal_value'],
                                             self.state_dict['duality_gap'],
                                             self.state_dict['rel_duality_gap'],
                                             self.state_dict['feas'],
                                             self.state_dict['rel_feas'],
                                             self.state_dict['primal_grad_norm'],
                                             self.state_dict['dual_grad_norm'])
            else:
                self.settings['logger'].info('(%d iterations) P = %.3g / PGRAD = %.3g',
                                              self.state_dict['iteration'],
                                              self.state_dict['primal_value'],
                                              self.state_dict['primal_grad_norm'])
            self.settings['logger'].info('')


    def check_stopping_criteria(self):
        stopping = False
        
        # Primal value and absolute feasibility
        if self.settings['STOP_ABS_PVAL'] is not None and self.settings['STOP_ABS_PVAL'] is not None:
            if self.state_dict['primal_value'] < self.settings['STOP_ABS_PVAL'] and \
                self.state_dict.get('feas', 0) < self.settings['STOP_ABS_PFEAS']:
                    self.settings['logger'].info('Stopping criterion: small primal value and primal feasibility')
                    stopping = True

        # Stalled
        constraint_diff = np.clip(self.state_dict.get('constraint_feas_prev', 0), 0, None) - \
            np.clip(self.state_dict.get('constraint_feas', 0), 0, None)
        pointwise_diff = np.clip(self.state_dict.get('pointwise_feas_prev', 0), 0, None) - \
            np.clip(self.state_dict.get('pointwise_feas', 0), 0, None)

        if self.settings['STOP_NO_UPDATE'] is not None:
            if self.state_dict['primal_value'] < self.state_dict['primal_value_prev'] or \
                np.any(constraint_diff > 0) or np.max(pointwise_diff > 0):
                self.update_state(no_update_iterations = self.settings['STOP_NO_UPDATE'])
            else:
                self.state_dict['no_update_iterations'] -= 1
                if self.state_dict['no_update_iterations'] == 0:
                    self.settings['logger'].info("Stopping criterion: no improvement on the objective or feasibility"
                                                 f"for {self.settings['STOP_NO_UPDATE']} iterations.")
                    stopping = True

        # Gradient norms
        if self.settings['STOP_PGRAD'] is not None and self.settings['STOP_DGRAD'] is not None:
            if self.state_dict['primal_grad_norm'] < self.settings['STOP_PGRAD'] and \
                self.state_dict.get('dual_grad_norm', 0) < self.settings['STOP_DGRAD']:
                self.settings['logger'].info('Stopping criterion: small gradients')
                stopping = True
        elif self.settings['STOP_PGRAD'] is not None:
            if self.state_dict['primal_grad_norm'] < self.settings['STOP_PGRAD']:
                self.settings['logger'].info('Stopping criterion: small gradients')
                stopping = True
        elif self.settings['STOP_DGRAD'] is not None:
            if self.state_dict.get('dual_grad_norm', 0) < self.settings['STOP_DGRAD']:
                self.settings['logger'].info('Stopping criterion: small gradients')
                stopping = True

        # Criteria that only apply to constrained problem
        if 'dual_value' in self.state_dict:
            # Absolute duality gap and feasibility
            if self.settings['STOP_ABS_DGAP'] is not None and self.settings['STOP_ABS_PFEAS'] is not None:
                if self.state_dict['duality_gap'] < self.settings['STOP_ABS_DGAP'] and \
                    self.state_dict['feas'] < self.settings['STOP_ABS_PFEAS']:
                        self.settings['logger'].info('Stopping criterion: absolute duality gap and primal feasibility')
                        stopping = True
            
            # Relative gap and feasibility
            if self.settings['STOP_REL_DGAP'] is not None and self.settings['STOP_REL_PFEAS'] is not None:
                if self.state_dict['rel_duality_gap'] < self.settings['STOP_REL_DGAP'] and \
                    self.state_dict['rel_feas'] < self.settings['STOP_REL_PFEAS']:
                        self.settings['logger'].info('Stopping criterion: relative duality gap and primal feasibility')
                        stopping = True
        
        return stopping

    
    # Trace plots
    def plots(self):
        ### Create figure
        if self.state_dict['lambdas_log'] is None and self.state_dict['mus_log'] is None:
            fig, axes = plot.subplots(1,1)

            ### Solver traces
            # Primal value plot
            axes.plot(np.arange(1,self.state_dict['iteration'] + 1),  self.state_dict['primal_value_log'][:self.state_dict['iteration']], label = 'Primal')
            axes.grid()
            axes.autoscale()
            
        else:
            fig, axes = plot.subplots(2, 3, sharex = True)
            
            axes[0,0].set_title('Duality gap')
            axes[0,1].set_title('Relative duality gap')
            axes[0,2].set_title('Dual variables (average)')
            axes[1,0].set_title('Feasibility')
            axes[1,1].set_title('Relative feasibility')
            axes[1,2].set_title('Dual variables (pointwise)')
            
            plot.tight_layout()
            ### Solver traces
            # Primal-dual values plot
            axes[0,0].plot(np.arange(1,self.state_dict['iteration'] + 1),
                           self.state_dict['primal_value_log'],
                           label = 'Primal')
            axes[0,0].plot(np.arange(1,self.state_dict['iteration'] + 1),
                           self.state_dict['dual_value_log'],
                           label = 'Dual')
            axes[0,0].legend()
            
            # Relative duality gap plot
            P = self.state_dict['primal_value_log']
            D = self.state_dict['dual_value_log']
            axes[0,1].plot(np.arange(1,self.state_dict['iteration'] + 1), np.abs((P-D)/P))
            if self.settings['STOP_REL_DGAP'] is not None:
                axes[0,1].hlines(self.settings['STOP_REL_DGAP'],
                                 1, self.state_dict['iteration'],
                                 color = 'black', linestyle = 'dashed')
            axes[0,1].set_yscale('log')
        
            # Absolute feasibility
            axes[1,0].plot(np.arange(1,self.state_dict['iteration'] + 1),
                           self.state_dict['feas_log'])
            if self.settings['STOP_ABS_PFEAS'] is not None:
                axes[1,0].hlines(self.settings['STOP_ABS_PFEAS'],
                                 1, self.state_dict['iteration'],
                                 color = 'black', linestyle = 'dashed')
            
            # Relative feasibility
            axes[1,1].plot(np.arange(1,self.state_dict['iteration'] + 1),
                           self.state_dict['rel_feas_log'])
            if self.settings['STOP_REL_PFEAS'] is not None:
                axes[1,1].hlines(self.settings['STOP_REL_PFEAS'],
                                 1, self.state_dict['iteration'],
                                 color = 'black', linestyle = 'dashed')
            
            # Dual variables
            if self.state_dict['lambdas_log'] is not None:
                axes[0,2].plot(np.arange(1,self.state_dict['iteration'] + 1),
                               self.state_dict['lambdas_log'])
            else:
                fig.delaxes(axes[0,2])
            
            if self.state_dict['mus_log'] is not None:
                for ii in range(self.state_dict['mus_log'].shape[1]):
                    axes[1,2].fill_between(np.arange(1,self.state_dict['iteration'] + 1),
                                           self.state_dict['mus_log'][:,ii,1],
                                           self.state_dict['mus_log'][:,ii,2],
                                           alpha = 0.25, color = f'C{ii}')
                    axes[1,2].plot(np.arange(1,self.state_dict['iteration'] + 1),
                                   self.state_dict['mus_log'][:,ii,0],
                                   color = f'C{ii}')
            else:
                fig.delaxes(axes[1,2])

            for ax in axes.flat:
                ax.grid()
                ax.autoscale()
    
        # Show figure
        fig.show()
        
        return fig


class TorchedPrimalDual(PrimalDual):
    def __init__(self, primal_solver, dual_solver, settings_ = None):
        super(TorchedPrimalDual, self).__init__(settings_)
        self._primal_solver = primal_solver
        self._dual_solver  = dual_solver
        
        self.primal_solver = None
        self.dual_solver  = None
        self.state_dict['primal_solver'] = None
        self.state_dict['dual_solver'] = None

    # Primal descent step
    def primal(self, problem):
        ### Initializations on the first call
        # Torch solver
        self.primal_solver = self._primal_solver(problem.parameters)
        if self.state_dict['primal_solver'] is not None:
            self.primal_solver.load_state_dict(self.state_dict['primal_solver'])
        
        ### Evaluate primal
        primalOutput = self.primal_post(problem)

        ### Do not run solver initialization again
        self.primal = self.primal_post
        
        return primalOutput
        
    
    def primal_post(self, problem):
        # Adjust step size
        for param_group in self.primal_solver.param_groups:
            param_group['lr'] = self.state_dict['lr_p']
        
        primal_value = 0
        primal_grad_norm = 0
        total_n = 0
        
        self.primal_solver.zero_grad()
        for lagrangian_value, normalization in problem.primal_batch(batch_size = self.settings['batch_size']):
            lagrangian_value.backward()
            self.primal_solver.step()
            
            with torch.no_grad():
                primal_value += lagrangian_value.item()*normalization
                primal_grad_norm += np.sum([p.grad.norm().cpu()**2 for p in problem.parameters]) * normalization**2
                total_n += normalization
            
            self.primal_solver.zero_grad()
        
        self.state_dict['primal_solver'] = self.primal_solver.state_dict()
        
        with torch.no_grad():
            primal_value /= total_n
            primal_grad_norm /= total_n**2

        return primal_value, primal_grad_norm
    
    
    # Dual ascent step
    def dual(self, problem):
        ### Initializations on the first call
        self.dual_solver = self._dual_solver(problem.lambdas + problem.mus)
        if self.state_dict['dual_solver'] is not None:
            self.dual_solver.load_state_dict(self.state_dict['dual_solver'])
        
        ### Evaluate primal
        dualOutput = self.dual_post(problem)

        ### Do not run solver initialization again
        self.dual = self.dual_post
        
        return dualOutput
    
    
    def dual_post(self, problem):
        for param_group in self.dual_solver.param_groups:
            param_group['lr'] = self.state_dict['lr_d']

        with torch.no_grad():
            constraint_value, pointwise_value = problem.slacks(batch_size = self.settings['batch_size'])
            
            dual_grad_norm = 0
            for ii, slack in enumerate(constraint_value):
                problem.lambdas[ii].grad = -slack
                dual_grad_norm += slack**2
                
            for ii, slack in enumerate(pointwise_value):
                problem.mus[ii].grad = -slack
                dual_grad_norm += torch.norm(slack)**2
            
            self.dual_solver.step()
            
            for ii, _ in enumerate(problem.lambdas):
                problem.lambdas[ii][problem.lambdas[ii] < 0] = 0
                
            for ii, _ in enumerate(problem.mus):
                problem.mus[ii][problem.mus[ii] < 0] = 0
            
            obj_value = problem.objective(self.settings['batch_size'])
            constraint_slacks, pointwise_slacks = problem.slacks(self.settings['batch_size'])
            lagrangian_value = obj_value.item()
            if constraint_slacks:
                lagrangian_value += np.sum([lambda_value.cpu()*slack.cpu()  for lambda_value, slack in zip(problem.lambdas, constraint_slacks)])
            if pointwise_slacks:
                lagrangian_value += np.sum([torch.dot(mu_value, slack)  for mu_value, slack in zip(problem.mus, pointwise_slacks)])
            
        self.state_dict['dual_solver'] = self.dual_solver.state_dict()
    
        return lagrangian_value, dual_grad_norm
    