"""Zero-shot QA datasets."""

import json
import numpy as np
from abc import ABC, abstractmethod

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler

from megatron.core import mpu
from megatron.training import get_args, print_rank_0, get_tokenizer

from .utils import swap_letters_and_numbers

def build_data_loader(dataset, micro_batch_size, num_workers):
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = DataLoader(dataset,
                            batch_size=micro_batch_size,
                            sampler=sampler,
                            shuffle=False,
                            num_workers=num_workers,
                            drop_last=False,
                            pin_memory=True,
                            collate_fn=None)

    return data_loader

def build_dataset(task):
    """Helper function to select and build dataset."""

    if task == 'PIQA':
        return _build_piqa_dataset()
    elif task == 'HELLASWAG':
        return _build_hellaswag_dataset()
    elif task == "ARC-E":
        return _build_arce_dataset()

    raise NotImplementedError('dataset for {} task is not '
                              'implemented.'.format(task))

class QADataset(ABC, Dataset):
    def __init__(self, path, pad_idx, tokenizer, seq_len, answers_per_questions, strict=False):
        self.seq_len = seq_len
        self.pad_idx = pad_idx
        self.tokenizer = tokenizer
        self.answers_per_questions = answers_per_questions
        self.strict = strict

        self.tokens = []
        self.labels = []
        self.qa_labels = []

        with open(path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                data = json.loads(line)
                samples, qa_label = self.process_sample_from_single_line(data)

                if len(samples) != self.answers_per_questions:
                    continue
                
                self.qa_labels.append(qa_label)
                for sample in samples:
                    tokens, labels = self.get_tokens(sample)
                    self.tokens.append(tokens)
                    self.labels.append(labels)

    @abstractmethod
    def process_sample_from_single_line(self, idx):
        pass

    def __len__(self):
        return len(self.tokens)
    
    def __getitem__(self, idx):
        tokens, loss_mask, ae_mask = self.process_tokens_and_loss_mask(idx)

        qa_idx = idx
        qa_label = self.qa_labels[idx // self.answers_per_questions]

        return {'qa_sample_id': qa_idx, 'qa_label': qa_label, 'text': tokens, 'loss_mask': loss_mask, 'ae_mask': ae_mask}
    
    def get_answers_per_questions(self):
        return self.answers_per_questions

    def get_tokens(self, text):
        if not self.strict:
            tokens = self.tokenizer.tokenize(text)
            return tokens[:-1], [tokens[-1]]
        last_token = text.split()[-1]
        start_idx = text.rfind(last_token)
        beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
        last_token = self.tokenizer.tokenize(' ' + last_token)
        return beginning_tokens, last_token

    def process_tokens_and_loss_mask(self, idx):
        tokens = self.tokens[idx]
        num_tokens = len(tokens)
        loss_mask = [1] * num_tokens
        labels = self.labels[idx]
        loss_mask += [1] * len(labels)
        tokens = tokens + labels
        num_tokens = len(tokens)
        if num_tokens < self.seq_len + 1:
            num_pad = (self.seq_len + 1 - num_tokens)
            loss_mask += [0] * (num_pad)
            tokens += [self.pad_idx] * num_pad
        ae_mask = np.array(loss_mask[:-1])
        loss_mask = np.array(loss_mask[1:])

        return np.array(tokens), loss_mask, ae_mask

class PIQADataset(QADataset, ABC):
    """
    Open Retrieval Question Answering evaluation using Google NQ dataset.
    """

    def __init__(self, path, pad_idx, tokenizer, seq_len, strict = False):
        super().__init__(path, pad_idx, tokenizer, seq_len, answers_per_questions = 2, strict = strict)

    def process_sample_from_single_line(self, data):
        samples = []
        goal = data['goal']
        sol1 = data['sol1']
        sol2 = data['sol2']
        qa_label = int(data['label'])

        qa1 = f"{goal} {sol1}"
        qa2 = f"{goal} {sol2}"

        samples = [qa1, qa2]

        return samples, qa_label

class HELLASWAGDataset(QADataset, ABC):
    def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
        super().__init__(path, pad_idx, tokenizer, seq_len,  answers_per_questions = 4, strict = strict)

    def process_sample_from_single_line(self, data):
        samples = []
        activity_label = data['activity_label']
        ctx = data['ctx']
        endings = data['endings']
        qa_label = int(data['label'])

        for ending in endings:
            samples.append(f"{activity_label}: {ctx} {ending}")

        return samples, qa_label

class ARCEDataset(QADataset, ABC):
    def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
        super().__init__(path, pad_idx, tokenizer, seq_len,  answers_per_questions = 4, strict = strict)

    def process_sample_from_single_line(self, data):
        samples = []
        question = data['question']
        endings = data['choices']['text']
        qa_label = swap_letters_and_numbers(data['answerKey'])

        for ending in endings:
            samples.append(f"Question:{question} Answer:{ending}")

        return samples, qa_label
    
def _build_piqa_dataset():
    """Build piqa dataset."""
    args = get_args()
    tokenizer = get_tokenizer()

    assert len(args.valid_data) == 1

    print_rank_0('> building PIQA dataset from {} ...'.format(args.valid_data))

    val_dataset = PIQADataset(args.valid_data[0], tokenizer.eod, tokenizer, args.inference_max_seq_length, strict=False)
    print_rank_0(f' > found {len(val_dataset) / val_dataset.get_answers_per_questions()} samples of length {args.inference_max_seq_length} with {len(val_dataset)} evaluations.')

    return val_dataset

def _build_hellaswag_dataset():
    """Build hellaswag dataset."""
    args = get_args()
    tokenizer = get_tokenizer()

    assert len(args.valid_data) == 1

    print_rank_0('> building HELLASWAG dataset from {} ...'.format(args.valid_data))

    val_dataset = HELLASWAGDataset(args.valid_data[0], tokenizer.eod, tokenizer, args.inference_max_seq_length, strict=False)
    print_rank_0(f' > found {len(val_dataset) / val_dataset.get_answers_per_questions()} samples of length {args.inference_max_seq_length} with {len(val_dataset)} evaluations.')

    return val_dataset

def _build_arce_dataset():
    """Build ARC-E dataset."""
    args = get_args()
    tokenizer = get_tokenizer()

    assert len(args.valid_data) == 1

    print_rank_0('> building ARC-E dataset from {} ...'.format(args.valid_data))

    val_dataset = ARCEDataset(args.valid_data[0], tokenizer.eod, tokenizer, args.inference_max_seq_length, strict=False)
    print_rank_0(f' > found {len(val_dataset) / val_dataset.get_answers_per_questions()} samples of length {args.inference_max_seq_length} with {len(val_dataset)} evaluations.')

    return val_dataset
