import argparse
import json
import logging
import re
import pdb

import tqdm
import torch
import numpy as np
import sklearn.feature_extraction.text
from torch.utils.data import DataLoader, SequentialSampler

#from utils_multiple_choice import processors, convert_examples_to_features, select_field
import utils_common
from utils_common import TaskType, get_task_type, get_task_processor, MODEL_CLASSES

logger = logging.getLogger(__name__)

def read_jsonl(input_file):
    lines = []
    with open(input_file, "r", encoding='utf-8') as f:
        for line in f.readlines():
            lines.append(json.loads(line))
    return lines

def make_identifier(single_example):
    return (single_example[''], single_example['fold-ind'], single_example['video-id'])

def getcls_model_batch_textcls(model_type, model, batch):
    input_ids = batch[0]
    attention_mask = batch[1]
    token_type_ids = batch[2] if model_type in ['bert', 'xlnet'] else None  # XLM don't use segment_ids
    position_ids = None
    head_mask = None

    if model_type == 'bert':
        encoder = model.bert
    elif model_type == 'roberta':
        encoder = model.roberta
    else:
        raise ValueError("Unknown model_type {}".format(model_type))

    outputs = encoder(
        input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        head_mask=head_mask,
    )
    cls_output = outputs[0][:, 0]
    return cls_output
    #pooled_output = outputs[1]
    #return pooled_output

def getcls_model_batch_multichoice(model_type, model, batch):
    input_ids = batch[0]
    attention_mask = batch[1]
    token_type_ids = batch[2] if model_type in ['bert', 'xlnet'] else None  # XLM don't use segment_ids
    position_ids = None
    head_mask = None

    num_choices = input_ids.shape[1]
    flat_input_ids = input_ids.view(-1, input_ids.size(-1))
    flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
    flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
    flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None

    if model_type == 'bert':
        encoder = model.bert
    elif model_type == 'roberta':
        encoder = model.roberta
    else:
        raise ValueError("Unknown model_type {}".format(model_type))

    outputs = encoder(
        flat_input_ids,
        position_ids=flat_position_ids,
        token_type_ids=flat_token_type_ids,
        attention_mask=flat_attention_mask,
        head_mask=head_mask,
    )
    pooled_output = outputs[1]
    pooled_output = pooled_output.reshape((len(input_ids), num_choices, -1))
    return pooled_output

def convert_dataset_to_frozencls(model, dataset, task_type, batch_size=1, device=None, model_type=None):
    if dataset[0][0].dtype == torch.float32:
        # Already looks like frozencls (TODO: Make this check more robust)
        logger.warning('convert_dataset_to_frozencls: dataset already looks embedded, so doing nothing')
        return dataset

    if device is None:
        device = torch.device('cpu')

    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    new_examples = []
    was_training = model.training
    model.eval()
    for batch in tqdm.tqdm(dataloader, desc='pre-cache'):
        batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            input_ids = batch[0]
            labels = batch[3].cpu()
            dataset_idxs = batch[4].cpu()
            pooled_output = None
            if task_type == TaskType.MULTIPLE_CHOICE:
                pooled_output = getcls_model_batch_multichoice(model_type, model, batch).cpu()
            elif task_type == TaskType.TEXT_CLASSIFICATION:
                pooled_output = getcls_model_batch_textcls(model_type, model, batch).cpu()
            else:
                raise ValueError("Unknown task type")
            for i in range(len(input_ids)):
                new_examples.append((pooled_output[i], labels[i], torch.tensor(len(input_ids[i].nonzero(as_tuple=True)[0])).cpu(), dataset_idxs[i]))
    newex_input_tensor = torch.stack([x[0] for x in new_examples])
    newex_label_tensor = torch.stack([x[1] for x in new_examples])
    newex_dataset_idxs_tensor = torch.stack([x[3] for x in new_examples])

    new_dataset = torch.utils.data.TensorDataset(newex_input_tensor, newex_label_tensor, newex_dataset_idxs_tensor)

    model.train(was_training)
    return new_dataset

def main(args):
    print('Loading data...')
    aug_data = torch.load(args.input_augmented_data)

    json_data = None
    with open(args.input_datafile) as f:
        json_data = [json.loads(line) for line in f]

    print('Processing...')
    json_keys = set(make_identifier(ex) for ex in json_data)
    new_aug_data = [ex for ex in aug_data if make_identifier(ex) in json_keys]

    print('Saving...')
    torch.save(new_aug_data, args.output_datafile)

def init_tokenizer(args):
    _, _, tokenizer_class = MODEL_CLASSES[get_task_type(args.task_name)][args.model_type]
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    return tokenizer

