import os
import re
import json
import random
import time
import itertools

import torch as t
import torch.nn.functional as F

from dataclasses import dataclass




def measure_execution_time(experiment_fn, env_dict):
    
    # start time measurement
    start_time = time.time()  
    
    # execute function with specific args
    _ = experiment_fn(env_dict) 
    
    # stop time measurement
    end_time = time.time()  
    
    # compute execution time
    execution_time = end_time - start_time  
    return execution_time


def logit_difference(logits, correct_labels, wrong_labels, return_one_element: bool=True) -> t.Tensor:
    correct_logits = logits[:, -1, correct_labels]
    incorrect_logits = logits[:, -1, wrong_labels]

    if return_one_element:
        return -(correct_logits.mean() - incorrect_logits.mean())
    else:
        return -(correct_logits - incorrect_logits).view(-1)


def load_batch(dataset, num_examples, tokenizer, seed=12, pad_to_length=True, length=64):
    dataset_items = open(dataset).readlines()
    random.seed(seed)
    random.shuffle(dataset_items)
    
    # define output format
    examples_batch = dict(
        clean_prefix_batch = t.zeros((num_examples, length), dtype=t.long),
        patch_prefix_batch = t.zeros((num_examples, length), dtype=t.long),
        clean_answer_batch = t.zeros((num_examples, 1), dtype=t.long),
        patch_answer_batch = t.zeros((num_examples, 1), dtype=t.long)
    )

    i = 0
    for line in dataset_items:
        data = json.loads(line)

        # tokenize sample
        clean_prefix = tokenizer(
            data["clean_prefix"], return_tensors="pt", padding=False).input_ids
        patch_prefix = tokenizer(
            data["patch_prefix"], return_tensors="pt", padding=False).input_ids
        clean_answer = tokenizer(
            data["clean_answer"], return_tensors="pt", padding=False).input_ids
        patch_answer = tokenizer(
            data["patch_answer"], return_tensors="pt", padding=False).input_ids

        # make sure clean and patch input are the same length
        # if clean_prefix.shape[1] != patch_prefix.shape[1]:
        #     continue
        
        # optional: left-sided padding of inputs to a specified length
        clean_prefix_length_wo_pad = clean_prefix.shape[1]
        patch_prefix_length_wo_pad = patch_prefix.shape[1]
        if pad_to_length:
            tokenizer.padding_side = 'right'
            clean_pad_length = length - clean_prefix_length_wo_pad
            patch_pad_length = length - patch_prefix_length_wo_pad
            if clean_pad_length < 0 or patch_pad_length < 0:  # example too long
                raise Exception
            # left padding: reverse, right-pad, reverse
            clean_prefix = t.flip(F.pad(t.flip(clean_prefix, (1,)), (0, clean_pad_length), value=tokenizer.pad_token_id), (1,))
            patch_prefix = t.flip(F.pad(t.flip(patch_prefix, (1,)), (0, patch_pad_length), value=tokenizer.pad_token_id), (1,))
        
        # optional: filter examples that are not of the specified length
        if length and clean_prefix.shape[1] != length:
            continue
        
        # make sure clean and patch answer are a single token
        # note: usually we would ignore these samples, but here we don't care whether they are meaningful,
        #   but just want to run all models with the same number of tokens
        if clean_answer.shape[1] != 1:
            clean_answer = clean_answer[:, 0]
        if patch_answer.shape[1] != 1:
            patch_answer = patch_answer[:, 0]

        # assign sample to batch
        examples_batch["clean_prefix_batch"][i] = clean_prefix
        examples_batch["patch_prefix_batch"][i] = patch_prefix
        examples_batch["clean_answer_batch"][i] = clean_answer
        examples_batch["patch_answer_batch"][i] = patch_answer
        i += 1

        if i >= num_examples:
            break
    
    # make sure we get enough samples
    if i < num_examples:
        raise Exception

    return examples_batch


