import torch
from datasets import load_from_disk, load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
from transformers import get_scheduler
import numpy as np
import copy
import os
import matplotlib.pyplot as plt

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

bert_small = "prajjwal1/bert-small"
# raw_datasets = load_from_disk("./datasets")
raw_datasets = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained(bert_small)


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")


full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

BS = 100
train_dataloader = DataLoader(full_train_dataset, shuffle=True, batch_size=BS)
eval_dataloader = DataLoader(full_eval_dataset, batch_size=BS)



def train(num_epochs, model, optimizer, lr_scheduler, method_name, lr, save_path):
    model.train()
    loss_history = []
    
    for epoch in range(num_epochs):
        fo = open(save_path + 'loss_' + method_name + '_lr'+str(lr)+'.txt', "a")
        model_path = save_path + method_name + '_epoch' + str(epoch) + '.pth.tar'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, model_path)    
        j = 0
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            cur_loss = loss.item()
            cur_lr = lr_scheduler.get_last_lr()[0]
            if j % LOG_INTERVAL == 0:
                loss_history.append(cur_loss)
                print('| epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f} | lr {}'.format(
                    epoch, j, len(train_dataloader), cur_loss, cur_lr))
                fo.write(str(cur_loss) + '\n')

            j = j + 1
            lr_scheduler.step()
            optimizer.zero_grad()

        progress_bar.update(1)
        fo.close()
    return loss_history


learning_rate_ada = 5e-5
learning_rate_sgd = 0.001
beta1 = 0.9
num_epochs = 10
LOG_INTERVAL = 10
save_path = './BERT_results/'
if not os.path.exists(save_path):
    os.mkdir(save_path)
ada_method = 'Adam'

model = AutoModelForSequenceClassification.from_pretrained(bert_small, num_labels=2)
model.to(device)
model_ada = copy.deepcopy(model)
model_sgd = copy.deepcopy(model)

optimizer_sgd = torch.optim.SGD(model_sgd.parameters(), lr=learning_rate_sgd, momentum=beta1)
optimizer_ada = torch.optim.Adam(model_ada.parameters(), lr=learning_rate_ada)

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler_ada = get_scheduler(
    "linear",
    optimizer=optimizer_ada,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)
lr_scheduler_sgd = get_scheduler(
    "linear",
    optimizer=optimizer_sgd,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

from tqdm.auto import tqdm

progress_bar = tqdm(range(num_epochs))

loss_ada = train(num_epochs, model_ada, optimizer_ada, lr_scheduler_ada, ada_method, learning_rate_ada, save_path)
progress_bar.refresh()
progress_bar.reset()
loss_sgd = train(num_epochs, model_sgd, optimizer_sgd, lr_scheduler_sgd, 'SGD', learning_rate_sgd, save_path)

loss_sgd_np = np.array(loss_sgd)
loss_ada_np = np.array(loss_ada)
x_axis = LOG_INTERVAL * np.linspace(0, len(loss_ada) - 1, len(loss_ada))
fig = plt.figure()
plt.title('Loss, SGD+M lr=' + str(learning_rate_sgd) + ', Adam lr=' + str(learning_rate_ada))
plt.ylabel("Loss")
plt.xlabel("Iteration")
plt.plot(x_axis, loss_sgd_np, label='SGD+M')
plt.plot(x_axis, loss_ada_np, label='Adam')
plt.legend()
fig.savefig(save_path + 'loss_SGD_lr' + str(learning_rate_sgd) + '_Adam_lr' + str(learning_rate_ada) + '.png')