def init_model(args, label_list):
    config_class, model_class, tokenizer_class = MODEL_CLASSES[get_task_type(args.task_name)][args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=len(label_list),
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = init_tokenizer(args)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    model.to(args.device)

    return model, tokenizer, config

def examples_to_dataset(args, examples, label_list, tokenizer):
    logger.info("Training number: %s", str(len(examples)))
    return utils_common.examples_to_dataset(args.task_name, args.model_type, args.max_seq_length, examples, label_list, tokenizer)

def embed_examples_clsfrozen_from_args(args, examples, label_list):
    model, tokenizer, config = init_model(args, label_list)
    dataset = examples_to_dataset(args, examples, label_list, tokenizer)
    task_type = get_task_type(args.task_name)

    return embed_examples_clsfrozen(model, dataset, task_type, batch_size=8, device=args.device, model_type=args.model_type)

def embed_examples_clsfrozen(model, dataset, task_type, **kwargs):
    frozencls_dataset = convert_dataset_to_frozencls(model, dataset, task_type, **kwargs)

    output = [record[0] for record in frozencls_dataset]
    if task_type == TaskType.MULTIPLE_CHOICE:
        output = [clsvec.mean(dim=0) for clsvec in output]

    return output

def embed_examples_dense_surprisal(model, dataset, task_type, **kwargs):
    surprisal_dataset = convert_dataset_to_dense_surprisal(model, dataset, task_type, **kwargs)

    output = None
    text_lens = None
    if task_type == TaskType.MULTIPLE_CHOICE:
        output = []
        text_lens = []
        #avg_mask = torch.zeros(surprisal_dataset[0][0].shape[1], dtype=torch.float32)
        for record in tqdm.tqdm(surprisal_dataset, desc='embed_examples'):
            vecs = record[0]
            text_len = record[3]
            attnmasks = record[4]
            avg_mask = attnmasks.sum(dim=0)
            avg_mask[avg_mask==0] = 1
            output.append(vecs.sum(dim=0)/avg_mask)
            text_lens.append(max(text_len).item())
    elif task_type == TaskType.TEXT_CLASSIFICATION:
        output = [record[0] for record in surprisal_dataset]
        text_lens = [record[3] for record in surprisal_dataset]
    else:
        raise ValueError("Unknown task type")

    return torch.vstack(output), text_lens

def convert_dataset_to_dense_surprisal(model, dataset, task_type, batch_size=1, device=None, model_type=None):
    """Returns a matrix of dense surprisal embeddings for the dataset.  The
    "surprisal" embeddings have dimension equal to the input sequence length
    and represent the per-token MLM loss (the loss is computed without masking
    the input tokens).  They are called "dense" here because the values for all
    tokens are included (the ALPS method that uses them only uses a random
    subset of tokens)"""

    if dataset[0][0].dtype == torch.float32:
        # Already looks embedded (TODO: Make this check more robust)
        logger.warning('convert_dataset_to_dense_surprisal: dataset already looks embedded, so doing nothing')
        return dataset

    if not hasattr(model, 'lm_head'):
        logger.warning('convert_dataset_to_dense_surprisal expects model to be a pretrained MLM, but no "lm_head" attribute was found on this one')

    if device is None:
        device = torch.device('cpu')

    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    new_examples = []
    was_training = model.training
    model.eval()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=-100)
    for batch in tqdm.tqdm(dataloader, desc='pre-cache surprisals'):
        batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            input_ids = batch[0]
            attention_mask = batch[1]
            dataset_idxs = batch[4].cpu()
            pooled_output = None

            num_choices = None
            if task_type == TaskType.MULTIPLE_CHOICE:
                num_choices = input_ids.shape[1]
                input_ids = input_ids.view(-1, input_ids.size(-1))
                attention_mask = attention_mask.view(-1, attention_mask.size(-1))
            elif task_type != TaskType.TEXT_CLASSIFICATION:
                raise ValueError("Unknown task type")

            logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
            # Apparently CrossEntropyLoss wants the seq dimension last
            logits = logits.transpose(-1, -2)

            labels = input_ids.clone().detach()
            labels[attention_mask==0] = -100
            surprisals = loss_fn(logits, labels).cpu().detach()
            if task_type == TaskType.MULTIPLE_CHOICE:
                input_ids = input_ids.view(-1, num_choices, input_ids.size(-1))
                attention_mask = attention_mask.view(-1, num_choices, attention_mask.size(-1)).cpu()
                surprisals = surprisals.view(-1, num_choices, surprisals.size(-1))

            text_lengths = attention_mask.sum(dim=-1)
            for i in range(len(input_ids)):
                new_examples.append((surprisals[i], batch[3][i], torch.tensor(len(input_ids[i].nonzero(as_tuple=True)[0])).cpu(), dataset_idxs[i], text_lengths[i], attention_mask[i]))

    newex_input_tensor = torch.stack([x[0] for x in new_examples])
    newex_label_tensor = torch.stack([x[1] for x in new_examples])
    newex_dataset_idxs_tensor = torch.stack([x[3] for x in new_examples])
    newex_text_length_tensor = torch.stack([x[4] for x in new_examples])
    newex_attnmask_tensor = torch.stack([x[5] for x in new_examples])

    new_dataset = torch.utils.data.TensorDataset(newex_input_tensor, newex_label_tensor, newex_dataset_idxs_tensor, newex_text_length_tensor, newex_attnmask_tensor)

    model.train(was_training)
    return new_dataset

