import os
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

import bounds.get_pac_bounds as get_bounds
import yaml

from model import GPTConfig, GPT
import loralib as lora
import datetime
import itertools
from projectors import create_intrinsic_model
import bounds.quantize_fns as quantize
import bounds.get_pac_bounds as get_bounds

from torch.optim import SGD, Adam
from tqdm.auto import tqdm

import projectors

import ast

from datasets import load_dataset
from transformers import AutoTokenizer, PretrainedConfig, default_data_collator, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification, GPT2ForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader



# Set EOT tokens for identifying a single document within openwebtext.
# {'<|endoftext|>': 50256}
EOT_TOKEN = 50256
create_new_output_dir = False

eval_interval = 25
log_interval = 1
eval_iters = 200
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = True # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
eval_metrics_after_training = False # if True, evaluate metrics for LLM pretraining bounds
best_checkpoint_path = None
# wandb logging
wandb_log = False # disabled by default
wandb_project = 'owt'
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'openwebtext'
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
# model
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
bias = False # do we use bias inside LayerNorm and Linear layers?
use_mergedlinear = False
# adamw optimizer
learning_rate = 6e-4 # max learning rate
max_iters = 600000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
correct_bias = False
adam_epislon = 1e-8
no_decay_bias = True
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# TODO: check lora setting

# ===== LoRA Settings ===== #
# 0. Basics
use_lora=False # true if any LoRA layer is used
lora_alpha = 32
lora_dropout = 0.1
# fan_in_fan_out=True
# merge_weights=False

# 1. attention linear layer
attention_linear_use_lora = False
attention_linear_lora_r = 4
# attention_linear_enable_lora=(True, False, True)

# 2. linear head
linear_head_lora_r = 4
linear_head_enable_lora = False

# 3. MLP
MLP_lora_r = 1
MLP_enable_lora = False

# 4. token embedding
token_embedding_lora_r = 1
token_embedding_enable_lora = False

# 5. positional embedding
positional_embedding_lora_r = 1
positional_embedding_enable_lora = False

## Quantization parameters 
intrinsic_dim=100000
use_kmeans=False
quant_lr=5e-5

## Bounds params
vocab_size = 50257
random_sampling_with_replacement = True
# bound_samples = 10000
sample_size = 10000
levels = 11
max_quant_iters = 0
data_size = 8823811
failure_prob=0.05

share_rank=True
rank_lora=4

optimize_alpha=0
apply_rope=0
extended_proj=0

test_old_ckpts=0

alpha_warmup = 0

# vera
use_vera = False

### for perturbations 
perturb_word_order_window_size=0
apply_perturb_eval = 0

# GLUE
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


task_name = None
load_pretrained_model = 1



################ 1. Getting the config ###################

# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str, list))]
exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging

ddp = False
ddp_rank = 0
master_process = True
seed_offset = 0
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")

config["linear_head_lora_r"] = attention_linear_lora_r
linear_head_lora_r = attention_linear_lora_r


### Change lora config here to train without lora if the rank for both attention and head = 0 
if attention_linear_lora_r == 0:
    attention_linear_use_lora = False
    config["attention_linear_use_lora"] = attention_linear_use_lora
    
if linear_head_lora_r == 0:
    linear_head_enable_lora = False
    config["linear_head_enable_lora"] = linear_head_enable_lora
    
if not (attention_linear_use_lora or linear_head_enable_lora or MLP_enable_lora or token_embedding_enable_lora or positional_embedding_enable_lora):
    use_lora = False 
    config["use_lora"] = use_lora
#####


#### setting the ranks
if share_rank:
    config["attention_linear_lora_r"] = rank_lora
    config["linear_head_lora_r"] = rank_lora
    attention_linear_lora_r = rank_lora
    linear_head_lora_r = rank_lora
    
if extended_proj:
    config["bias"] = True
    bias = True
    intrinsic_mode = "filmrdkronqr"
    if optimize_alpha:
        batch_size = 4
        config["batch_size"] = batch_size
else:
    intrinsic_mode = "rdkronqr"



# if master_process:
#     os.makedirs(out_dir, exist_ok=True)
# torch.manual_seed(1337 + seed_offset)
torch.manual_seed(137)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)


