import os
import torch
import numpy as np
import bounds.get_pac_bounds as get_bounds
import yaml



### Defining hypers
best_checkpoint_path="[TODO]"
intrinsic_dim=50000
batch_size = 8
total_ddp = 8
seq_length = 1024.0
eval_after_training = False
bound_type = "bpd"
vocab_size = 50257
alpha_array = [0., 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5]
###############

if eval_after_training:
    
    if bound_type == "acc":
        all_dict = {}

        for ddp_rank in range(total_ddp):
            dict_top_k = {}
            for k in range(1,10+1):
                dict_top_k[f'top_{k}_acc'] = 0
            dict_top_k[f'top_50_acc'] = 0
            dict_top_k[f'top_100_acc'] = 0
            dict_top_k["n_train"] = 0
            filename=os.path.join(best_checkpoint_path, f"train_indices_{ddp_rank}_top_k_indices.txt")
            with open(filename) as file:
                for i, line in enumerate(file):
                    str_train_set_batch_indices = line[1:-2].strip().split(",")
                    idx = [int(x) for x in str_train_set_batch_indices]
                    if len(idx)/batch_size == seq_length:
                        unique_indices, indices_counts = torch.unique(torch.tensor(idx),return_counts=True)

                        for k in range(1,10+1):
                            top_k_acc = (sum(indices_counts[0:k])/seq_length).item() 
                            dict_top_k[f'top_{k}_acc'] = (dict_top_k[f'top_{k}_acc'] * dict_top_k["n_train"] + top_k_acc) / (dict_top_k["n_train"] + batch_size)

                        top_50_acc = (sum(indices_counts[0:50])/seq_length).item()
                        dict_top_k[f'top_50_acc'] = (dict_top_k[f'top_50_acc'] * dict_top_k["n_train"] + top_50_acc) / (dict_top_k["n_train"] + batch_size)

                        top_100_acc = (sum(indices_counts[0:100])/seq_length).item()
                        dict_top_k[f'top_100_acc'] = (dict_top_k[f'top_100_acc'] * dict_top_k["n_train"] + top_100_acc) / (dict_top_k["n_train"] + batch_size)


                        dict_top_k["n_train"] += batch_size

                    all_dict[f"dict_top_k_ddp{ddp_rank}"] = dict_top_k
                    with open(os.path.join(best_checkpoint_path, 'all_dict_top_k_acc.yml'), 'w') as f:
                        yaml.safe_dump(all_dict, f, indent=2)

                    if i % 1000 == 0:
                        print("All dict: \n", all_dict)
    else:
        all_dict = {}
        for ddp_rank in range(total_ddp):
            dict_bdp = {}
            for alpha in alpha_array:
                dict_bdp[f'bpd_alpha_{alpha}'] = 0
            dict_bdp["n_train"] = 0
            filename=os.path.join(best_checkpoint_path, f"train_indices_{ddp_rank}_selected_log_prob_scores.txt")
            with open(filename) as file:
                for i, line in enumerate(file):
                    right_batch = False
                    str_log_probs = line[1:-2].strip().split(",")
                    for alpha in alpha_array:
                        log_probs = [np.log((1-alpha)*np.exp(float(x)) + alpha/vocab_size) for x in str_log_probs]
                        if len(log_probs)/batch_size == seq_length:
                            right_batch = True
                            bdp_alpha = - sum(log_probs)/seq_length

                            dict_bdp[f'bpd_alpha_{alpha}'] = float((dict_bdp[f'bpd_alpha_{alpha}'] * dict_bdp["n_train"] + bdp_alpha) / (dict_bdp["n_train"] + batch_size))
                        
                    if right_batch:
                        all_dict[f"dict_bdp_ddp{ddp_rank}"] = dict_bdp

                        dict_bdp["n_train"] += batch_size

                        with open(os.path.join(best_checkpoint_path, 'all_dict_bdp.yml'), 'w') as f:
                            yaml.safe_dump(all_dict, f, indent=2)

                        if i % 1000 == 0:
                            print("All dict: \n", all_dict)

else:
    if bound_type == "acc": 
        with open(os.path.join(best_checkpoint_path, 'all_dict_top_k_acc.yml'), 'r') as f:
            all_dict = yaml.safe_load(f)

        dict_top_k = {}
        for k in range(1,10+1):
            dict_top_k[f'top_{k}_acc'] = 0
        dict_top_k[f'top_50_acc'] = 0
        dict_top_k[f'top_100_acc'] = 0
        dict_top_k["n_train"] = 0

        for i in range(8):
            for k in range(1,10+1):
                dict_top_k[f'top_{k}_acc'] += all_dict[f'dict_top_k_ddp{i}'][f'top_{k}_acc'] * all_dict[f'dict_top_k_ddp{i}']["n_train"]
            dict_top_k[f'top_50_acc'] += all_dict[f'dict_top_k_ddp{i}']["top_50_acc"] * all_dict[f'dict_top_k_ddp{i}']["n_train"]
            dict_top_k[f'top_100_acc'] += all_dict[f'dict_top_k_ddp{i}']["top_100_acc"] * all_dict[f'dict_top_k_ddp{i}']["n_train"]
            dict_top_k["n_train"] += all_dict[f'dict_top_k_ddp{i}']["n_train"] 

        for k in range(1,10+1):
            dict_top_k[f'top_{k}_acc'] /= dict_top_k["n_train"]
        dict_top_k[f'top_50_acc'] /= dict_top_k["n_train"]
        dict_top_k[f'top_100_acc'] /= dict_top_k["n_train"]


        top_k_bounds = {}

        for k,v in dict_top_k.items():

            quant_train_acc = v

            misc_extra_bits = 4

            prefix_message_len = torch.load(os.path.join(best_checkpoint_path, "quant_ckpt.pt"))['prefix_message_len']

            divergence = (prefix_message_len + misc_extra_bits) * np.log(2)
            train_size = dict_top_k['n_train']
            top_k_bounds[k] = get_bounds.compute_catoni_bound(train_error=1. - quant_train_acc, divergence=divergence,
                                             sample_size=train_size)

        print("top_k_bounds: ", top_k_bounds)
        
        
    else:
        
        ddp_rank = 0
        
        bounds = {}
        
        misc_extra_bits = 8
        
        with open(os.path.join(best_checkpoint_path, 'all_dict_bdp.yml'), 'r') as f:
            all_dict = yaml.safe_load(f)
            
        prefix_message_len = torch.load(os.path.join(best_checkpoint_path, "quant_ckpt.pt"))['prefix_message_len']
        
        for alpha in alpha_array[1:]:
            
            normalizing_factor = np.log(vocab_size/alpha)
            
            train_error = all_dict[f"dict_bdp_ddp{ddp_rank}"][f"bpd_alpha_{alpha}"] / normalizing_factor
            
            train_size = all_dict[f"dict_bdp_ddp{ddp_rank}"]["n_train"]
            
            divergence = (prefix_message_len + misc_extra_bits) * np.log(2)
            
            bound = get_bounds.compute_catoni_bound(train_error=train_error, divergence=divergence,
                                                    sample_size=train_size) * normalizing_factor
            
            
            bounds[f"bound_alpha_{alpha}"] = bound 
            
            print(f"For alpha = {alpha}, the error bound is: {bound} vs. the random guess error is: {np.log(vocab_size)}")
            
            if bound < np.log(vocab_size):
                print("HENCE, THE BOUND IS NON-VACUOUS")
            else:
                print("HENCE, THE BOUND IS VACUOUS")
        

    
                
