from abc import ABC, abstractmethod

from utils.constants import *
from utils.general_utils import gradient_of_f_at_xi


def square_euclidean_norm(x): return np.inner(x, x)


def default_distance_func(x, y): return square_euclidean_norm(x - y)


def default_distance_gradient_func(x, y): return 2 * (x - y)


def simple_proportional_function(d): return 1 - d + DELTA_


class PublishersGame(ABC):
    """Abstract class for a publishers' game."""

    def __init__(self, k, n, s, lam, x_0=None, x_star_lst=None,
                 base_distance_func=default_distance_func, distance_gradient_func=default_distance_gradient_func):
        """
        k: embedding space dimension
        n: number of publishers
        s: number of information needs
        lam: lambda parameter - the weight of the distance from the initial docs
        x_0: initial docs in rows
        x_star_lst: the information needs
        distance_func: distance function
        gradient_func: gradient of the distance function (if exists)
        """
        self.k = k
        self.n = n
        self.s = s
        self.lam = lam
        self.N = list(range(n))
        self.x_0 = x_0 if x_0 is not None else np.random.rand(n, k)
        # self.x_star_lst = (x_star_lst if x_star_lst.ndim == 2 else x_star_lst.reshape((1, k))) \
        #         if x_star_lst is None else np.random.rand(s, k) 
        self.x_star_lst = x_star_lst if x_star_lst is None else np.random.rand(s, k)
        max_d = base_distance_func(np.zeros(k), np.ones(k))  # we assume that this is the maximum distance
        # we assume that the range of the base distance function is [0, max_d] and we want the range of the function
        # to be [0, 1].
        self.d = lambda a, b: base_distance_func(a, b) / max_d
        if distance_gradient_func is not None:  # not all distance functions have gradients
            # we divide the gradient because we divided the distance function
            self.d_grad = lambda a, b: distance_gradient_func(a, b) / max_d
        assert x_0.shape == (n, k), f'x_0.shape={x_0.shape}'
        assert x_star_lst.shape == (s, k), f'x_star_lst.shape={x_star_lst.shape}'

    # Calculations functions

    def calc_dist_from_opt(self, xi, j):
        """Calculates the distance from the j's information need for publisher i."""
        return self.d(xi, self.x_star_lst[j])

    def calc_dist_from_init(self, xi, i):
        """Calculates the distance from the initial document for publisher i."""
        return self.d(xi, self.x_0[i])

    def calc_full_r(self, dist_from_opt_full, i):
        """Calculates the probability that the document of publisher i will be ranked first for each x_star."""
        return [self.calc_r([dist_from_opt_full[i][j] for i in self.N], i) for j in range(self.s)]

    @abstractmethod
    def calc_r(self, dist_from_opt_all, i):
        """Calculates the probability that publisher i's document will be ranked first 
            with respect to the distances from an information need."""
        pass

    def calc_u(self, x, i):
        """Calculates the utility for publisher i in profile x."""
        dist_from_opt_full = [[self.calc_dist_from_opt(x[i], j) for j in range(self.s)] for i in self.N]
        return sum(self.calc_full_r(dist_from_opt_full, i)) / self.s - self.lam * self.calc_dist_from_init(x[i], i)

    def calc_full_r_all(self, dist_from_opt_full):
        """Calculates the probability that the document of each publisher will be ranked first for each x_star."""
        return [self.calc_r_all([dist_from_opt_full[i][j] for i in self.N]) for j in range(self.s)]

    @abstractmethod
    def calc_r_all(self, dist_from_opt_all):
        """Calculates the ranking for all publishers with respect to the distances from an information need."""
        pass

    def calc_u_all(self, x):
        """Calculates the utility all publishers for profile x."""
        dist_from_init_all = [self.calc_dist_from_init(x[j], j) for j in self.N]
        dist_from_opt_full = [[self.calc_dist_from_opt(x[j], i) for i in range(self.s)] for j in self.N]
        # return self.calc_full_r_all(dist_from_opt_full) - self.lam * dist_from_init_all
        ranks = self.calc_full_r_all(dist_from_opt_full)
        return [sum([ranks[j][i] for j in range(self.s)]) / self.s - self.lam * dist_from_init_all[i] for i in self.N]

    def calc_publishers_welfare(self, x):
        """Calculates the publishers' welfare."""
        return np.sum(self.calc_u_all(x))

    def calc_users_welfare(self, x):
        """Calculates the users' welfare."""
        dist_from_opt_full = [[self.calc_dist_from_opt(x[j], i) for i in range(self.s)] for j in self.N]
        temp_r = self.calc_full_r_all(dist_from_opt_full)
        return np.sum([temp_r[j][i] * (1 - dist_from_opt_full[i][j]) for i in self.N for j in range(self.s)]) / self.s

    def calc_grad(self, x, i):
        """Calculates the gradient of the utility of publisher i."""
        return gradient_of_f_at_xi(lambda val: self.calc_u(val, i), x, i)