################ 2. Getting the data ###################


raw_datasets = load_dataset("glue", task_name, cache_dir=f'/scratch/{os.getenv("USER")}/cache')
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
sentence1_key, sentence2_key = task_to_keys[task_name]

gptconfig = AutoConfig.from_pretrained('gpt2', num_labels=num_labels, finetuning_task=task_name, cache_dir=f'/scratch/{os.getenv("USER")}/cache',)
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True, cache_dir=f'/scratch/{os.getenv("USER")}/cache')
tokenizer.pad_token = tokenizer.eos_token



if load_pretrained_model:
    model = AutoModelForSequenceClassification.from_pretrained(
        'gpt2',
        from_tf=bool(".ckpt" in 'gpt2'),
        config=gptconfig,
        cache_dir=f'/scratch/{os.getenv("USER")}/cache',
    )
else:
    model = GPT2ForSequenceClassification(config=gptconfig)

model.config.pad_token_id = model.config.eos_token_id
print(f"trainable params (M) before applying lora = {model.num_parameters(only_trainable=True)/1e6} M")

peft_config = LoraConfig(
    r = 8, 
    target_modules=['c_attn', 'score'],
    lora_alpha=32, 
    lora_dropout=0.1,
)

model = get_peft_model(model, peft_config)

print(f"trainable params (M) after applying lora = {model.num_parameters(only_trainable=True)}")
    

iter_num = 0
best_val_loss = 1e9


# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=vocab_size, dropout=dropout,
                  use_lora=use_lora, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                  attention_linear_use_lora=attention_linear_use_lora, attention_linear_lora_r=attention_linear_lora_r, # attention_linear_enable_lora=attention_linear_enable_lora,
                  linear_head_lora_r=linear_head_lora_r, linear_head_enable_lora=linear_head_enable_lora,
                  MLP_lora_r = MLP_lora_r, MLP_enable_lora = MLP_enable_lora,
                  token_embedding_lora_r=token_embedding_lora_r, token_embedding_enable_lora=token_embedding_enable_lora,
                  positional_embedding_lora_r=positional_embedding_lora_r,
                  positional_embedding_enable_lora=positional_embedding_enable_lora,
                  use_mergedlinear=use_mergedlinear,intrinsic_dim=intrinsic_dim, 
                  sample_size=sample_size, data_size=data_size, failure_prob=failure_prob,
                  optimize_alpha=optimize_alpha, apply_rope=apply_rope, extended_proj=extended_proj,
                  test_old_ckpts=test_old_ckpts, use_vera=use_vera,
                  ) # start with model_args from command line

            
print(f"loading best training checkpoint from {best_checkpoint_path} for pretraining bound metrics eval")



if alpha_warmup:
    if use_vera:
        wandb_run_name = "alpha_warmup_id{}_lr{}_vera_r{}_optimalpha{}_rope{}_extproj{}".format(config["intrinsic_dim"],
                                                                                config["learning_rate"],
                                                                                config["rank_lora"],
                                                                                config["optimize_alpha"],
                                                                                config["apply_rope"],
                                                                                config["extended_proj"],
                                                                                )

    else:
        wandb_run_name = "alpha_warmup_id{}_lr{}_r{}_optimalpha{}_rope{}_extproj{}".format(config["intrinsic_dim"],
                                                                                config["learning_rate"],
                                                                                config["rank_lora"],
                                                                                config["optimize_alpha"],
                                                                                config["apply_rope"],
                                                                                config["extended_proj"],
                                                                                )
    
else:
    if use_vera:
        wandb_run_name = "weighted_loss_id{}_lr{}_vera_r{}_optimalpha{}_rope{}_extproj{}".format(config["intrinsic_dim"],
                                                                                config["learning_rate"],
                                                                                config["rank_lora"],
                                                                                config["optimize_alpha"],
                                                                                config["apply_rope"],
                                                                                config["extended_proj"],
                                                                                )
    else:
        wandb_run_name = "weighted_loss_id{}_lr{}_r{}_optimalpha{}_rope{}_extproj{}".format(config["intrinsic_dim"],
                                                                                config["learning_rate"],
                                                                                config["rank_lora"],
                                                                                config["optimize_alpha"],
                                                                                config["apply_rope"],
                                                                                config["extended_proj"],
                                                                                )
