import json
import os
import math
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F

from tqdm import tqdm

def compute_g2(train_clusters, delta, K, C):
    n_T = len(train_clusters.keys())
    N = 0
    sum_sqrt = 0
    for cluster_id in train_clusters.keys():   
        Ni = len(train_clusters[cluster_id])
        N += Ni
        sum_sqrt += math.sqrt(Ni)
    
    ln = math.log(2 * K / delta)
    term = n_T * ln / N
    
    return C * ((math.sqrt(2)+1) * math.sqrt(ln) * sum_sqrt / N + 4 * term)

def compute_ab(train_clusters, delta, K):
    N = 0
    a = 0
    for cluster_id in train_clusters.keys():
        Ni = len(train_clusters[cluster_id])
        a += (Ni*Ni)
        N += Ni
    a /= (2 * N*N)
    ssq = a*2
    a += math.sqrt(2/N * math.log(2*K/delta))
    return a, 1/(2*N), ssq

def compute_g(train_clusters, delta, K, C):
    n_T = len(train_clusters.keys())
    N = 0
    for cluster_id in train_clusters.keys():        
        N += len(train_clusters[cluster_id])
    
    ln = math.log(2 * K / delta)
    term = n_T * ln / N
    
    return C * ((math.sqrt(2)+1) * math.sqrt(term) + 2 * term)
    
seed_s = []
K_s = []
delta_s = []
alpha_s = []
gamma_s = []
gamma_pow_neg_alpha_s = []
a_s = []
b_s = []
g_s = []
ssq_s = []
u_s = []
term2_s = []
g2_s = []

delta = 0.01
C = 1
for seed in range(42, 47):
    for K in [100, 200, 300, 1000, 5000, 10000]:
        with open(f"imagenet-train-clusters/seed_{seed}/{K}/train_group.json") as f:
            train_clusters = json.load(f)
        g = compute_g(train_clusters, delta/2, K, C)
        g2 = compute_g2(train_clusters, delta/2, K, C)
        a, b, ssq = compute_ab(train_clusters, delta=delta, K=200)

        for gamma_pow_neg_alpha in [0.01, 0.02, 0.03, 0.04]:
            for alpha in range(5, 105, 5):
                seed_s.append(seed)
                K_s.append(K)
                delta_s.append(delta)
                alpha_s.append(alpha)

                gamma = 1 / math.pow(gamma_pow_neg_alpha, 1/alpha)
                u = a * (gamma*gamma) + b*gamma
                term2 = C * math.sqrt(u * math.log(1 / gamma_pow_neg_alpha))

                gamma_s.append(gamma)
                gamma_pow_neg_alpha_s.append(gamma_pow_neg_alpha)
                a_s.append(a)
                b_s.append(b)
                g_s.append(g)
                ssq_s.append(ssq)
                u_s.append(u)
                term2_s.append(term2)
                g2_s.append(g2)
                
df = pd.DataFrame.from_dict({
    "seed": seed_s,
    "K": K_s,
    "delta": delta_s,
    "alpha": alpha_s,
    "gamma": gamma_s,
    "gamma^-alpha": gamma_pow_neg_alpha_s,
    "a": a_s,
    "b": b_s,
    "g": g_s,
    "ssq_ni/n": ssq_s,
    "u": u_s,
    "term2": term2_s,
    "term2+g": np.array(term2_s) + np.array(g_s),
    "g2": g2_s,
    "term2+g2": np.array(term2_s) + np.array(g2_s),
})

df.to_csv('total_results_neurips.csv', index=False)