import numpy as np
import torch
import argparse
import random
from tqdm import tqdm
from sparse_max import Sparsemax
import json

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_players', type=int, default=10, help='Number of players')
    parser.add_argument('--prob', type=float, default=0.3, help='Probability of connection')
    parser.add_argument('--action_space', type=int, default=5, help='Maximum action space')

    parser.add_argument('--out_num', type=int, default=1000, help='Number of outputs')

    parser.add_argument('--seed', type=int, default=0, help='Random Seed')
    parser.add_argument('--json', type=str)
    args = parser.parse_args()
    return args

class NoRegret:

    def __init__(self, args):
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

        self.is_dense = False
        if args.prob >= 0.1: 
            self.is_dense = True

        self.num_players = args.num_players  
        self.prob = args.prob 
        self.action_space = args.action_space
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.generate_random_graph(self.num_players, self.prob, self.device)

        self.utility_matrices = torch.rand(size=(self.num_edges * 2, self.action_space, self.action_space)).to(self.device)

        for i in range(self.num_players):
            if self.degrees[i].item() > 0:
                utility_matrix_mean = torch.rand(self.action_space, self.action_space).to(self.device)

                multiplier = torch.min(torch.stack([utility_matrix_mean, 1.0 - utility_matrix_mean]), dim=0)[0]
                bias = (utility_matrix_mean - multiplier).clamp(min=0.0)
                self.utility_matrices[self.player_idx_i==i] = multiplier * self.utility_matrices[self.player_idx_i==i] * 2.0 + bias

        self.utility_matrices = self.utility_matrices * 2.0 - 1.0
        self.strategies = torch.ones(self.num_players, self.action_space).to(self.device) / self.action_space
        
        self.harmonic_mean_degree = 1.0 / torch.mean(1.0 / self.degrees[self.degrees>0].float()).item()
        if np.isnan(self.harmonic_mean_degree):
            self.harmonic_mean_degree = 1.0

        A = args.action_space
        lN = np.log(args.num_players)
        N = args.num_players

        N_max = self.degrees.max().item()
        N_min = max(self.degrees.min().item(), 1)
        
        self.clubsuit = 16 * (A ** 3) * (lN ** 2) / (self.harmonic_mean_degree**(4/9)) / (N_min ** 2) + 4 * A / N
        self.clubsuit = max(self.clubsuit, 4 * A / (self.harmonic_mean_degree**(4/9)), self.clubsuit)
        self.clubsuit = min(self.clubsuit, 2 * A) # upperbound of spadesuit is 2 * A
        
        if self.is_dense: # dense graph
            self.T = max(int(self.harmonic_mean_degree ** (8/9)), 1) 
            self.eta = 1.0 / (self.T * ((self.clubsuit) ** (1/3)))
        else: # sparse graph
            self.T = int((1- np.log(lN) / lN) * lN / np.log(N_max))
        if N_max > 1:
            self.spadesuit = 2 * A / N * min(2 * self.get_pow(N_max, self.T, N * (N_max - 1) * 20) / (N_max - 1), N)
        else:
            self.spadesuit = 2 * A / N * min(2 * self.T, N)
        if not self.is_dense:
            self.eta = 1.0 / (self.T * ((self.spadesuit) ** (1/3)))
        
        self.sigma = 0.3 / np.sqrt(self.T)

        self.epsilon = (self.eta ** 2) / (self.sigma ** 2) * np.min([self.spadesuit, self.clubsuit]) * self.T

        self.tau = self.harmonic_mean_degree ** (5.0 / 9.0) / self.degrees.float().clamp(min=1.0) / lN
        self.out_freq = max(self.T // args.out_num, 1)

        print(f"T: {self.T}")
        print(f"Learning Rate: {self.eta}")
        print(f"Noise: {self.sigma}")
        print(f"Regularizer: {torch.mean(self.tau).item()}")
        print(f"Clubsuit: {self.clubsuit}")
        print(f"Spadesuit: {self.spadesuit}")
        print(f"Epsilon: {self.epsilon}")
        print("=====================================")

    def get_pow(self, x, y, bound):
        if np.log(x) * y > np.log(bound):
            return bound
        return x ** y

    def generate_random_graph(self, num_nodes, p, device):
        self.adjacency_matrix = None
        if self.is_dense:
            upper_triu = torch.rand(num_nodes, num_nodes).to(device)
            upper_triu = torch.triu(upper_triu, diagonal=1)
            
            self.adjacency_matrix = (upper_triu > 1-p).float()
            
            self.adjacency_matrix = self.adjacency_matrix + self.adjacency_matrix.t()
            
            self.num_edges = torch.sum(self.adjacency_matrix).to(int).item() // 2

            self.edge_idx = self.adjacency_matrix.reshape(-1).cumsum(0) - 1
            self.edge_idx = self.edge_idx.reshape([self.num_players, self.num_players])
            self.edge_idx[self.adjacency_matrix == 0] = -1

            self.player_idx_i, self.player_idx_j = torch.where(self.edge_idx != -1)
            self.degrees = self.adjacency_matrix.sum(dim=-1)

        else:
            edges = []
            c = int(p * num_nodes) * 2
    
            for i in range(num_nodes):
                edges += [i for _ in range(c)]

            edges = torch.tensor(edges, device=device)
            perm = torch.randperm(edges.shape[0], device=device)
            edges = edges[perm].reshape([2, -1])
            
            
            mask = edges[0] != edges[1]
            edges = edges[:, mask]
            
            edges = torch.sort(edges, dim=0)[0]
            
            edges = torch.unique(edges, sorted=True, dim=1)

            self.num_edges = edges.shape[1]
            self.player_idx_i = torch.cat([edges[0], edges[1]], dim=0)
            self.player_idx_j = torch.cat([edges[1], edges[0]], dim=0)
            
            values = torch.ones(edges.size(1), device=device)
            self.adjacency_matrix = torch.sparse.FloatTensor(edges, values, (num_nodes, num_nodes))
            
            self.adjacency_matrix = self.adjacency_matrix + self.adjacency_matrix.t()

            self.degrees = self.adjacency_matrix.sum(dim=-1).to_dense()

    def compute_avg_gradients(self, strategies):
        gradients = torch.bmm(self.utility_matrices, strategies[self.player_idx_j].unsqueeze(-1)).squeeze(-1)
        gradients_sum = torch.zeros_like(strategies)
        gradients_sum.index_add_(0, self.player_idx_i, gradients)
        gradients_sum = gradients_sum / self.degrees.unsqueeze(-1).clamp(min=1)

        return gradients_sum
    
    def simplex_projection(self, v):
        return self.sparsemax(v)

    def train(self):

        self.sparsemax = Sparsemax() 
        self.avg_strategies = torch.zeros_like(self.strategies)

        args_dict = {k: v for k, v in vars(args).items() if v is not None}
        args_dict['epsilon'] = self.epsilon
        self.output = [args_dict]

        sum_gradients = 0
        sum_utilities = 0

        with torch.no_grad():
            for t in tqdm(range(1, self.T+1)):

                gradients = self.compute_avg_gradients(self.strategies)
                sum_gradients = sum_gradients + gradients

                self.strategies = self.simplex_projection((self.strategies + self.eta * gradients) / (1 + self.eta * self.tau).unsqueeze(-1))
                sum_utilities = sum_utilities + (self.strategies * gradients).sum(dim=-1)

                noise = torch.randn_like(self.strategies) * self.sigma

                self.strategies = self.simplex_projection(self.strategies + noise)
                self.avg_strategies += self.strategies

                if t % self.out_freq == 0:
                    max_utility = torch.max(sum_gradients, dim=-1)[0]
                    exp = (max_utility - sum_utilities).clamp(min=0.0).mean().item() / t
                    self.output.append({'time': t, 'exp': exp})
                    tqdm.write(f"Episode {t-1} Exploitability: {exp}")
        if args.json:
            with open(args.json, 'w') as f:
                json.dump(self.output, f)


if __name__ == '__main__':
    print(torch.cuda.is_available())
    args = parse_args()

    alg = NoRegret(args)
    alg.train()