config["wandb_run_name"] = wandb_run_name

# logging
if wandb_log and master_process:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name, config=config)
    

# TODO: get back to this
lora_ckpt_path = os.path.join(best_checkpoint_path, "trainable_initparams.pt")
lora_checkpoint = torch.load(lora_ckpt_path, map_location=device)

lora_name_ckpt_path = os.path.join(best_checkpoint_path, "names.pt")
lora_name_checkpoint = torch.load(lora_name_ckpt_path, map_location=device)

lora_dict = dict(zip(lora_name_checkpoint, lora_checkpoint))


model.load_state_dict(lora_dict, strict=False)


if intrinsic_dim > 0:
    
    
    print("\n# === final trainable parameters (before applying subspace) === #\n")
    for n, p in model.named_parameters():
        if p.requires_grad:
            print(n)
    print("\n# === final trainable parameters (before applying subspace) === #\n")
    
    model = create_intrinsic_model(base_net=model,
                                   ckpt_path=None,
                                   intrinsic_mode="rdkronqr",
                                   intrinsic_dim=intrinsic_dim,
                                   seed=137,
                                   device=device,)

    print("\n# === final trainable parameters (after applying subspace) === #\n")
    for n, p in model.named_parameters():
        if p.requires_grad:
            print(n)
    print("\n# === final trainable parameters (after applying subspace) === #\n")

# dataset
label_to_id = None

pad_to_max_length = True
padding = "max_length" if pad_to_max_length else False

def preprocess_function(examples):
    # Tokenize the texts
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*texts, padding=padding, max_length=1024, truncation=True)

    if "label" in examples:
        if label_to_id is not None:
            # Map labels to IDs (not necessary for GLUE tasks)
            result["labels"] = [label_to_id[l] for l in examples["label"]]
        else:
            # In all cases, rename the column to labels because the model will expect that.
            result["labels"] = examples["label"]
    return result

processed_datasets = raw_datasets.map(
    preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if task_name == "mnli" else "validation"]

# DataLoaders creation:
use_fp16 = True
if pad_to_max_length:
    data_collator = default_data_collator
else:
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if use_fp16 else None))

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=8
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=8)


subspace_ckpt_path = os.path.join(best_checkpoint_path, "subspace_params_saved.pt")
subspace_checkpoint = torch.load(subspace_ckpt_path, map_location=device)

model.load_state_dict({"subspace_params": subspace_checkpoint})

model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type, correct_bias, adam_epislon, no_decay_bias)

checkpoint = None # free up memory


################ 4. Defining some utils for later .. ###################

# helps estimate an arbitrarily accurate loss over either split using many batches

estimate_loss_split_list = ['train', 'val']
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for i, dataloader in enumerate([train_dataloader, eval_dataloader]):
        losses = torch.zeros(eval_iters)
    
        for step, batch in enumerate(dataloader):
            if step >= eval_iters:
                break
                
            outputs = model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))
            og_loss = F.cross_entropy(outputs.logits, batch['labels'].to(device))

            losses[step] = og_loss.item()
            
        out[estimate_loss_split_list[i]] = losses.mean()
    model.train()
    return out
    

################ [PENDING] 5. Evaluation of the model without the quantized weights ###################

eval_iters=50

losses = estimate_loss()

print("EVALUATING THE MODEL BEFORE QUANTIZATION")

print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

################ 6. Quantization ###################


