#!/usr/bin/env python
# coding: utf-8

# # Fine-Tuning With SQuAD 2.0

# In[ ]:


import torch
import os
import requests
import json
import random
import functools
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm
from transformers import BertTokenizerFast, XLNetTokenizerFast, RobertaTokenizerFast, AutoTokenizer
from transformers import BertForQuestionAnswering, AutoModelForQuestionAnswering, RobertaForQuestionAnswering, AlbertForQuestionAnswering

# In[4]:


# if not os.path.exists('../data/benchmarks/squad'):
#     os.mkdir('../data/benchmarks/squad')


# ---
# # Get and Prepare Data
# 
# ## Download SQuAD data

# In[5]:


# url = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/'
# res = requests.get(f'{url}train-v2.0.json')


# # In[6]:


# for file in ['train-v2.0.json', 'dev-v2.0.json']:
#     res = requests.get(f'{url}{file}')
#     # write to file
#     with open(f'../data/benchmarks/squad/{file}', 'wb') as f:
#         for chunk in res.iter_content(chunk_size=4):
#             f.write(chunk)


# ## Read

# In[ ]:


def read_squad(path_base, mode='train', combined=False):
    if combined:
        if mode == 'train':
            path_list = [
                f'src/data/fine_tuning/squad/flight_delay/small/train.json',
                f'src/data/fine_tuning/squad/online_delivary/small/train.json',
                f'src/data/fine_tuning/squad/student_perf/small/train.json'
            ]
        else:
            path_list = [
                f'src/data/fine_tuning/squad/flight_delay/small/test.json',
                f'src/data/fine_tuning/squad/online_delivary/small/test.json',
                f'src/data/fine_tuning/squad/student_perf/small/test.json'
            ]
    else:
        path_list = [path_base]
    
    contexts_total = []
    questions_total = []
    answers_total = []
    for path in path_list:
        with open(path, 'rb') as f:
            squad_dict = json.load(f)

        # initialize lists for contexts, questions, and answers
        contexts = []
        questions = []
        answers = []
        # iterate through all data in squad data
        for group in squad_dict['data']:
            for passage in group['paragraphs']:
                context = passage['context']
                for qa in passage['qas']:
                    question = qa['question']
                    if 'plausible_answers' in qa.keys():
                        access = 'plausible_answers'
                    else:
                        access = 'answers'
                    for answer in qa['answers']:
                        # append data to lists
                        contexts.append(context)
                        questions.append(question)
                        answers.append(answer)
        # return formatted data lists
        extra_examples = len(answers) % batch_size
        contexts = contexts[:-extra_examples]
        questions = questions[:-extra_examples]
        answers = answers[:-extra_examples]
        contexts_total.append(contexts)
        questions_total.append(questions)
        answers_total.append(answers)
    
    min_len = min([len(ex) for ex in contexts_total])

    contexts_total = [ex[:min_len] for ex in contexts_total]
    questions_total = [ex[:min_len] for ex in questions_total]
    answers_total = [ex[:min_len] for ex in answers_total]

    contexts_total = functools.reduce(lambda a, b: a+b, contexts_total)
    questions_total = functools.reduce(lambda a, b: a+b, questions_total)
    answers_total = functools.reduce(lambda a, b: a+b, answers_total)

    extra_example_size = (len(contexts_total) % batch_size)

    contexts_total = contexts_total if extra_example_size == 0 else contexts_total[:-(len(contexts_total) % batch_size)]
    questions_total = questions_total if extra_example_size == 0 else questions_total[:-(len(questions_total) % batch_size)]
    answers_total = answers_total if extra_example_size == 0 else answers_total[:-(len(answers_total) % batch_size)]

    sampled_context = []
    sampled_questions = []
    sampled_answers = []
    percentage = 0.60
    for i in range(len(contexts_total)):
        rand = random.random()
        if rand < percentage:
            sampled_context.append(contexts_total[i])
            sampled_questions.append(questions_total[i])
            sampled_answers.append(answers_total[i])

    print(f"length of dataset: {len(sampled_answers)} From {len(answers_total)}")
    return sampled_context, sampled_questions, sampled_answers


schema_name = 'flight_delay'
model_name = 'albert'
combined = False
model_path = f'models/squad/{schema_name}/{model_name}' if not combined else f'models/squad/combined/{model_name}'
batch_size = 256
max_seq_len = 128
learning_rate = 5e-5
weight_decay = 0.1
if not os.path.exists(model_path):
    os.makedirs(model_path)

print(f"schema name: {schema_name}, model path: {model_path}, batch size: {batch_size}")

# train_contexts, train_questions, train_answers = read_squad('src/data/fine_tuning/squad/train-v2.0.json')
train_contexts, train_questions, train_answers = read_squad(f'src/data/fine_tuning/squad/{schema_name}/small/train.json', combined=combined)
val_contexts, val_questions, val_answers = read_squad(f'src/data/fine_tuning/squad/{schema_name}/small/test.json', mode='test', combined=combined)