class ExtendedPublishersGame(PublishersGame, ABC):
    """Abstract class for an extended publishers' game."""

    def __init__(self, k, n, s, lam, x_0, x_star_lst,
                 base_distance_func=default_distance_func, distance_gradient_func=default_distance_gradient_func):
        super().__init__(k, n, s, lam, x_0, x_star_lst, base_distance_func, distance_gradient_func)
        self.x = x_0.copy()
        self.dist_from_opt = [None] * self.n
        self.dist_from_init = [None] * self.n
        self.r = [None] * self.n
        self.r_flag = False
        self.u = [None] * self.n
        self.u_flag = False

    def initialize(self):
        """Initialize the game"""
        self.update_x(self.x_0)

    # Get functions

    def get_dist_from_opt_full(self, i):
        """Gets the distance from each information need for publisher i."""
        if self.dist_from_opt[i] is None:
            self.dist_from_opt[i] = [self.calc_dist_from_opt(self.x[i], j) for j in range(self.s)]
        return self.dist_from_opt[i]

    def get_dist_from_init(self, i):
        """Gets the distance from the initial document for publisher i."""
        if self.dist_from_init[i] is None:
            self.dist_from_init[i] = self.calc_dist_from_init(self.x[i], i)
        return self.dist_from_init[i]

    def get_full_r(self, i):
        """Gets the probability for being ranked first for publisher i for each information need."""
        if self.r[i] is None:
            self.r[i] = self.calc_full_r([self.get_dist_from_opt_full(j) for j in self.N], i)
        return self.r[i]

    def get_u(self, i):
        """Gets the utility for publisher i."""
        if self.u[i] is None:
            self.u[i] = sum(self.get_full_r(i)) / self.s - self.lam * self.get_dist_from_init(i)
        return self.u[i]

    def get_grad(self, i):
        """Get the gradient of the utility of publisher i."""
        return self.calc_grad(self.x, i)

    def get_grad_all(self):
        """Gets the gradients of the utilities for all publishers."""
        return [self.get_grad(i) for i in self.N]

    def get_publishers_welfare(self):
        """Gets the publishers' welfare."""
        if not self.u_flag:
            self.save_u_all()
            self.u_flag = True
        return sum(self.u)

    def get_users_welfare(self):
        """Gets the users' welfare."""
        if not self.r_flag:
            self.save_r_all()
            self.r_flag = True
        # return sum([self.r[i][j] * (1 - self.get_dist_from_opt_full(i)[j]) for j in range(self.s) for i in self.N])
        sum_welfare = 0
        for i in self.N:
            dists = self.get_dist_from_opt_full(i)
            sum_welfare += sum([self.r[i][j] * (1 - dists[j]) for j in range(self.s)])
        return sum_welfare / self.s

    # Set functions

    def update_x(self, new_x):
        """Updates the current profile to new_x."""
        assert new_x.shape == (self.n, self.k), f'new_x.shape={new_x.shape}'
        self.x = new_x.copy()
        self.dist_from_opt = [None] * self.n
        self.dist_from_init = [None] * self.n
        self.r = [None] * self.n
        self.r_flag = False
        self.u = [None] * self.n
        self.u_flag = False
        return self

    def update_x_i(self, new_x_i, i):
        """Updates the current profile to new_x_i for publisher i."""
        assert new_x_i.shape[0] == self.k, f'new_x_i.shape={new_x_i.shape}'
        self.x[i] = new_x_i.copy()
        self.dist_from_opt[i] = None
        self.dist_from_init[i] = None
        self.r = [None] * self.n
        self.r_flag = False
        self.u = [None] * self.n
        self.u_flag = False
        return self

    # Calculations functions

    def calc_u_deviation(self, xi, i):
        """Calculates the utility for publisher i if he deviates to xi."""
        dist_from_opt_full = [[self.calc_dist_from_opt(xi, k) for k in range(self.s)]
                              if j == i else self.get_dist_from_opt_full(j) for j in self.N]
        return sum(self.calc_full_r(dist_from_opt_full, i)) / self.s - self.lam * self.calc_dist_from_init(xi, i)

    # Saving functions

    def save_r_all(self):
        """Saves the probability for being ranked first in the current profile for each publisher's document."""
        base_ranks = self.calc_full_r_all([self.get_dist_from_opt_full(i) for i in self.N])
        self.r = [[base_ranks[j][i] for j in range(self.s)] for i in self.N]

    def save_u_all(self):
        """Saves the utility in the current profile for each publisher."""
        if not self.r_flag:
            self.save_r_all()
            self.r_flag = True
        self.u = [self.get_u(i) for i in self.N]