if max_quant_iters > 0 and intrinsic_dim > 0:
    raise ValueError("not using this due to stability issue")
    # vector = model.subspace_params.cpu().data.numpy()
    # cluster_fn = quantize.get_random_symbols_and_codebook
    # if use_kmeans:
    #     cluster_fn = quantize.get_kmeans_symbols_and_codebook
    # _, centroids = cluster_fn(vector, levels=levels, codebook_dtype=np.float16)
    # centroids = torch.tensor(centroids, dtype=torch.float32)
    # centroids = centroids.to(device)
    # quantizer_fn = quantize.Quantize().apply
    # qw = quantize.QuantizingWrapper(model, quantizer=quantizer_fn, centroids=centroids)
    # optim = SGD(
    #     [qw.subspace_params, qw.centroids],
    #     lr = quant_lr, momentum=0.9)

    # for e in tqdm(range(max_quant_iters)):
    #     qw.train()
    #     optim.zero_grad()
    #     X, Y, ix = get_batch('train')
    #     logits, loss = qw(X, Y)
    #     loss.backward()
    #     optim.step()
    #     if e % 10 == 0:
    #         metrics = {"iter": e, "ix": ix, "mini_loss": loss.detach().item()}
    #         print(metrics)
    # quantized_vec = qw.quantizer(qw.subspace_params, qw.centroids)
    # quantized_vec = quantized_vec.cpu().detach().numpy()
    # vec = (qw.centroids.unsqueeze(-2) - qw.subspace_params.unsqueeze(-1))**2.0
    # symbols = torch.min(vec, -1)[-1]
    # symbols = symbols.cpu().detach().numpy()
    # centroids = qw.centroids.cpu().detach().numpy()
    # # centroids = centroids.astype(np.float16)
    # probabilities = np.array([np.mean(symbols == i) for i in range(levels)])
    # _, coded_symbols_size = quantize.do_arithmetic_encoding(symbols, probabilities,
    #                                                qw.centroids.shape[0])
    # message_len = quantize.get_message_len(
    #     coded_symbols_size=coded_symbols_size,
    #     codebook=centroids,
    #     max_count=len(symbols),
    # )
else:
    if intrinsic_dim > 0:
        module = model.module if isinstance(model,
                                        torch.nn.parallel.DistributedDataParallel) else model
        vector = module.subspace_params.cpu().data.numpy()
        quantized_vec, message_len = quantize.quantize_vector(vector, levels=levels, use_kmeans=use_kmeans)
    else:
        aux = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
        names, vector = zip(*aux)
        fvector = projectors.flatten(vector).cpu().data.numpy()
        quantized_vec, message_len = quantize.quantize_vector(fvector, levels=levels, use_kmeans=use_kmeans)
        ## free memory 
        fvector = None 
        

if intrinsic_dim > 0:
    module = model.module if ddp else model
    module.subspace_params.data = torch.tensor(quantized_vec).float().to(device)
else:
    unfquantized_vec = projectors.unflatten_like(torch.tensor(quantized_vec), vector)
    ## free memory  
    quantized_vec, vector = None, None
    for n, p in model.named_parameters():
        for name, quantp in zip(names, unfquantized_vec):
            if n == name:
                p.data = torch.tensor(quantp).float().to(device)
        
prefix_message_len = message_len + 2 * np.log2(message_len) if message_len > 0 else 0

################ 7. Evaluation of the model with the quantized weights ###################

losses = estimate_loss()

print("EVALUATING THE MODEL AFTER QUANTIZATION")

print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


################ 8. Saving everything for the bound computation ###################


raw_model = model.module if ddp else model # unwrap DDP container if needed

if use_lora:
    raw_model_state_dict = raw_model.state_dict()
    lora_state_dict = lora.lora_state_dict(model)
else:
    raw_model_state_dict = raw_model.state_dict()
    lora_state_dict = None

checkpoint = {
    'raw_model': raw_model_state_dict,
    'lora_model': lora_state_dict,
    'optimizer': None,
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': None,
    'config': config,
    'prefix_message_len': prefix_message_len, 
}
# print(f"saving checkpoint to {out_dir}")

torch.save(checkpoint, os.path.join(best_checkpoint_path, f'quant_ckpt_levels{levels}_iters{max_quant_iters}.pt'))

################ 9. Defining some eval utils for later .. ###################

def sum_k_elements(row, k_th):
    return torch.sum(row[:k_th+1])


# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def compute_pretrain_bound_metrics(X, Y):
    if intrinsic_dim == 0:
        with ctx:
            logits, loss, _, alphas = model(X, Y)
    else:
        logits, loss, _, alphas = model(X, Y)
    
    # === for each token, log top-k indice, percentile, and log probability === #
    softmax_matrix = torch.nn.functional.softmax(logits.view(-1, logits.size(-1)),dim=-1)
    sorted_softmax_matrix, indices_softmax_matrix = torch.sort(softmax_matrix, dim=-1, descending=True)
    #torch.cumsum(sorted_softmax_matrix, dim=-1)
    
    # 1. top-k indices
    top_k_indices = torch.argmax((indices_softmax_matrix == Y.view(-1).unsqueeze(1)).long(),dim=-1)
    
    # 4. probabilities 
    selected_prob_scores = softmax_matrix[torch.arange(softmax_matrix.shape[0]), Y.view(-1)]
    
    
    # 3. percentile
    percentile_vec = torch.zeros((sorted_softmax_matrix.shape[0],), device=device)
    for i in range(sorted_softmax_matrix.shape[0]):
        percentile_vec[i] = sum_k_elements(sorted_softmax_matrix[i], top_k_indices[i])
        
    # 2. log probability
    selected_log_prob_scores = torch.log(selected_prob_scores)
    
            
    return loss.item(), top_k_indices, selected_log_prob_scores, percentile_vec, selected_prob_scores, alphas

################ 10. Beginning of the evaluation ###################