def add_end_idx(answers, contexts):
    # loop through each answer-context pair
    for answer, context in zip(answers, contexts):
        # gold_text refers to the answer we are expecting to find in context
        gold_text = answer['text']
        # we already know the start index
        start_idx = answer['answer_start']
        # and ideally this would be the end index...
        end_idx = start_idx + len(gold_text)

        # ...however, sometimes squad answers are off by a character or two
        if context[start_idx:end_idx] == gold_text:
            # if the answer is not off :)
            answer['answer_end'] = end_idx
        else:
            for n in [1, 2]:
                if context[start_idx-n:end_idx-n] == gold_text:
                    # this means the answer is off by 'n' tokens
                    answer['answer_start'] = start_idx - n
                    answer['answer_end'] = end_idx - n



add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)




if model_name == 'bert':
    model_path = 'bert-base-uncased'
    learning_rate = 7.5e-5
    weight_decay = 0.25
    tokenizer = BertTokenizerFast.from_pretrained(model_path)
    model = BertForQuestionAnswering.from_pretrained(model_path)
elif model_name == 'albert':
    model_path = 'albert-base-v2'
    learning_rate = 7.5e-5
    weight_decay = 0.25
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AlbertForQuestionAnswering.from_pretrained(model_path)
elif model_name == 'xlnet':
    model_path = 'xlnet-base-cased'
    learning_rate = 1e-4
    weight_decay = 0.25
    tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased')
    model = AutoModelForQuestionAnswering.from_pretrained('xlnet-base-cased')
elif model_name == 'roberta':
    model_path = 'roberta-base'
    learning_rate = 7.5e-5
    weight_decay = 0.01
    tokenizer = RobertaTokenizerFast.from_pretrained(model_path)
    model = RobertaForQuestionAnswering.from_pretrained(model_path)

train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True, max_length=max_seq_len)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True, max_length=max_seq_len)



def add_token_positions(encodings, answers):
    # initialize lists to contain the token indices of answer start/end
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        # append start/end token position using char_to_token method
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            # start_positions[-1] = encodings.char_to_token(i, answers[i]['answer_start'] -1)
            start_positions[-1] = max_seq_len - 1

        # end position cannot be found, char_to_token found space, so shift one token forward
        go_back = 1
        while end_positions[-1] is None:
            if answers[i]['answer_end'] + go_back > tokenizer.model_max_length:
                end_positions[-1] = max_seq_len - 1 #can this value be NONE? we have to check
                break            
            end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] + go_back)
            go_back +=1
        # if end_positions[-1] is None:
        #     end_positions[-1] = max_seq_len - 1 if encodings.char_to_token(i, answers[i]['answer_end'] + 1) == None else encodings.char_to_token(i, answers[i]['answer_end'] + 1)
    # update our encodings object with the new token-based start/end positions
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

# apply function to our data
add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)





class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)






# setup GPU/CPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# move model over to detected device
gpu_ids = [i for i in range(torch.cuda.device_count())]
print(gpu_ids)
model = torch.nn.DataParallel(model, device_ids=gpu_ids)

model.to(gpu_ids[0])
# activate training mode of model
model.train()
# initialize adam optimizer with weight decay (reduces chance of overfitting)
optim = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# initialize data loader for training data
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# train_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)


num_of_epoch = 1

for epoch in range(num_of_epoch):
    # set model to train mode
    model.train()
    # setup loop (we use tqdm for the progress bar)
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all the tensor batches required for training
        input_ids = batch['input_ids'].to(gpu_ids[0])
        attention_mask = batch['attention_mask'].to(gpu_ids[0])
        start_positions = batch['start_positions'].to(gpu_ids[0])
        end_positions = batch['end_positions'].to(gpu_ids[0])
        # train model on batch and return outputs (incl. loss)
        outputs = model(input_ids, attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
        # extract loss
        loss = outputs[0]
        # calculate loss for every parameter that needs grad update
        loss.mean().backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        # loop.set_postfix(loss=loss.mean().item())


# ## Save Model

# In[ ]:


model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


# In[ ]:


# switch model out of training mode
model.eval()

#val_sampler = SequentialSampler(val_dataset)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

acc = []

# initialize loop for progress bar
loop = tqdm(val_loader)
# loop through batches
for batch in loop:
    # we don't need to calculate gradients as we're not training
    with torch.no_grad():
        # pull batched items from loader
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        # make predictions
        outputs = model(input_ids, attention_mask=attention_mask)
        # pull preds out
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        # calculate accuracy for both and append to accuracy list
        acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
        acc.append(((end_pred == end_true).sum()/len(end_pred)).item())
# calculate average accuracy in total
acc = sum(acc)/len(acc)


# In[96]:

print(f"Accuracy in Exact Match: {acc}")
# print("T/F\tstart\tend\n")
# for i in range(len(start_true)):
#     print(f"true\t{start_true[i]}\t{end_true[i]}\n"
#           f"pred\t{start_pred[i]}\t{end_pred[i]}\n")

