from pyexpat import model
import torch
import os
import json
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm
from transformers import BertTokenizerFast, XLNetTokenizerFast, RobertaTokenizerFast, AlbertTokenizerFast
from transformers import BertForQuestionAnswering, XLNetForQuestionAnswering, RobertaForQuestionAnswering, AlbertForQuestionAnswering, AutoModelForQuestionAnswering
import numpy as np
from src.utils.ft_pytorch.squad_util import *

def load_model(model_name, model_path):
    
    if model_name == 'bert':
        tokenizer = BertTokenizerFast.from_pretrained(model_path)
        model = BertForQuestionAnswering.from_pretrained(model_path)
    elif model_name == 'xlnet':
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForQuestionAnswering.from_pretrained(model_path)
    elif model_name == 'roberta':
        tokenizer = RobertaTokenizerFast.from_pretrained(model_path)
        model = RobertaForQuestionAnswering.from_pretrained(model_path)
    elif model_name == 'albert':
        tokenizer = AlbertTokenizerFast.from_pretrained(model_path)
        model = AlbertForQuestionAnswering.from_pretrained(model_path)

    return tokenizer, model


def read_squad(path, batch_size):
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    # initialize lists for contexts, questions, and answers
    contexts = []
    questions = []
    answers = []
    # iterate through all data in squad data
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                if 'plausible_answers' in qa.keys():
                    access = 'plausible_answers'
                else:
                    access = 'answers'
                for answer in qa['answers']:
                    # append data to lists
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
    # return formatted data lists
    extra_examples = len(answers) % batch_size
    contexts = contexts[:-extra_examples]
    questions = questions[:-extra_examples]
    answers = answers[:-extra_examples]
    
    return contexts, questions, answers

def add_end_idx(answers, contexts):
    # loop through each answer-context pair
    for answer, context in zip(answers, contexts):
        # gold_text refers to the answer we are expecting to find in context
        gold_text = answer['text']
        # we already know the start index
        start_idx = answer['answer_start']
        # and ideally this would be the end index...
        end_idx = start_idx + len(gold_text)

        # ...however, sometimes squad answers are off by a character or two
        if context[start_idx:end_idx] == gold_text:
            # if the answer is not off :)
            answer['answer_end'] = end_idx
        else:
            for n in [1, 2]:
                if context[start_idx-n:end_idx-n] == gold_text:
                    # this means the answer is off by 'n' tokens
                    answer['answer_start'] = start_idx - n
                    answer['answer_end'] = end_idx - n

def add_token_positions(encodings, answers, tokenizer, max_len):
    # initialize lists to contain the token indices of answer start/end
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        # append start/end token position using char_to_token method
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = max_len -1
        # end position cannot be found, char_to_token found space, so shift one token forward
        go_back = 1
        while end_positions[-1] is None:
            if answers[i]['answer_end'] + go_back > tokenizer.model_max_length:
                end_positions[-1] = max_len #can this value be NONE? we have to check
                break            
            end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] + go_back)
            go_back +=1
        # if end_positions[-1] is None:
        #     end_positions[-1] = (max_len - 1) if encodings.char_to_token(i, answers[i]['answer_end'] + 1) == None else encodings.char_to_token(i, answers[i]['answer_end'] + 1)
    # update our encodings object with the new token-based start/end positions
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)




def prepare_data(tokenizer, path, split='train', batch_size=128, max_length=128):
    print("##################__Prepare Data__##################")
    contexts, questions, answers = read_squad(path, batch_size)
    
    add_end_idx(answers, contexts)

    encodings = tokenizer(contexts, questions, truncation=True, padding=True, max_length=max_length)

    add_token_positions(encodings, answers, tokenizer, max_length)
    
    dataset = SquadDataset(encodings)

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return data_loader

def training_loop(model, model_path, train_loader, num_of_epoch=1, learning_rate=5e-5):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # move model over to detected device
    gpu_ids = [i for i in range(torch.cuda.device_count())]
    print(gpu_ids)
    model = torch.nn.DataParallel(model, device_ids=gpu_ids)

    model.to(gpu_ids[0])
    # initialize adam optimizer with weight decay (reduces chance of overfitting)
    optim = AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(num_of_epoch):
        # set model to train mode
        model.train()
        # setup loop (we use tqdm for the progress bar)
        loop = tqdm(train_loader, leave=True)
        for batch in loop:
            # initialize calculated gradients (from prev step)
            optim.zero_grad()
            # pull all the tensor batches required for training
            input_ids = batch['input_ids'].to(gpu_ids[0])
            attention_mask = batch['attention_mask'].to(gpu_ids[0])
            start_positions = batch['start_positions'].to(gpu_ids[0])
            end_positions = batch['end_positions'].to(gpu_ids[0])
            # train model on batch and return outputs (incl. loss)
            outputs = model(input_ids, attention_mask=attention_mask,
                            start_positions=start_positions,
                            end_positions=end_positions)
            # extract loss
            loss = outputs[0]
            # calculate loss for every parameter that needs grad update
            loss.mean().backward()
            # update parameters
            optim.step()
            # print relevant info to progress bar
            loop.set_description(f'Epoch {epoch}')
            # loop.set_postfix(loss=loss.mean().item())
    
    model.save_pretrained(model_path)


