import torch
import argparse

from load_model import load_model
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import torch.nn.functional as F
import sampling
import os
from ipdb import set_trace as debug

@torch.no_grad()
def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--model_path", default="louaaron/sedd-medium", type=str)
    parser.add_argument("--name_dir", type=str)
    parser.add_argument("--specific_checkpoint", default=None, type=str)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--sampling_batch_size", type=int, default=1)
    parser.add_argument("--steps", type=int, default=1024)
    args = parser.parse_args()

    
    device = torch.device('cuda')
    model, graph, noise = load_model(args.model_path, device, args.specific_checkpoint)
    model.eval(), noise.eval()
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
    os.makedirs(args.name_dir, exist_ok=True)
    file_name_perplexity = os.path.join(args.name_dir, f"perplexity_samples_inference_steps{args.steps}.txt")
    file_name_entropy = os.path.join(args.name_dir, f"entropy_samples_inference_steps{args.steps}.txt")
    file_name = os.path.join(args.name_dir, f"samples_inference_steps{args.steps}.txt")

    sampling_fn = sampling.get_pc_sampler(
        graph, noise, (args.sampling_batch_size, 1024), 'analytic', args.steps, device=device
    )

    total_perplexity = 0
    total_entropy = torch.tensor(0.0, device=device) 
    batches = args.batch_size // args.sampling_batch_size
    text_samples = []

    for _ in range(batches):
        samples = sampling_fn(model)
        text_samples.extend(tokenizer.batch_decode(samples))
        
        eval_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device).eval()
        
        loss, logits = eval_model(samples, labels=samples)[:2]
        logits = logits.transpose(-1, -2)
        perplexity = F.cross_entropy(logits[..., :-1], samples[..., 1:], reduction="none").mean(dim=-1).exp().mean()
        total_perplexity += perplexity
        for i in samples:
            _, counts = torch.unique(i, return_counts=True, sorted=False)
            entropy = torch.special.entr(counts.float() / counts.sum()).sum().item()
            total_entropy += entropy
        total_entropy /= samples.shape[0]

    total_perplexity /= batches
    total_entropy /= batches

    with open(file_name_perplexity, 'w') as file:
        file.write(f"Generative Perplexity: {total_perplexity}.")

    with open(file_name_entropy, 'w') as file:
        file.write(f"Entropy: {total_entropy}.")
    
    with open(file_name, 'w') as file:
        for sentence in text_samples:
            file.write(sentence + "\n")
            file.write("============================================================================================\n")

       

if __name__=="__main__":
    main()