class ProportionalPublishersGame(ExtendedPublishersGame):
    """General class for a proportional publishers' game."""

    def __init__(self, k, n, s, lam, x_0, x_star_lst, g=simple_proportional_function,
                 base_distance_func=default_distance_func, distance_gradient_func=default_distance_gradient_func):
        super().__init__(k, n, s, lam, x_0, x_star_lst, base_distance_func, distance_gradient_func)
        self.g = g
        assert self.g(0) != 0, f'g(0)={self.g(0)}'
        assert self.g(1) != 0, f'g(1)={self.g(1)}'

    def calc_r(self, dist_from_opt_all, i):
        r = np.array([self.g(d) for d in dist_from_opt_all])
        return r[i] / sum(r)

    def calc_r_all(self, dist_from_opt_all):
        r = np.array([self.g(d) for d in dist_from_opt_all])
        r = r / sum(r)
        return r.tolist()


class LinearProportionalPublishersGame(ProportionalPublishersGame):
    """Linear Proportional Publishers' Game."""

    def __init__(self, k, n, s, lam, x_0, x_star_lst, b=1 + DELTA_, base_distance_func=default_distance_func,
                 distance_gradient_func=default_distance_gradient_func):
        assert b > 1, f'b={b}'
        g = lambda d: b - d
        super().__init__(k, n, s, lam, x_0, x_star_lst, g, base_distance_func, distance_gradient_func)
        self.b = b

    def get_grad_all(self):
        dists_opt_full = np.array([self.get_dist_from_opt_full(i) for i in self.N])
        res = np.array([- self.lam * self.d_grad(self.x[i], self.x_0[i]) for i in self.N])  # initial doc gradient
        for j in range(self.s):
            sum_d_opt = np.sum(dists_opt_full[:, j])
            temp = 1 / (self.n * self.b - sum_d_opt)
            dists_opt_grad = np.array([self.d_grad(self.x[i], self.x_star_lst[j]) for i in self.N])
            rank_grad = ((self.b - dists_opt_full[:, j]) * (temp ** 2) - temp)
            opt_grad = rank_grad[:, np.newaxis] * dists_opt_grad
            res += opt_grad / self.s
        return res

    def calc_grad(self, x, i):
        """Calculates the gradient of the utility of publisher i."""
        dists_opt_full = np.array([[self.calc_dist_from_opt(x[j], k) for k in range(self.s)] for j in self.N])
        res = - self.lam * self.d_grad(x[i], self.x_0[i])  # initial doc gradient
        for j in range(self.s):
            sum_d_opt = np.sum(dists_opt_full[:, j])
            temp = 1 / (self.n * self.b - sum_d_opt)
            dists_opt_grad = self.d_grad(x[i], self.x_star_lst[j])
            rank_grad = ((self.b - dists_opt_full[i, j]) * (temp ** 2) - temp)
            opt_grad = rank_grad * dists_opt_grad
            res += opt_grad / self.s
        return res


class RootProportionalPublishersGame(ProportionalPublishersGame):
    """Root Proportional Publishers' Game."""

    def __init__(self, k, n, s, lam, x_0, x_star_lst, a=0.5, base_distance_func=default_distance_func,
                 distance_gradient_func=default_distance_gradient_func):
        assert 0 < a < 1, f'a={a}'
        g = lambda d: np.power(1 + DELTA_ - d, a)
        super().__init__(k, n, s, lam, x_0, x_star_lst, g, base_distance_func, distance_gradient_func)
        self.a = a

    def get_grad_all(self):
        dists_opt_full = np.array([self.get_dist_from_opt_full(i) for i in self.N])
        dists_opt_root = (1 + DELTA_ - dists_opt_full) ** self.a
        res = np.array([- self.lam * self.d_grad(self.x[i], self.x_0[i]) for i in self.N])  # initial doc gradient
        for j in range(self.s):
            sum_root_d_opt = np.sum(dists_opt_root[:, j])
            temp = 1 / sum_root_d_opt
            dists_opt_grad = np.array([self.d_grad(self.x[i], self.x_star_lst[j]) for i in self.N])
            rank_grad = (self.a * (dists_opt_root[:, j] / (1 + DELTA_ - dists_opt_full[:, j])) *
                         (dists_opt_root[:, j] * (temp ** 2) - temp))
            opt_grad = rank_grad[:, np.newaxis] * dists_opt_grad
            res += opt_grad / self.s
        return res

    def calc_grad(self, x, i):
        dists_opt_full = np.array([[self.calc_dist_from_opt(x[j], k) for k in range(self.s)] for j in self.N])
        dists_opt_root = (1 + DELTA_ - dists_opt_full) ** self.a
        res = - self.lam * self.d_grad(x[i], self.x_0[i])  # initial doc gradient
        for j in range(self.s):
            sum_root_d_opt = np.sum(dists_opt_root[:, j])
            temp = 1 / sum_root_d_opt
            dists_opt_grad = self.d_grad(x[i], self.x_star_lst[j])
            rank_grad = self.a * (dists_opt_root[i, j] / (1 + DELTA_ - dists_opt_full[i, j])) * (dists_opt_root[i, j] *
                                                                                                 (temp ** 2) - temp)
            opt_grad = rank_grad * dists_opt_grad
            res += opt_grad / self.s
        return res


