import sys

sys.path.append("../")

import numpy as np
import math
from typing import List
from methods.saddle import Logger, CentralizedSaddleSlidingVRMB, ConstraintsL2
from oracles.saddle import ArrayPair, BaseSmoothSaddleOracle


class CentralizedSaddleSlidingVRMBRunner(object):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            L: float,
            mu: float,
            delta: float,
            r_x: float,
            r_y: float,
            eps: float,
            n_nodes: int,
            logger: Logger
    ):
        self.oracles = oracles
        self.L = L
        self.mu = mu
        self.delta = delta
        self.r_x = r_x
        self.r_y = r_y
        self.eps = eps
        self.n_nodes = n_nodes
        self.logger = logger
        self._params_computed = False

    def compute_method_params(self):
        self.probability = 1 / (math.sqrt(self.n_nodes))
        self.batch_size = int(0.2*max(1, int(math.sqrt(self.n_nodes))))
        self.step = 1
        self.alpha = 1
        self.step_inner = 1 / (self.step * self.L + 1)
        self.T_inner = 200
        self._params_computed = True
        print('SVOGS outer:', self.step, 'inner:', self.step_inner)
        

    def create_method(self, z_0: ArrayPair):
        if self._params_computed == False:
            raise ValueError("Call compute_method_params first")

        self.method = CentralizedSaddleSlidingVRMB(
            oracles=self.oracles,
            stepsize_outer=self.step,
            stepsize_inner=self.step_inner,
            inner_iterations=self.T_inner,
            probability=self.probability,
            alpha=self.alpha,
            batch_size=self.batch_size,
            z_0=z_0,
            logger=self.logger,
            constraints=ConstraintsL2(self.r_x, self.r_y)
        )

    def run(self, max_iter, max_time=None):
        self.method.run(max_iter, max_time)
