import numpy as np

from typing import Callable, List, Optional
from oracles.saddle import ArrayPair, BaseSmoothSaddleOracle, OracleLinearComb
from methods.saddle import Logger, extragradient_solver
from .base import BaseSaddleMethod
from .constraints import ConstraintsL2


class SaddlePointOracleRegularizer(BaseSmoothSaddleOracle):
    def __init__(self, oracle: BaseSmoothSaddleOracle, eta: float, v: ArrayPair):
        self.oracle = oracle
        self.eta = eta
        self.v = v

    def func(self, z: ArrayPair) -> float:
        return self.eta * self.oracle.func(z) + 0.5 * (z.x - self.v.x).dot(z.x - self.v.x) - \
               0.5 * (z.y - self.v.y).dot(z.y - self.v.y)

    def grad_x(self, z: ArrayPair) -> np.ndarray:
        return self.eta * self.oracle.grad_x(z) + z.x - self.v.x

    def grad_y(self, z: ArrayPair) -> np.ndarray:
        return self.eta * self.oracle.grad_y(z) + self.v.y - z.y


class SaddleSliding(BaseSaddleMethod):
    def __init__(
            self,
            oracle_g: BaseSmoothSaddleOracle,
            oracle_phi: BaseSmoothSaddleOracle,
            stepsize_outer: float,
            stepsize_inner: float,
            inner_solver: Callable,
            inner_iterations: int,
            z_0: ArrayPair,
            logger: Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        super().__init__(oracle_g, z_0, None, None, logger)
        self.oracle_g = oracle_g
        self.oracle_phi = oracle_phi
        self.stepsize_outer = stepsize_outer
        self.stepsize_inner = stepsize_inner
        self.inner_solver = inner_solver
        self.inner_iterations = inner_iterations
        self.constraints = constraints

    def step(self):
        v = self.z - self.oracle_g.grad(self.z) * self.stepsize_outer
        u = self.solve_subproblem(v)
        self.z = u + self.stepsize_outer * (self.oracle_g.grad(self.z) - self.oracle_g.grad(u))

    def solve_subproblem(self, v: ArrayPair) -> ArrayPair:
        suboracle = SaddlePointOracleRegularizer(self.oracle_phi, self.stepsize_outer, v)
        return self.inner_solver(
            suboracle,
            self.stepsize_inner, v, num_iter=self.inner_iterations, constraints=self.constraints)


class CentralizedSaddleSliding(BaseSaddleMethod):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            stepsize_outer: float,
            stepsize_inner: float,
            inner_iterations: int,
            z_0: ArrayPair,
            logger=Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        self._num_nodes = len(oracles)
        oracle_sum = OracleLinearComb(oracles, [1 / self._num_nodes] * self._num_nodes)
        super().__init__(oracle_sum, z_0, None, None, logger)
        self.oracle_list = oracles
        self.stepsize_outer = stepsize_outer
        self.stepsize_inner = stepsize_inner
        self.inner_iterations = inner_iterations
        self.constraints = constraints

    def step(self):
        grad_z_list = [oracle.grad(self.z) for oracle in self.oracle_list]
        grad_z = ArrayPair.mean(grad_z_list)
        grad_1_z = self.oracle_list[0].grad(self.z)
        v = self.z - self.stepsize_outer * (grad_z - grad_1_z)
        u = self.solve_subproblem(0, v)
        self.gradient_calls += 2 * self.inner_iterations
        grad_u_list = [oracle.grad(u) for oracle in self.oracle_list]
        grad_u = ArrayPair.mean(grad_u_list)
        grad_1_u = self.oracle_list[0].grad(u)
        self.z = u + self.stepsize_outer * (grad_z - grad_1_z - grad_u + grad_1_u)
        if self.constraints is not None:
            self.z = self.constraints.apply(self.z)
        self.current_round_volume += 4 * self._num_nodes
        self.gradient_calls += 2 * self._num_nodes

    def solve_subproblem(self, m: int, v: ArrayPair):
        suboracle = SaddlePointOracleRegularizer(self.oracle_list[m], self.stepsize_outer, v)
        return extragradient_solver(suboracle,
                                    self.stepsize_inner, v, num_iter=self.inner_iterations,
                                    constraints=self.constraints)
        
        
class CentralizedExtragradientSliding(BaseSaddleMethod):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            stepsize_outer: float,
            stepsize_inner: float,
            alpha: float,
            theta: float,
            inner_iterations: int,
            z_0: ArrayPair,
            logger=Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        self._num_nodes = len(oracles)
        oracle_sum = OracleLinearComb(oracles, [1 / self._num_nodes] * self._num_nodes)
        super().__init__(oracle_sum, z_0, None, None, logger)
        self.oracle_list = oracles
        self.stepsize_outer = stepsize_outer
        self.stepsize_inner = stepsize_inner
        self.alpha = alpha
        self.theta = theta
        self.inner_iterations = inner_iterations
        self.constraints = constraints

    def step(self):
        grad_z_list = [oracle.grad(self.z) for oracle in self.oracle_list]
        self.gradient_calls += 2 * self._num_nodes
        grad_z = ArrayPair.mean(grad_z_list)
        grad_1_z = self.oracle_list[0].grad(self.z)
        v = self.z - self.theta * (grad_z - grad_1_z)
        u = self.solve_subproblem(0, v)
        self.gradient_calls += 2 * self.inner_iterations
        grad_u_list = [oracle.grad(u) for oracle in self.oracle_list]
        self.gradient_calls += 2 * self._num_nodes
        grad_u = ArrayPair.mean(grad_u_list)
        self.z = self.z + self.stepsize_outer * self.alpha * (u - self.z) - self.stepsize_outer * grad_u
        if self.constraints is not None:
            self.z = self.constraints.apply(self.z)
        self.current_round_volume += 2 * self._num_nodes

    def solve_subproblem(self, m: int, v: ArrayPair):
        suboracle = SaddlePointOracleRegularizer(self.oracle_list[m], self.theta, v)
        return extragradient_solver(suboracle,
                                    self.stepsize_inner, v, num_iter=self.inner_iterations,
                                    constraints=self.constraints)


class CentralizedSaddleSlidingVR(BaseSaddleMethod):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            stepsize_outer: float,
            stepsize_inner: float,
            inner_iterations: int,
            probability: float,
            z_0: ArrayPair,
            logger=Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        self._num_nodes = len(oracles)
        oracle_sum = OracleLinearComb(oracles, [1 / self._num_nodes] * self._num_nodes)
        super().__init__(oracle_sum, z_0, None, None, logger)
        self.oracle_list = oracles
        self.stepsize_outer = stepsize_outer
        self.stepsize_inner = stepsize_inner
        self.inner_iterations = inner_iterations
        self.probability = probability
        self.z = z_0
        self.constraints = constraints
        self.w = self.z.copy()
        self.grad_w_list = [oracle.grad(self.w) for oracle in self.oracle_list]  
        self.grad_w = ArrayPair.mean(self.grad_w_list)        

    def step(self):
        self.z_bar = (1 - self.probability) * self.z + self.probability * self.w
        v = self.z_bar - self.stepsize_outer * (self.grad_w - self.grad_w_list[0]) 
        u = self.solve_subproblem(0, v)
        self.gradient_calls += 2 * self.inner_iterations
        j = np.random.choice(self._num_nodes, 1, replace=False)[0] 
        grad_j_u = self.oracle_list[j].grad(u)
        grad_1_u = self.oracle_list[0].grad(u)
        self.current_round_volume += 2
        self.gradient_calls += 2
        self.z = u + self.stepsize_outer * (self.grad_w_list[j] - self.grad_w_list[0] - grad_j_u + grad_1_u)        
        if self.constraints is not None:
            self.z = self.constraints.apply(self.z)
        if np.random.uniform() < self.probability:
            self.w = self.z
            self.grad_w_list = [oracle.grad(self.w) for oracle in self.oracle_list]
            self.gradient_calls += 2 * self._num_nodes
            self.current_round_volume += 2 * self._num_nodes
            self.grad_w = ArrayPair.mean(self.grad_w_list)
                        
    def solve_subproblem(self, m: int, v: ArrayPair):
        suboracle = SaddlePointOracleRegularizer(self.oracle_list[m], self.stepsize_outer, v)
        return extragradient_solver(suboracle,
                                    self.stepsize_inner, v, num_iter=self.inner_iterations,
                                    constraints=self.constraints)


class CentralizedSaddleSlidingVRMB(BaseSaddleMethod):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            stepsize_outer: float,
            stepsize_inner: float,
            inner_iterations: int,
            probability: float,
            alpha: float,
            batch_size: int,
            z_0: ArrayPair,
            logger=Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        self._num_nodes = len(oracles)
        oracle_sum = OracleLinearComb(oracles, [1 / self._num_nodes] * self._num_nodes)
        super().__init__(oracle_sum, z_0, None, None, logger)
        self.oracle_list = oracles
        self.stepsize_outer = stepsize_outer
        self.stepsize_inner = stepsize_inner
        self.inner_iterations = inner_iterations
        self.probability = probability
        self.alpha = alpha
        self.batch_size = batch_size
        self.z = z_0
        self.constraints = constraints
        self.w = self.z.copy()
        self.z_prev = self.z.copy()  
        self.w_prev = self.w.copy()
        self.grad_w_prev_list = [oracle.grad(self.w_prev) for oracle in self.oracle_list]  
        self.grad_w_prev = ArrayPair.mean(self.grad_w_prev_list)
        

    def step(self):
        self.z_bar = (1 - self.probability) * self.z + self.probability * self.w
        sample_indices = np.random.choice(self._num_nodes, self.batch_size, replace=False)
        sum_term = ArrayPair.zeros_like(self.z)
        for j in sample_indices:
            grad_j_z = self.oracle_list[j].grad(self.z)
            grad_j_1 = self.oracle_list[0].grad(self.z)
            sum_term += (grad_j_z - grad_j_1
                         - self.grad_w_prev_list[j] + self.grad_w_prev_list[0]
                         + self.alpha * (grad_j_z - grad_j_1 - self.oracle_list[j].grad(self.z_prev) + self.oracle_list[0].grad(self.z_prev)))
        v = self.z_bar - self.stepsize_outer * (self.grad_w_prev - self.grad_w_prev_list[0]) - self.stepsize_outer / self.batch_size * sum_term
        u = self.solve_subproblem(0, v)
        self.gradient_calls += 2 * self.inner_iterations
        self.z_prev = self.z
        self.z = u
        self.gradient_calls += 2 * self.batch_size
        if self.w_prev != self.w:
            self.grad_w_prev_list = [oracle.grad(self.w) for oracle in self.oracle_list]  
            self.grad_w_prev = ArrayPair.mean(self.grad_w_prev_list)
            self.gradient_calls += 2 * self._num_nodes
        self.w_prev = self.w
        if np.random.uniform() < self.probability:
            self.w = self.z
            self.current_round_volume += 2 * self._num_nodes + 4 * self.batch_size
        else:
            self.current_round_volume += 4 * self.batch_size
            
    def solve_subproblem(self, m: int, v: ArrayPair):
        suboracle = SaddlePointOracleRegularizer(self.oracle_list[m], self.stepsize_outer, v)
        return extragradient_solver(suboracle,
                                    self.stepsize_inner, v, num_iter=self.inner_iterations,
                                    constraints=self.constraints)