class LogProportionalPublishersGame(ProportionalPublishersGame):
    """Logarithmic Proportional Publishers' Game."""

    def __init__(self, k, n, s, lam, x_0, x_star_lst, c=2 + DELTA_, base_distance_func=default_distance_func,
                 distance_gradient_func=default_distance_gradient_func):
        g = lambda d: np.log(c - d)
        super().__init__(k, n, s, lam, x_0, x_star_lst, g, base_distance_func, distance_gradient_func)
        self.c = c

    def get_grad_all(self):
        dists_opt_full = np.array([self.get_dist_from_opt_full(i) for i in self.N])
        dists_opt_log = np.log(self.c - dists_opt_full)
        res = np.array([- self.lam * self.d_grad(self.x[i], self.x_0[i]) for i in self.N])  # initial doc gradient
        for j in range(self.s):
            sum_log_d_opt = np.sum(dists_opt_log[:, j])
            temp = 1 / sum_log_d_opt
            dists_opt_grad = np.array([self.d_grad(self.x[i], self.x_star_lst[j]) for i in self.N])
            rank_grad = 1 / (self.c - dists_opt_full[:, j]) * (dists_opt_log[:, j] * (temp ** 2) - temp)
            opt_grad = rank_grad[:, np.newaxis] * dists_opt_grad
            res += opt_grad / self.s
        return res

    def calc_grad(self, x, i):
        dists_opt_full = np.array([[self.calc_dist_from_opt(x[j], k) for k in range(self.s)] for j in self.N])
        dists_opt_log = np.log(self.c - dists_opt_full)
        res = - self.lam * self.d_grad(x[i], self.x_0[i])  # initial doc gradient
        for j in range(self.s):
            sum_log_d_opt = np.sum(dists_opt_log[:, j])
            temp = 1 / sum_log_d_opt
            dists_opt_grad = self.d_grad(x[i], self.x_star_lst[j])
            rank_grad = 1 / (self.c - dists_opt_full[i, j]) * (dists_opt_log[i, j] * (temp ** 2) - temp)
            opt_grad = rank_grad * dists_opt_grad
            res += opt_grad / self.s
        return res


class SoftmaxPublishersGame(ProportionalPublishersGame):
    """Softmax Publishers' Game."""

    def __init__(self, k, n, s, lam, x_0, x_star_lst, beta=1, base_distance_func=default_distance_func,
                 distance_gradient_func=default_distance_gradient_func):
        g = lambda d: np.exp(beta * (- d))
        super().__init__(k, n, s, lam, x_0, x_star_lst, g, base_distance_func, distance_gradient_func)
        self.beta = beta
    
    def get_grad_all(self):
        dists_opt_full = np.array([self.get_dist_from_opt_full(i) for i in self.N])
        res = np.array([- self.lam * self.d_grad(self.x[i], self.x_0[i]) for i in self.N])  # initial doc gradient
        for j in range(self.s):
            exp_d = np.exp(self.beta * (- dists_opt_full[:, j]))
            exp_d = exp_d / sum(exp_d)
            dists_opt_grad = np.array([self.d_grad(self.x[i], self.x_star_lst[j]) for i in self.N])
            rank_grad = - exp_d * (1 - exp_d)
            opt_grad = rank_grad[:, np.newaxis] * dists_opt_grad
            res += opt_grad / self.s
        return res
    
    def calc_grad(self, x, i):
        dists_opt_full = np.array([[self.calc_dist_from_opt(x[j], k) for k in range(self.s)] for j in self.N])
        res = - self.lam * self.d_grad(x[i], self.x_0[i])
        for j in range(self.s):
            exp_d = np.exp(self.beta * (- dists_opt_full[:, j]))
            exp_d = exp_d / sum(exp_d)
            dists_opt_grad = self.d_grad(x[i], self.x_star_lst[j])
            rank_grad = - exp_d[i] * (1 - exp_d[i])
            opt_grad = rank_grad * dists_opt_grad
            res += opt_grad / self.s
        return res


GAME_TYPES = {PROPORTIONAL: ProportionalPublishersGame, LINEAR_PROPORTIONAL: LinearProportionalPublishersGame,
              ROOT_PROPORTIONAL: RootProportionalPublishersGame, LOG_PROPORTIONAL: LogProportionalPublishersGame,
              SOFTMAX: SoftmaxPublishersGame}
