import torch
from datasets import load_from_disk, load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
from transformers import get_scheduler
import numpy as np
import copy
import os
import matplotlib.pyplot as plt
import pandas as pd

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

bert_small = "prajjwal1/bert-small"
# raw_datasets = load_from_disk("./datasets")
raw_datasets = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained(bert_small)


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")


full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

BS = 40
train_dataloader = DataLoader(full_train_dataset, shuffle=False, batch_size=BS)
eval_dataloader = DataLoader(full_eval_dataset, batch_size=BS)

learning_rate_ada = 5e-5
learning_rate_sgd = 0.001

def load_model(path, method_name, learning_rate):
    model = AutoModelForSequenceClassification.from_pretrained(bert_small, num_labels=2).to(device)
    checkpoint = torch.load(path,map_location=device)
    # Load for model
    model.load_state_dict(checkpoint['model_state_dict'])
    # Load for optimizer
    if method_name ==  'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    elif method_name == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer


# Calculate H_ii of the coordinates in the index_list in the given layer
def cal_diag_hessian(loss, all_params, layer_list, layer, layer_type, index_list, H_G):
    grad_params = torch.autograd.grad(loss, all_params[layer_list[layer]], create_graph=True)[0] #all_params[layer_list[layer]].grad #
    row_num = index_list.shape[1]

    for r in range(row_num):
        index1 = index_list[layer][r]
        if layer_type == 'conv':
            grad = grad_params[index1[0]][index1[1]][index1[2]][index1[3]].item()
            h = torch.autograd.grad(grad_params[index1[0]][index1[1]][index1[2]][index1[3]], all_params[layer_list[layer]], retain_graph=True)[0]
            diagH = h[index1[0]][index1[1]][index1[2]][index1[3]].item()
        elif layer_type == 'bn':
            grad = grad_params[index1[0]].item()
            h = torch.autograd.grad(grad_params[index1[0]], all_params[layer_list[layer]], retain_graph=True)[0]
            diagH = h[index1[0]].item()
        else:
            grad = grad_params[index1[0]][index1[1]].item()
            h = torch.autograd.grad(grad_params[index1[0]][index1[1]], all_params[layer_list[layer]], retain_graph=True)[0]
            diagH = h[index1[0]][index1[1]].item()

        H_G[layer][r][0] += diagH
        H_G[layer][r][1] += grad



def cal_hessian(epoch, model, optimizer, layer_list, index_list, method_name, save_path):
    all_params = optimizer.param_groups[0]['params']
    layer_num = len(layer_list)
    j = 0
    row_num = index_list.shape[1]
    H_G = np.zeros((layer_num, row_num,2))
    for batch in train_dataloader:
        if j > 5:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        # loss.backward(retain_graph=True)

        for layer in range(layer_num):
            cal_diag_hessian(loss, all_params, layer_list, layer, 'fc', index_list, H_G)
        j += 1

    for layer in range(layer_num):
        states = list(optimizer.state_dict()['state'].values())
        if method_name == 'Adam' and epoch > 0:
            exp_avg_sq_list = states[layer_list[layer]]['exp_avg_sq']
            exp_avg_list = states[layer_list[layer]]['exp_avg']

        if method_name == 'SGD' and epoch > 0:
            momentum_buffer_list = states[layer_list[layer]]['momentum_buffer']

        txt_name = save_path + '/layer' + str(layer_list[layer]) + '_epoch' + str(epoch) + '_' + method_name + '_diag.txt'
        fo = open(txt_name, 'a')

        for r in range(row_num):
            fo.write(str(H_G[layer][r][0]) + '\t' + str(H_G[layer][r][1]) + '\t')
            
            index1 = index_list[layer][r]
            if method_name == 'Adam' and epoch > 0:
                exp_avg_sq = exp_avg_sq_list[index1[0]][index1[1]].item()
                exp_avg = exp_avg_list[index1[0]][index1[1]].item()
                fo.write(str(exp_avg_sq) + '\t' + str(exp_avg) + '\n')
            elif method_name == 'Adam' and epoch == 0:
                fo.write(str(0) + '\t' + str(0) + '\n')
            elif method_name == 'SGD' and epoch > 0:
                momentum_buffer = momentum_buffer_list[index1[0]][index1[1]].item()
                fo.write(str(momentum_buffer) + '\n')
            else:
                fo.write(str(0) + '\n')

        fo.close()


from timeit import default_timer as timer


epoch_list = [0,3,5]

# fc_layers = [1, 7]
# fc_layers = [9, 11]
fc_layers = [21, 27, 37, 41, 43, 55, 59, 69]
# fc_layers = [37, 41]#[37, 41, 43, 55, 59, 69]
print(epoch_list)
print(fc_layers)
layer_num = len(fc_layers)
row_num = 200  # the number of coordinates we want to sample

init_model = AutoModelForSequenceClassification.from_pretrained(bert_small, num_labels=2).to(device)
init_optimizer = torch.optim.SGD(init_model.parameters(), lr=learning_rate_sgd)

all_params = init_optimizer.param_groups[0]['params']
# randomly sample row_num coordinates per layer and store their indexes into index_fc
index_fc = np.zeros((layer_num, row_num, 2), dtype=int)
for layer in range(layer_num):
    df_read = pd.read_csv('./rand_coord_BERT/rand_coord_layer' + str(layer) + '.csv')
    index_fc[layer] = df_read.values

path = './BERT_results/'
save_path = path+'diagHessian_adaptGrad_'+str(row_num)
if not os.path.exists(save_path):
    os.mkdir(save_path)
for method_name in ['Adam','SGD']:
    for epoch in epoch_list:
        start_time = timer()
        load_path=path+method_name+'_epoch'+str(epoch)+'.pth.tar'
        model, optimizer = load_model(load_path, method_name, 0.01)
        cal_hessian(epoch, model, optimizer, fc_layers, index_fc, method_name, save_path)
        end_time = timer()
        print(method_name, end=', ')
        print((f"Epoch: {epoch}, "f"Epoch time = {(end_time - start_time):.3f}s"))
