"""
Sample from a trained model
"""
import os
from contextlib import nullcontext
import torch
import numpy as np
import io

from alpha_integrate.train.model import GPTConfig, GPT
from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.tokenization import tokenize, WORD2ID, PAD, END, detokenize
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer

e = ExpressionTokenizer()

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'alpha_integrate/train/out-integrate' # ignored if init_from is not 'resume'
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1000 # number of samples to draw
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 10 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
exec(open('alpha_integrate/train/configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

PATH = 'alpha_integrate/synthetic_data/final_steps_dataset/prim_ibp/test.txt'
with io.open(PATH, mode='r', encoding='utf-8') as f:
    lines = [line for line in f if not(line == '\n' or line == '')]

PAD_TOKEN = WORD2ID[PAD]
VOCAB_SIZE = len(WORD2ID)
END_TOKEN = WORD2ID[END]

def generate_beginning(block_size):
    done = False
    while not done:
        idx = np.random.randint(0, len(lines))
        line = lines[idx]
        tokenized, start_id, end_id = tokenize(line) 
        if len(tokenized) < block_size :
            done = True

    return tokenized[:start_id+1], tokenized[start_id+1:end_id]

# get sample beginnings
start_ids = []
answers = []
for _ in range(num_samples):
    start_id, answer = generate_beginning(block_size)
    start_ids.append(start_id)
    answers.append(answer)

x = [torch.tensor(start_ids[i], dtype=torch.long, device=device)[None, ...] for i in range(num_samples)]
#x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

BEAM_SIZE = 10
sum_p = 0.95

'''
y = model.generate_beam(x[0][0], max_new_tokens, BEAM_SIZE, END_TOKEN, PAD_TOKEN, temperature=temperature, top_k=top_k)
print(y)
raise Exception
'''
#y = model.nucleus_sampling(x[0][0], max_new_tokens, END_TOKEN, sum_p = sum_p, temperature=temperature, top_k=top_k)
#raise Exception


# run generation
correct = 0
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            question = x[k][0].tolist()[:-1]
            expression = question[1:]
            detokenized_expression = detokenize(expression)
            #print(detokenized_expression)
            sympy_expression = e.seq_to_sp(detokenized_expression)
            print(f"Question: {sympy_expression}")
            answer = answers[k]
            #print(f"Answer: {detokenize(answer)}")
            samples = model.generate_beam(x[k][0], block_size - len(question), BEAM_SIZE, END_TOKEN, PAD_TOKEN, temperature=temperature, top_k=top_k).cpu()
            #samples = model.nucleus_sampling(x[k][0], len(x[k][0]) + 1, END_TOKEN, BEAM_SIZE, sum_p = sum_p, temperature=temperature, top_k=top_k)
            #print(f"Question Length: {len(x[k][0])}")
            #print(f"Sample Shape: {samples.shape}")
            #print(question)
            #print(samples)
            '''
            print(detokenize(question))
            print('---------------')
            print(detokenize(samples.cpu().tolist()[0]))
            print('---------------')
            print(detokenize(answer))
            raise Exception
            '''

            for i, y in enumerate(samples):
                sample = y.tolist()[len(question)+1:]
                try:
                    endtoken = sample.index(END_TOKEN)
                except ValueError:
                    continue
                sample = sample[:endtoken+1]
                print(f"Question tokens: {detokenize(question)}")
                print(f"Sample {i}: {detokenize(sample)}")

                
                if len(sample) != len(answer):
                    continue
                if all([sample[i] == answer[i] for i in range(len(sample))]):
                    correct += 1
                    break


            print(f"Accuracy: {(correct/(k+1)):.3%}, Num Samples: {k+1}   ", end = '\r')

print(f"Accuracy: {(correct/num_samples):.3%}, Num Samples: {num_samples}")