def evaluation_loop(model, tokenizer, val_loader, batch_size, file):
    gpu_ids = [i for i in range(torch.cuda.device_count())]
    print(gpu_ids)
    model = torch.nn.DataParallel(model, device_ids=gpu_ids)

    model.to(gpu_ids[0])

    model.eval()
    
    acc = []
    predictions = []
    ground_truths = []

    # initialize loop for progress bar
    loop = tqdm(val_loader)
    # loop through batches
    for batch in loop:
        # we don't need to calculate gradients as we're not training
        with torch.no_grad():
            # pull batched items from loader
            input_ids = batch['input_ids'].to(gpu_ids[0])
            attention_mask = batch['attention_mask'].to(gpu_ids[0])
            start_true = batch['start_positions'].to(gpu_ids[0])
            end_true = batch['end_positions'].to(gpu_ids[0])
            # make predictions
            outputs = model(input_ids, attention_mask=attention_mask)
            # pull preds out
            start_pred = torch.argmax(outputs['start_logits'], dim=1)
            end_pred = torch.argmax(outputs['end_logits'], dim=1)
            # calculate accuracy for both and append to accuracy list
            acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
            acc.append(((end_pred == end_true).sum()/len(end_pred)).item())

            for i in range(len(batch['input_ids'])):
                output_tokens = input_ids.cpu().numpy()[i][start_pred.cpu().numpy()[i]:end_pred.cpu().numpy()[i]]
                gt_tokens = input_ids.cpu().numpy()[i][start_true.cpu().numpy()[i]:end_true.cpu().numpy()[i]]
                # gts = []
                answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(output_tokens))
                ground_truth = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(gt_tokens))
                # gts.append(ground_truth)
                predictions.append(answer.split())
                ground_truths.append(ground_truth.split())
    # calculate average accuracy in total
    acc = sum(acc)/len(acc)
    f1, precision, recall = compute_f1(predictions, ground_truths)
    print(f"Accuracy in Exact Match: {acc}\nAccuracy in F1: {f1}, \t Precision: {precision}, \t Recall: {recall}\n")
    file.write(f"\nAccuracy in Exact Match: {acc}\nAccuracy in F1: {f1}, \t Precision: {precision}, \t Recall: {recall}\n")



def main(schema_name='flight_delay', model_name='bert', combined=False):
    # model_path = f'models/squad/combined/{model_name}'
    # data_path = f"src/data/test_data/squad_format/{schema_name}.json"
    # batch_size = 32
    # max_seq_len = 128
    # print(model_path)
    # tokenizer, model = load_model(model_name, model_path)
    # evaluation_dataloader = prepare_data(tokenizer, data_path, split='train', batch_size=batch_size, max_length=max_seq_len)
    # evaluation_loop(model, tokenizer, evaluation_dataloader, batch_size)
    CONF_LOC = "src/core/evaluation/global_config.json"
    file = open("src/core/ensembles/squad_results.txt", 'a')
    with open(CONF_LOC) as config_file:
        configuration = json.load(config_file)

        SCHEMA_NAMES = configuration["schema"]
        MODEL_LIST = configuration["ner_models"]
        
        for schema_name in SCHEMA_NAMES:
            
            print(f"**************__{schema_name}_**************")
            file.write(f"**************__{schema_name}_**************\n")
            for model_name in MODEL_LIST:
                print(f"**************__{model_name}_**************")
                file.write(f"**************__{model_name}_**************\n")

                model_path = f'models/squad/combined/{model_name}' if combined else f'models/squad/{schema_name}/{model_name}'
                data_path = f"src/data/test_data/squad_format/{schema_name}.json"
                batch_size = 32
                max_seq_len = 128
                print(model_path)
                tokenizer, model = load_model(model_name, model_path)
                evaluation_dataloader = prepare_data(tokenizer, data_path, split='train', batch_size=batch_size, max_length=max_seq_len)
                evaluation_loop(model, tokenizer, evaluation_dataloader, batch_size, file)
        

    file.close()

# main(schema_name='student_perf', model_name='roberta', combined=False)