# -*- coding: utf-8 -*-

import numpy as np
import scipy as sp

from .algorithm import Algorithm

import sys

sys.path.insert(0, "../")
from utils import projection_Omega


class ProjectedGradientAscent(Algorithm):

    def run(self, a, b, X, Y, Omega, k):
        """Run the ascent iterations."""
        maxmin_values = []
        minmax_values = []
        gap = self.threshold
        step_size = self.step_size_0  # Fixed stepsize for Gradient Ascent
        for t in range(self.max_iter):

            if gap < self.threshold:
                if self.verbose:
                    print('Precision', gap, 'attained.')
                break

            if self.verbose and t % 10 == 0:
                print('Iteration', t)

            # Optimal transport computation
            C = self.Mahalanobis(X, Y, Omega)
            OT_val, OT_plan = self.OT(a, b, C)
            pi = OT_plan
            maxmin_values.append(OT_val)

            # Gradient step
            V = self.Vpi(X, Y, a, b, OT_plan)
            if self.reg == 0:  # Supergradient Method needs diminishing stepsizes
                step_size = self.step_size_0 / np.sqrt(t + 1)
            Omega = projection_Omega(Omega + step_size * V, k)

            # Duality values
            d = V.shape[0]
            eigenvalues = sp.linalg.eigh(V, eigvals=(d - k, d - 1), eigvals_only=True)

            sum_eigenvalues = np.sum(eigenvalues)
            max_maxmin_values = max(maxmin_values)
            gap = np.abs(sum_eigenvalues - max_maxmin_values) / max_maxmin_values
            minmax_values.append(sum_eigenvalues)

        return Omega, pi, maxmin_values, minmax_values
