# Import packages
import torch
import numpy as np
from typing import List    

class NPGame:
    def __init__(self, n_player: int, n_dim: int, n_data: int, 
                 L_A: float, mu_A: float,  
                 L_B: float, mu_B: float,
                 device: torch.device):
        """
        Initializes a N-player game with n_player players, each player having n_dim dimensions, 
        and n_data instances. The game is defined by matrices A, B, and vectors a.

        Parameters:
        n_player (int): The number of players in the game.
        n_dim (int): The dimensionality of the strategy space for each player.
        n_data (int): The number of instances or scenarios in the game.
        mu_A (float): The lower bound for the eigenvalues of matrix A.
        L_A (float): The upper bound for the eigenvalues of matrix A.

        mu_B (float): The lower bound for the eigenvalues of matrix B.
        L_B (float): The upper bound for the eigenvalues of matrix B.
        device (torch.device): The device to store the tensors.

        Attributes:
        A (torch.Tensor): A tensor of shape (n_player, n_data, n_dim, n_dim) containing the matrices A for each player and instance.
        B (torch.Tensor): A tensor of shape (n_player, n_player, n_data, n_dim, n_dim) containing the matrices B for each pair of players and instance.
        a (torch.Tensor): A tensor of shape (n_player, n_data, n_dim) containing the vectors a for each player and instance.
        n_player (int): The number of players in the game.
        device (torch.device): The device to store the tensors.
        x_optimal (torch.Tensor): A tensor of shape (n_player, n_dim) containing the optimal strategy for each player.
        """
        
        self.n_player = n_player
        self.device = device
        def generate_mx(mu:float, L:float, dim:int) -> np.ndarray:
            evalues = np.random.uniform(mu, L, dim)
            evalues[0], evalues[-1] = mu, L
            rndm_mx = np.random.normal(0, 1, (dim, dim))
            _, Q = np.linalg.eig(rndm_mx.T @ rndm_mx)
            return Q @ np.diag(evalues) @ Q.T

        # Initialize the tensors of matrices and vectors with zeros
        A, B = torch.zeros(n_player, n_data, n_dim, n_dim), torch.zeros(n_player, n_player, n_data, n_dim, n_dim) 
        a = torch.zeros(n_player, n_data, n_dim)

        # Store the matrices and vectors
        for idx in range(n_data):
            for player in range(n_player):
                A[player, idx] = torch.tensor(generate_mx(mu_A, L_A, n_dim)).to(self.device)
                a[player, idx] = torch.tensor(np.random.normal(0, 1,n_dim)).to(self.device)

            for player1 in range(n_player):
                for player2 in range(player1+1, n_player):
                    B[player1, player2, idx] = torch.tensor(generate_mx(mu_B, L_B, n_dim)).to(self.device)
                    B[player2, player1, idx] = - B[player1, player2, idx].transpose(-2, -1)

        # Compute the full M matrix and z vector
        M = torch.zeros(n_player * n_dim, n_player * n_dim)
        z = torch.zeros(n_player * n_dim)
        
        for player in range(n_player):
            z[player * n_dim:(player + 1) * n_dim] = torch.mean(a[player], dim=0)
            M[player * n_dim:(player + 1) * n_dim, player * n_dim:(player + 1) * n_dim] = torch.mean(A[player], dim=0)
            for player2 in range(n_player):
                if player != player2:
                    M[player * n_dim:(player + 1) * n_dim, player2 * n_dim:(player2 + 1) * n_dim] = torch.mean(B[player, player2], dim=0)

        # Solve M @ x_optimal = -z
        x_optimal_flat = torch.linalg.solve(M, -z)

        # Reshape x_optimal to (n_player, n_dim) shape
        x_optimal = x_optimal_flat.view(n_player, n_dim)

        # Compute Problem Constant
        
        # Compute strong monotonicity constant mu, L based on eigenvalues of the mean of A for each player
        mu_values, L_values = [], []
        for i in range(n_player):
            mean_A = torch.mean(A[i], dim=0)  # Average A over data points for each player
            mu_values.append(torch.linalg.matrix_norm(mean_A, ord = -2))  # Take the smallest real eigenvalue
            L_values.append(torch.linalg.matrix_norm(mean_A, ord = 2))
        self.mu = min(mu_values)
        self.Lmax = max(L_values)
        
        # Compute Lipschitz constant L
        self.L = torch.linalg.matrix_norm(M, ord = 2)
        
        # Pass to self
        self.M, self.z = M, z
        self.A, self.B, self.a = A.to(self.device), B.to(self.device), a.to(self.device) 
        self.x_optimal = x_optimal.to(self.device)
        
    def objective_function(self, player: int, x: torch.Tensor,
                           index: List[int] = None) -> torch.Tensor:
        """
        Computes the objective function for a specific player in a N-player game.

        Parameters:
        player (int): The index of the player for whom the objective function is computed.
        x (torch.Tensor): The tensor representing the strategy of all players.
        index (List[int], optional): A list of indices to compute the objective function for specific data points. 
                                     If not provided, the function computes the full objective function.

        Returns:
        torch.Tensor: The computed objective function value for the specified player.
        """
        # pass to correct device
        x = x.to(self.device)
        if index is not None:
            coupling_term  = 0
            for player2 in range(self.n_player):
                if player2 != player:
                    coupling_term += torch.t(x[player]) @ torch.mean(self.B[player][player2][index], dim = 0) @ x[player2]

            return (.5 * torch.t(x[player]) @ torch.mean(self.A[player][index], dim = 0) @ x[player]
                    + torch.t(torch.mean(self.a[player][index], dim = 0)) @ x[player]
                    + coupling_term)

        else:
            coupling_term  = 0
            for player2 in range(self.n_player):
                if player2 != player:
                    coupling_term += torch.t(x[player]) @ torch.mean(self.B[player][player2], dim = 0) @ x[player2]

            return (.5 * torch.t(x[player]) @ torch.mean(self.A[player], dim = 0) @ x[player]
                    + torch.t(torch.mean(self.a[player], dim = 0)) @ x[player]
                    + coupling_term)
            
        
    def opt_dist(self, x: torch.Tensor) -> torch.Tensor:
        """
        Calculates the optimal distance between a given strategy tensor (x) and the optimal strategy tensor (x_optimal) in the context of a N-player game.

        Parameters:
        x (torch.Tensor): The tensor representing the strategy of all players.

        Returns:
        torch.Tensor: The optimal distance between the given strategy tensor and the optimal strategy tensor.
        """
        # pass to correct device
        x = x.to(self.device)
        return torch.norm(x - self.x_optimal) ** 2
            