"""
This file loads IMBD into GPT2, runs training loop, then plots results
Note: using absolute path, set path on line 27
def train and def validation are reusable for other models (care with cuda)
This file runs experiments with : finetune 1-4 epochs, and baseline GPTsequenceClassification 
"""

import io
import os
import numpy as np
import torch
from imdb_load import MovieReviewsDataset, Gpt2ClassificationCollator, MovieReviewsDataset_simple, Gpt2ClassificationCollator_simple
from gpt_utils import plot_dict 
import time # to do add time t0 and t1, then dt = t1-t0 in train loop
from torch.utils.data import DataLoader
# `pip install ftfy` and import it `from ftfy import fix_text`. May not be needed, depends on dataset.
from sklearn.metrics import classification_report, accuracy_score
from transformers import (set_seed,
                          GPT2Config,
                          GPT2Tokenizer,
                          AutoTokenizer,
                          AdamW, 
                          GPT2Model,
                          get_linear_schedule_with_warmup,
                          GPT2ForSequenceClassification) 


set_seed(123)
# SET data path---------------------------------------------------------------------------    
DATAPATH = '/home/miria/jaxopt/GPT2/data/aclImdb/'
epochs = 1 
batch_size = 1500 
max_length = 60 
device = torch.device('cuda')
print("using device:", device)
# SET model path gpt2, gpt2-medium, gpt2-large, gpt2-xl, or path to your own model--------
model_name_or_path = "gpt2" 
labels_ids = {'neg': 0, 'pos': 1} # assign 0=negative and 1=positive
out_dir = '/home/miria/jaxopt/GPT2/checkpoints/'
# ----------------------------------------------------------------------------------------
print('Load selected GPT2 flavor model config...')
model_config = GPT2Config.from_pretrained("gpt2", num_labels=2) # gets config from HF (how many decoder blocks, etc.)

print('Load GPT2 tokenizer...')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 
# tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.padding_side = "left" # all examples are padded to the same length (take first 60 words ie)
tokenizer.pad_token = tokenizer.eos_token

#----------------------------------------------------------------------------------------
print('Loading model...')
#model = GPT2Model.from_pretrained("openai-community/gpt2")
model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=model_config) # MODEL is GPT2 with a classification head

# sd_hf = model.state_dict()

# print(sd_hf.keys()) # this prints all the weights in the model
# exit()

# check the next two lines of code
# resize model embedding to match new tokenizer
model.resize_token_embeddings(len(tokenizer))
# fix model padding token id
model.config.pad_token_id = model.config.eos_token_id

# Load model to device.
model.to(device)
print('Model loaded to `%s`'%device)


#----------------------------------------------------------------------------------------
# Deal with data : encoding, batches, collator
# use Gpt2ClassificationCollator_simple and MovieReviewsDataset_simple for getting weights without pre-training or finetuning
gpt2_classificaiton_collator = Gpt2ClassificationCollator(use_tokenizer=tokenizer, 
                                                          labels_encoder=labels_ids, 
                                                          max_sequence_len=max_length)

# change 'train/pos/' to 'train/pos/' or 'train/unlabeled/' to get unlabeled data
print('Dealing with Train...')
train_dataset = MovieReviewsDataset(path=DATAPATH + 'train/', use_tokenizer=tokenizer)
print('Created `train_dataset` with %d examples!'%len(train_dataset))

# Move pytorch dataset into dataloader.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=gpt2_classificaiton_collator)
print('Created `train_dataloader` with %d batches!'%len(train_dataloader))

print('Dealing with Validation...')
valid_dataset =  MovieReviewsDataset(path=DATAPATH + 'test/', use_tokenizer=tokenizer)
print('Created `valid_dataset` with %d examples!'%len(valid_dataset))

# Move pytorch dataset into dataloader.
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=gpt2_classificaiton_collator)
print('Created `eval_dataloader` with %d batches!'%len(valid_dataloader))

print('All datasets created!')

#----------------------------------------------------------------------------------------

