import torch
import argparse

from load_model import load_model
from transformers import GPT2TokenizerFast
import torch.nn.functional as F
import sampling
import numpy as np
import json
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers
from transformers import PreTrainedTokenizerFast


def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--model_path", default="exp_local/openwebtext/x", type=str)
    parser.add_argument("--visualize", default=False, type=bool)
    parser.add_argument("--dataset", default="wikitext103", type=str)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--steps", type=int, default=1024)
    args = parser.parse_args()

    
    device = torch.device('cuda')
    model, graph, noise = load_model(args.model_path, device)


    custom_chars = [
        ' ', 'e', 't', 'o', 'a', 'h', 'n', 's', 'r', 'i', 'l', 'd', '\n', 'u', 'm',
        'y', ',', '.', 'w', 'f', 'c', 'g', 'I', 'p', 'b', 'A', 'E', 'T', 'v', 'S',
        'O', "'", 'k', 'R', 'N', 'L', 'C', 'H', ';', 'W', 'M', 'B', 'D', 'U', 'F',
        'G', 'P', '?', 'Y', '!', '-', 'K', 'x', 'V', 'j', 'q', '[', ']', 'J', ':',
        'Q', 'z', '9', '1', '(', ')', 'Z', 'X', '<', '"', '>', '2', '3', '0', '4',
        '5', '_', '6', '7', '8', '|', '&', '}', '`'
    ]

    vocab_dict = {char: idx for idx, char in enumerate(custom_chars)}

    vocab_dict["<pad>"] = len(vocab_dict)
    vocab_dict["<unk>"] = len(vocab_dict)

    with open("char_vocab.json", "w", encoding="utf-8") as f:
        json.dump(vocab_dict, f)

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=Tokenizer(models.WordLevel(vocab_dict, unk_token="<unk>")),
        unk_token="<unk>",
        pad_token="<pad>",
    )

    tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Split("", behavior="isolated")

    encoding = tokenizer(''.join(custom_chars), add_special_tokens=False)
    print(encoding.input_ids)

    id_to_char = {v: k for k, v in vocab_dict.items()}

    def custom_decoder(token_ids):

        return ''.join(id_to_char.get(token_id, '<unk>') for token_id in token_ids)

    tokenizer.custom_decoder = custom_decoder
    print(tokenizer.custom_decoder(encoding.input_ids))


    loop_max = 1000
    jumps = []

    for loop in range(loop_max):
        print(loop)

        sampling_fn = sampling.get_pc_sampler(
            graph, noise, (args.batch_size, 128), 'euler', args.steps, args.visualize, device=device
        )

        samples, switch_hist = sampling_fn(model)
        jumps.append(switch_hist)
        print('Average Jumps:', np.array(jumps).mean())
        print('Std Jumps:', np.array(jumps).std()/np.sqrt((0*np.array(jumps)+1).sum()))
        #print('hist:', switch_hist)


if __name__=="__main__":
    main()