# clear file contents
with open(os.path.join(best_checkpoint_path, f'ix_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_ix:
    pass 
with open(os.path.join(best_checkpoint_path, f'loss_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_loss:
    pass 
with open(os.path.join(best_checkpoint_path, f'top_k_indices_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_top_k_indices:
    pass 
with open(os.path.join(best_checkpoint_path, f'selected_log_prob_scores_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_selected_log_prob_scores:
    pass 
with open(os.path.join(best_checkpoint_path, f'selected_prob_scores_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_selected_prob_scores:
    pass 
with open(os.path.join(best_checkpoint_path, f'percentile_vec_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_percentile_vec:
    pass 
with open(os.path.join(best_checkpoint_path, f'alphas_levels{levels}_iters{max_quant_iters}.txt'), 'w') as file_alphas:
    pass 

alpha_array = [0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5]

assert ddp_world_size == 1, "multi-GPU bound computations are not yet properly implemented"


with torch.no_grad():
    model.eval()

    # len_train_data = len(train_data)

    curr_iter_i = 0
    
    metrics_dict = {}
    # for k in range(1,10+1):
    #     metrics_dict[f'top_{k}_acc'] = 0
    metrics_dict[f'finetune_acc'] = 0
    # metrics_dict[f'top_100_acc'] = 0
    
    # if not optimize_alpha:
    #     for alpha in alpha_array:
    #         metrics_dict[f'bpd_alpha_{alpha}'] = 0
    # else:
    #     metrics_dict[f'optimized_alpha_bpd'] = 0
    metrics_dict["n_train"] = 0
    metrics_dict["curr_iter_i"] = 0
    metrics_dict["delta"] = 0.


    for step, batch in enumerate(train_dataloader):
        # Break the loop after 10 iterations
        if sample_size <= metrics_dict["n_train"]:
            break

        # Rest of your training loop code
        outputs = model(input_ids=batch['input_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device),
                        labels=batch['labels'].to(device))
        og_loss = F.cross_entropy(outputs.logits, batch['labels'].to(device))
        og_loss = og_loss.item()

        finetune_acc = torch.sum(batch['labels'].to(device)==torch.argmax(outputs.logits, dim=-1)).item()
        metrics_dict[f'finetune_acc'] = (metrics_dict[f'finetune_acc'] * metrics_dict["n_train"] + finetune_acc) / (metrics_dict["n_train"] + train_dataloader.batch_size)

        # update batch size estimation
        metrics_dict["n_train"] += train_dataloader.batch_size # TODO check back on this
        metrics_dict["curr_iter_i"] += 1 
        
        with open(os.path.join(best_checkpoint_path, f'metrics_levels{levels}_iters{max_quant_iters}.yml'), 'w') as f:
                        yaml.safe_dump(metrics_dict, f, indent=2)
                
        if wandb_log:
            wandb.log(metrics_dict)
            
        if curr_iter_i % 100 == 0:
            print("\n".join("{}\t{}".format(k, v) for k, v in metrics_dict.items()))

        curr_iter_i += 1
        
################ 11. Compute the bounds ###################


# if intrinsic_dim == 0:
#     if attention_linear_lora_r == 0:
#         # no_lora_no_id
#         misc_extra_bits = np.ceil(np.log2(2*3)) #  two ranks of lora & 3 learning rates; 
#     else:
#         # lora_no_id 
#         misc_extra_bits = np.ceil(np.log2(2*2*3)) 
        
# else:
#     if attention_linear_lora_r == 0:
#         ## no_lora_id
#         misc_extra_bits = np.ceil(np.log2(2*6*3)) 
#     else:
#         ## lora_id
#         misc_extra_bits = np.ceil(np.log2(4*6*2*3)) 

if intrinsic_dim == 0:
    if attention_linear_lora_r == 0:
        # no_lora_no_id
        misc_extra_bits = np.ceil(np.log2(1)) #  two ranks of lora & 3 learning rates; 
    else:
        # lora_no_id 
        misc_extra_bits = np.ceil(np.log2(1)) 
        
else:
    if attention_linear_lora_r == 0:
        ## no_lora_id
        misc_extra_bits = np.ceil(np.log2(1)) 
    else:
        ## lora_id
        misc_extra_bits = np.ceil(np.log2(1)) 

        
prefix_message_len = torch.load(os.path.join(best_checkpoint_path, f'quant_ckpt_levels{levels}_iters{max_quant_iters}.pt'))['prefix_message_len']
sample_size = metrics_dict["n_train"]

def corrected_bound(train_error, div, data_size, sample_size, delta = 1, epsilon=failure_prob):
    r = sample_size/(sample_size + data_size)
    complexity = np.sqrt((div - np.log(r * epsilon))/(2*data_size)) +np.sqrt(-np.log((1-r) * epsilon)/(2*sample_size))
    bound = train_error
    return bound+delta*complexity

bounds_dict = {}
bounds_dict["prefix_message_len"] = float(prefix_message_len)

best_bpd_bound = np.inf

for k in metrics_dict.keys():
    if "acc" in k:
        train_error = 1. - metrics_dict[k] 

        divergence = (prefix_message_len + misc_extra_bits) * np.log(2)

        bounds_dict["acc_divergence"] = float(divergence)

        bounds_dict[f"bound_{k}"] = float(corrected_bound(train_error=train_error,
                                              div=divergence,
                                              data_size=data_size,     # TODO: make 
                                              sample_size=sample_size,
                                              delta=1.))

    elif "bpd" in k:
        if optimize_alpha:

            misc_extra_bits += np.ceil(len(alpha_array))

            divergence = (prefix_message_len + misc_extra_bits) * np.log(2)

            bounds_dict["bpd_divergence"] = float(divergence)

            delta = metrics_dict["delta"]

            train_error = metrics_dict[k]

            bounds_dict[f"bound_{k}"] = float(corrected_bound(train_error=train_error,
                                                  div=divergence,
                                                  data_size=data_size,
                                                  sample_size=sample_size,
                                                  delta=delta))

        else:
            misc_extra_bits += np.ceil(len(alpha_array))

            divergence = (prefix_message_len + misc_extra_bits) * np.log(2)

            bounds_dict["bpd_divergence"] = float(divergence)

            alpha = float(k.replace("bpd_alpha_", ""))

            delta = np.log2(1 + (1 - alpha) * vocab_size / alpha)

            train_error = metrics_dict[k]

            bounds_dict[f"bound_{k}"] = float(corrected_bound(train_error=train_error,
                                                  div=divergence,
                                                  data_size=data_size,
                                                  sample_size=sample_size,
                                                  delta=delta))
        if best_bpd_bound > bounds_dict[f"bound_{k}"]:
            best_bpd_bound = bounds_dict[f"bound_{k}"]
            
# bounds_dict["best_bpd_bound"] = best_bpd_bound
                        
print("\n".join("{}\t{}".format(k, v) for k, v in bounds_dict.items()))

if wandb_log:
    wandb.log(bounds_dict)
            
with open(os.path.join(best_checkpoint_path, f'bounds_levels{levels}_iters{max_quant_iters}.yml'), 'w') as f:
                        yaml.safe_dump(bounds_dict, f, indent=2)
        

        