# training loop function: go through each batch and does the back prop and update the model
def train(dataloader, optimizer_, scheduler_, device_):
  """
  Singel pass.
  Arguments:
      dataloader (:obj:`torch.utils.data.dataloader.DataLoader`):
          Pared data into batches of tensors.
      optimizer_ (:obj:`transformers.optimization.AdamW`):
          Optimizer used for training.
      scheduler_ (:obj:`torch.optim.lr_scheduler.LambdaLR`):
          PyTorch scheduler.
      device_ (:obj:`torch.device`):
          Device used to load tensors before feeding to model.

  Returns:
      :obj:`List[List[int], List[int], float]`: List of [True Labels, Predicted
        Labels, Train Average Loss].
  """

  # Use global variable for model.
  global model

  predictions_labels = []
  true_labels = []
  total_loss = 0

  # Put the model into training mode.
  model.train()

  # For each batch of training data...
  for batch in dataloader:

    # keep track of true labels - use later for evaluation.
    true_labels += batch['labels'].numpy().flatten().tolist()
    
    # move batch to device
    batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
    
    # zero out gradients before performing a backward pass
    model.zero_grad()

    # Perform a forward pass (evaluate the model on this training batch).
    # This will return the loss (rather than the model output) because we
    # have provided the `labels`. This is where generation differse from classification
    outputs = model(**batch) # outputs = model(**inputs)

    loss, logits = outputs[:2] # get loss and logits

    # Accumulate the training loss over all of the batches
    total_loss += loss.item()

    # back prop and update the model
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip gradients to 1.0

    # one step of optimziation (call the optimizer)
    optimizer_.step()
    scheduler_.step()

    # Move logits and labels to CPU (NOT NEEDED IN JAX)
    logits = logits.detach().cpu().numpy()

    # Convert these logits to list of predicted labels values.
    # get prediction labels from logits
    predictions_labels += logits.argmax(axis=-1).flatten().tolist()

  # average loss over the training data.
  avg_epoch_loss = total_loss / len(dataloader)
  # last_hidden_states = outputs.last_hidden_state # Optional[Tuple[torch.FloatTensor]]
  # last_hidden_states_np = last_hidden_states.detach().cpu().numpy()

  # # Save the NumPy array to disk using a format like .npy
  # np.save('last_hidden_states.npy', last_hidden_states_np)

  return true_labels, predictions_labels, avg_epoch_loss


# inference (only doing predictions)
def validation(dataloader, device_):
  """Validation function to evaluate model performance on a 
  separate set of data.

  Arguments:
    dataloader (:obj:`torch.utils.data.dataloader.DataLoader`):
          Parsed data into batches of tensors.
    device_ (:obj:`torch.device`):
          Device used to load tensors before feeding to model.
  Returns:
    :obj:`List[List[int], List[int], float]`: List of [True Labels, Predicted
        Labels, Train Average Loss]
  """

  global model

  predictions_labels = []
  true_labels = []
  total_loss = 0

  model.eval()

  for batch in dataloader:

    # add original labels
    true_labels += batch['labels'].numpy().flatten().tolist()

    batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}

    with torch.no_grad():       

        # Forward pass, grab the logits (probabilities)
        outputs = model(**batch)

        loss, logits = outputs[:2] 
        
        logits = logits.detach().cpu().numpy()

        total_loss += loss.item()
        
        # get predicitons list from logits
        predict_content = logits.argmax(axis=-1).flatten().tolist()

        predictions_labels += predict_content

  avg_epoch_loss = total_loss / len(dataloader)

  return true_labels, predictions_labels, avg_epoch_loss


# init optimizer

# Note: AdamW is a class from the huggingface (not pytorch)
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, 
                  eps = 1e-8 
                  )

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_acc':[], 'val_acc':[]}



# Loop through each epoch.
print('Epoch')
for epoch in range(epochs):
  print()
  print('Training on batches...')
  # Perform one full pass over the training set.
  train_labels, train_predict, train_loss = train(train_dataloader, optimizer, scheduler, device)
  train_acc = accuracy_score(train_labels, train_predict)

  # Get prediction from model on validation data. 
  print('Validation on batches...')
  valid_labels, valid_predict, val_loss = validation(valid_dataloader, device)
  val_acc = accuracy_score(valid_labels, valid_predict)

# def save_checkpoint(self):
#         if self.config.ckpt_path is not None:
#             ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
#             logger.info("saving %s", self.config.ckpt_path)
#             torch.save(ckpt_model.state_dict(), self.config.ckpt_path)


  # Save the model checkpoint
  if (epoch == epochs - 1):
    print("This is the last epoch. Saving model checkpoint...")
    print(f"saving checkpoint to {out_dir}")
    #torch.save(model.state_dict(), os.path.join(out_dir, 'my_ckpt.pt'))
    model.save_pretrained(out_dir)

  # Print loss and accuracy values to see how training evolves.
  print("  train_loss: %.5f - val_loss: %.5f - train_acc: %.5f - valid_acc: %.5f"%(train_loss, val_loss, train_acc, val_acc))
  print()

  # Store the loss value for plotting the learning curve.
  all_loss['train_loss'].append(train_loss)
  all_loss['val_loss'].append(val_loss)
  all_acc['train_acc'].append(train_acc)
  all_acc['val_acc'].append(val_acc)



# plots
# plot_dict(all_loss, use_xlabel='Epochs', use_ylabel='Value', use_linestyles=['-', '--'])
# plot_dict(all_acc, use_xlabel='Epochs', use_ylabel='Value', use_linestyles=['-', '--'])

# Get prediction form model on validation data
# true_labels, predictions_labels, avg_epoch_loss = validation(valid_dataloader, device)

# # create and print the evaluation report.
# evaluation_report = classification_report(true_labels, predictions_labels, labels=list(labels_ids.values()), target_names=list(labels_ids.keys()))
# print(evaluation_report)
