import numpy as np
from tqdm import tqdm
from typing import Dict, List
import torch
import pickle
import argparse
import os

from model import NPGame

def parse_args():
    parser = argparse.ArgumentParser(description="N-Player Game Experiment")

    parser.add_argument('--gpu', type=str, default='0', help='GPU number to use')
    parser.add_argument('--n_comm', type=int, default=int(1e2), help='Number of communication rounds')
    parser.add_argument('--n_local_step', type=int, nargs='*', default=None,
                        help='List of local steps, e.g. --n_local_step 1 5 20')
    parser.add_argument('--random_seed', type=int, default=1024, help='Random seed')
    parser.add_argument('--n_dim', type=int, default=10, help='Dimension of the problem')
    parser.add_argument('--n_data', type=int, default=100, help='Number of data samples used to generate objective functions')
    parser.add_argument('--n_batch', type=int, default=100, help='Minibatch size')
    parser.add_argument('--n_trial', type=int, default=1, help='Number of trials (for stochastic setting)')
    parser.add_argument('--n_player', type=int, default=5, help='Number of game players')
    parser.add_argument('--lam', type=float, nargs='*', default=None,
                        help='List of lambda values, e.g. --lam 0 200')
    parser.add_argument('--lr', type=float, default=None, help='SGD stepsize')
    parser.add_argument('--output_file', type=str, default='results/output.pkl', help='Output filename for results')

    parser.add_argument('--L_A', type=float, default=1.0, help='Lipschitz constant of A matrix')
    parser.add_argument('--mu_A', type=float, default=0.01, help='Strong convexity constant of A matrix')
    parser.add_argument('--L_B', type=float, default=1.0, help='Lipschitz constant of B matrix')
    parser.add_argument('--mu_B', type=float, default=0.0, help='Strong convexity constant of B matrix')

    return parser.parse_args()


def main():
    args = parse_args()

    device = torch.device('cuda:' + args.gpu if torch.cuda.is_available() else 'cpu')

    N_COMM = args.n_comm
    N_LOCAL_STEP = args.n_local_step
    RANDOM_SEED = args.random_seed
    N_DIM = args.n_dim
    N_DATA = args.n_data
    N_BATCH = args.n_batch
    N_TRIAL = args.n_trial
    N_PLAYER = args.n_player
    LAM = args.lam

    output_file = args.output_file

    L_A = args.L_A
    mu_A = args.mu_A
    L_B = args.L_B
    mu_B = args.mu_B

    # Seed
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)

    # Initialize game
    game = NPGame(N_PLAYER, N_DIM, N_DATA, L_A, mu_A, L_B, mu_B, device=device)
    Lmax, mu, L = game.Lmax, game.mu, game.L
    ell = (L**2)/mu

    print(f'ell = {ell}')
    print(f'mu = {mu}')
    print(f'Lmax = {Lmax}')
    print(f'L = {L}')
    print(f'Condition Number = {ell/mu}')

    x_start = torch.stack([
        torch.randn(N_DIM, requires_grad=True) for _ in range(N_PLAYER)
    ], dim=0).to(device)
    init_dist = game.opt_dist(x_start)

    relative_errors: Dict[float, Dict[int, List[torch.Tensor]]] = {}

    if LAM is None or len(LAM) == 0:
        LAM = [4 * ell + 4 * Lmax * np.sqrt(ell/mu)]

    for lam in tqdm(LAM):
        relative_errors[lam] = {}

        if args.n_local_step is None:
            N_LOCAL_STEP = [int(max(np.sqrt(4 * ell / lam), 16 * (1 + Lmax / lam) ** 2)) + 1]
        else:
            N_LOCAL_STEP = args.n_local_step

        for n_local_step in tqdm(N_LOCAL_STEP, desc=f"Lambda {lam}"):
            relative_errors[lam][n_local_step] = []
            
            if args.lr is None:
                # Step size choice in Corollary 5.5
                lr = 2 * np.log(n_local_step) / lam / n_local_step
            else:
                lr = args.lr

            for _ in tqdm(range(N_TRIAL), leave=False, desc=f"Tau {n_local_step}"):
                trial_errors = [init_dist / init_dist]
                x = x_start.clone().detach().requires_grad_(True)

                for _ in tqdm(range(N_COMM)):
                    x_new = torch.zeros((N_PLAYER, N_DIM), requires_grad=True).to(device)
                    
                    if N_BATCH < N_DATA:
                        index = [list(np.random.choice(N_DATA, N_BATCH, replace=False)) for _ in range(n_local_step)]

                    for player in range(N_PLAYER):
                        x_local = x.clone().detach().requires_grad_(True)
                        x_init = x.clone().detach()

                        for local_step in range(n_local_step):
                            x_local.grad = None
                            if N_BATCH < N_DATA:
                                loss = game.objective_function(player, x_local, index=index[local_step])
                            else:
                                loss = game.objective_function(player, x_local)
                            loss += (lam / 2) * torch.norm(x_local - x_init, p=2) ** 2
                            loss.backward()
                            with torch.no_grad():
                                x_local[player] -= lr * x_local.grad[player]

                        with torch.no_grad():
                            x_new[player].copy_(x_local[player])

                    with torch.no_grad():
                        x.copy_(x_new)

                    trial_errors.append(game.opt_dist(x) / init_dist)

                print(f"Lambda: {lam}, Tau: {n_local_step} Completed with Error {trial_errors[-1]}")
                relative_errors[lam][n_local_step].append(torch.tensor(trial_errors))

    # Save all results together
    data_to_save = {
        "relative_errors": relative_errors,
        "Lmax": Lmax,
        "mu": mu,
        "ell": ell
    }

    output_dir = os.path.dirname(args.output_file)
    os.makedirs(output_dir, exist_ok=True)

    with open(output_file, 'wb') as f:
        pickle.dump(data_to_save, f)

if __name__ == '__main__':
    main()