def example_to_string(example, task_type):
    text = None
    if task_type == TaskType.MULTIPLE_CHOICE:
        context = example.contexts[0]
        assert all(ctx == context for ctx in example.contexts)
        question = example.question
        endings = ' '.join(example.endings)
        text = f'{context} {question} {endings}'
    elif task_type == TaskType.TEXT_CLASSIFICATION:
        text = example.text
    else:
        raise ValueError("Unknown task type")

    return text

def embed_examples_modelwordembed(args, examples, label_list):
    task_type = get_task_type(args.task_name)
    if task_type != TaskType.MULTIPLE_CHOICE:
        raise ValueError("roberta_word_avg only supported for MULTIPLE_CHOICE currently")

    model, tokenizer, config = init_model(args, label_list)

    embeddings = []
    with torch.no_grad():
        for example in tqdm.tqdm(examples, desc='embed examples'):
            context = example.contexts[0]
            assert all(ctx == context for ctx in example.contexts)
            question = example.question
            endings = ' '.join(example.endings)
            text = f'{context} {question} {endings}'

            input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)), dtype=torch.long).to(device)
            embeddings.append(model.roberta.embeddings.word_embeddings(input_ids.unsqueeze(0)).mean(dim=1)[0].cpu())


    return embeddings

def embed_examples_tokenids(args, examples, label_list):
    task_type = get_task_type(args.task_name)

    non_alphanum_regex = re.compile(r'[^a-zA-Z0-9]+')

    embeddings = []
    for example in tqdm.tqdm(examples, desc='embed examples'):
        text = example_to_string(example, task_type)

        tokens = [tok for tok in non_alphanum_regex.split(text.lower()) if tok != '']
        embeddings.append(tokens)

    return embeddings

def embed_examples_tfidf(args, examples, label_list):
    task_type = get_task_type(args.task_name)

    embeddings = []
    texts = [example_to_string(example, task_type) for example in examples]
    print('Fitting tfidf...')
    vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=True, use_idf=True, min_df=3, norm='l2')
    sparse_vectors = vectorizer.fit_transform(texts)
    embeddings = torch.tensor(sparse_vectors.toarray(), dtype=torch.float32)

    return embeddings


def main2(args):
    #set_seed(args)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    processor = get_task_processor(args.task_name)

    label_list = processor.get_labels()
    examples = processor.get_examples(args.input_datafile, datasplit='train')
    examples_jsonl = read_jsonl(args.input_datafile)

    embeddings = None
    if args.embedding_method == 'roberta_cls':
        embeddings = embed_examples_clsfrozen_from_args(args, examples, label_list)
    elif args.embedding_method == 'roberta_word_avg':
        embeddings = embed_examples_modelwordembed(args, examples, label_list)
    elif args.embedding_method == 'simple_tokens':
        embeddings = embed_examples_tokenids(args, examples, label_list)
    elif args.embedding_method == 'tfidf':
        embeddings = embed_examples_tfidf(args, examples, label_list)
    else:
        raise ValueError("Invalid embedding method")

    #train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, datasplit='train')

    for idx, example in enumerate(examples):
        examples_jsonl[idx]['roberta_cls'] = embeddings[idx]

    print('Saving...')
    torch.save(examples_jsonl, args.output_datafile)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    #parser.add_argument('--input_augmented_data', required=True, help='Full augmented data file to load in torch format.')
    parser.add_argument('--embedding_method', required=True, choices=['roberta_word_avg', 'roberta_cls', 'simple_tokens', 'tfidf'], help='How to generate embeddings')
    parser.add_argument('--input_datafile', required=True, help='JSONL file with examples to augment from the augmented data file')
    parser.add_argument('--output_datafile', required=True, help='Path to save the augmented data in torch format')
    parser.add_argument('--model_name_or_path', default=None, help='huggingface model name or path (required for model-based embedding)')
    parser.add_argument('--cache_dir', default='', help='huggingface model name or path (required for model-based embedding)')
    parser.add_argument('--task_name', default='swag', help='task name for type of examples')
    parser.add_argument('--model_type', default='roberta', help='type of model')
    parser.add_argument("--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.")
    parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.")
    #parser.add_argument( "--seed", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.")
    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.model_type = args.model_type.lower()
    main2(args)

