import torch
import numpy as np
import networkx as nx
import FederatedBanditAgent as fba
import FederatedBanditEnvironment as fbe
from tqdm import tqdm
from torch.utils.data import DataLoader
import os
import pickle
import matplotlib.pyplot as plt

seed = 5
n_agents = 16
n_arms = 20
horizon = 3000

os.environ["PYTHONHASHSEED"] = str(seed)

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if os.path.exists('SyntheticDatasets.pth'):
    train_data = torch.load('SyntheticDatasets.pth', weights_only=False)
else:
    train_data = fbe.SyntheticDatasets(
        horizon,
        n_agents,
        n_arms,
        np.random.default_rng(42)
    )
    torch.save(train_data, 'SyntheticDatasets.pth')

train_loader = DataLoader(
    train_data,
    batch_size=1,
    shuffle=False
)
best_cum_loss = train_data.cumloss_of_best_arm()
path = 'SyntheticDatasets ' + 'Grid' + '.gpickle'
if os.path.exists('path'):
    with open(path, 'rb') as f:
        graph = pickle.load(f)
else:
    graph = nx.grid_graph([
        int(np.sqrt(n_agents)),
        int(np.sqrt(n_agents))
    ])
    with open(path, "wb") as f:
        pickle.dump(graph, f)

comm_net = fba.CommNet(graph)
D = comm_net.get_diameter()
gossip_numpy, spectral_gap = comm_net.max_deg_gossip(spectral_gap=True)
gossip = torch.tensor(gossip_numpy, device=device)

FedExp3_agent = fba.FedExp3(
    n_agents, n_arms, gossip, device=device,
    rng=torch.Generator(device=device).manual_seed(seed),
    expr_scheduler=fba.cube_FedExp3_scheduler(spectral_gap, horizon, n_agents),
)

FedFTRL_agent = fba.FedFTRL(
    n_agents, n_arms, W=gossip, graph=graph,
    lr=fba.cube_FedFTRL_lr(n_agents, D),
    gamma=fba.cube_FedFTRL_gamma(D, n_arms, spectral_gap, n_agents),
    D=D, device=device, rng=torch.Generator(device=device).manual_seed(seed),
    spectral_gap=spectral_gap
)

FTRL_agent = fba.FTRL(
    n_agents, n_arms, W=gossip, graph=graph,
    lr=fba.cube_FTRL_lr(n_agents, D),
    D=D, device=device,
    rng=torch.Generator(device=device).manual_seed(seed), spectral_gap=spectral_gap
)

Gossip_UCB = fba.Gossip_UCB(
    n_agents, n_arms, graph=graph, T=horizon, W=gossip, device=device
)

Drrb_bandit_agent = fba.DRRB_bandit(
    n_agents, n_arms, graph=graph, delta=0.01, T=horizon,
    W=gossip, device=device, spectral_gap=spectral_gap, D=D
)

Gossip_UCB_regrets, FTRL_regrets, FedExp3_regrets, FedFTRL_regrets, Drrb_bandit_regrets= [], [], [], [], []

Gossip_UCB_cum_loss = torch.zeros(n_agents, device=device)
FedExp3_cum_loss = torch.zeros(n_agents, device=device)
FedFTRL_cum_loss = torch.zeros(n_agents, device=device)
FTRL_cum_loss = torch.zeros(n_agents, device=device)
Drrb_bandit_cum_loss = torch.zeros(n_agents, device=device)

rounds = len(train_loader)
for i, loss_matrix in tqdm(enumerate(train_loader), total=rounds):
    L_t = torch.squeeze(loss_matrix, 0).to(device)

    FedExp3_actions, FedExp3_probs = FedExp3_agent.action()
    FedFTRL_actions, FedFTRL_probs = FedFTRL_agent.action()
    FTRL_actions, FTRL_probs = FTRL_agent.action()
    Gossip_UCB_actions = Gossip_UCB.action()
    Drrb_bandit_actions = Drrb_bandit_agent.action()

    FedExp3_cum_loss += torch.matmul(
        torch.mean(L_t, dim=0),
        torch.transpose(FedExp3_actions.float(), 1, 0)
    )
    FedFTRL_cum_loss += torch.matmul(
        torch.mean(L_t, dim=0),
        torch.transpose(FedFTRL_actions.float(), 1, 0)
    )
    FTRL_cum_loss += torch.matmul(
        torch.mean(L_t, dim=0),
        torch.transpose(FTRL_actions.float(), 1, 0)
    )
    Gossip_UCB_cum_loss += torch.matmul(
        torch.mean(L_t, dim=0),
        torch.transpose(Gossip_UCB_actions.float(), 1, 0)
    )
    Drrb_bandit_cum_loss += torch.matmul(
        torch.mean(L_t, dim=0),
        torch.transpose(Drrb_bandit_actions.float(), 1, 0)
    )

    FedExp3_agent.update(L_t, FedExp3_actions, FedExp3_probs)
    FedFTRL_agent.update(L_t, FedFTRL_actions, FedFTRL_probs)
    FTRL_agent.update(L_t, FTRL_actions, FTRL_probs)
    Gossip_UCB.update(1 - L_t, Gossip_UCB_actions)
    Drrb_bandit_agent.update(1 - L_t, Drrb_bandit_actions)

    FedExp3_regrets.append(torch.mean(FedExp3_cum_loss).cpu().numpy() - best_cum_loss[i])
    FedFTRL_regrets.append(torch.mean(FedFTRL_cum_loss).cpu().numpy() - best_cum_loss[i])
    FTRL_regrets.append(torch.mean(FTRL_cum_loss).cpu().numpy() - best_cum_loss[i])
    Gossip_UCB_regrets.append(torch.mean(Gossip_UCB_cum_loss).cpu().numpy() - best_cum_loss[i])
    Drrb_bandit_regrets.append(torch.mean(Drrb_bandit_cum_loss).cpu().numpy() - best_cum_loss[i])

xs = np.arange(1, rounds + 1)
plt.figure(figsize=(10, 6))
plt.plot(xs, FedExp3_regrets, 'b-', linewidth=2, label='FedExp3')
plt.plot(xs, FedFTRL_regrets, 'r-', linewidth=2, label='FedFTRL')
plt.plot(xs, FTRL_regrets, 'g-', linewidth=2, label='FTRL')
plt.plot(xs, Gossip_UCB_regrets, 'y-', linewidth=2, label='Gossip_UCB')
plt.plot(xs, Drrb_bandit_regrets, 'c-', linewidth=2, label='Drbb-bandit')
plt.xlabel('Rounds', fontsize=12)
plt.ylabel('Average Regret', fontsize=12)
plt.xlim(0, 3000)
plt.ylim(0, 120)
plt.title(f'average cumulative regret\nenv: Synthetic Datasets, graph: Grid', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()