import torch
import argparse
import os
from load_model import load_model
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from transformers import AutoModelForCausalLM, AutoTokenizer

import torch.nn.functional as F
import sampling
import numpy as np
import time
import math


def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created.")
    else:
        print(f"Directory '{directory_path}' already exists.")

def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--model_path", default="x", type=str)
    parser.add_argument("--dataset", default="wikitext103", type=str)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--length", type=int, default=128)
    parser.add_argument("--no_batches", type=int, default=1000)
    parser.add_argument("--steps", type=int, default=8)
    parser.add_argument("--mode", type=str, default='gen')
    args = parser.parse_args()

    
    device = torch.device('cuda')
    model, graph, noise = load_model(args.model_path, device)
    sampling_fn = sampling.get_pc_sampler(
        graph, noise, (args.batch_size, args.length), 'euler', args.steps, device=device
    )

    #UNCOMMENT TO TEST USING LLAMA (Hugging face account required). ALSO NEED TO SET args.mode to llama. 

    # model_id = "meta-llama/Llama-3.1-8B"
    # tokenizer_llama = AutoTokenizer.from_pretrained(model_id)
    # tokenizer_llama.pad_token = tokenizer_llama.eos_token

    tokenizer_gpt = GPT2TokenizerFast.from_pretrained('gpt2')
    tokenizer_gpt.pad_token = tokenizer_gpt.eos_token


    with torch.no_grad():
        create_directory_if_not_exists(args.model_path+'_eval/')
        
        if args.mode=='gen':
            all_samples = []
            for i in range(args.no_batches):
                print(i)
                st=time.time()
                s, _ = sampling_fn(model)
                cond = ((s<50257)*1).prod()
                if cond:
                    print('Time to generate a batch:', time.time()-st)
                    s = s.detach().cpu()
                    all_samples.append(s.to(device))
                    file_name = '_eval/' +'totalpoints_'+str(args.no_batches*args.batch_size)+'_steps_'+str(args.steps)+'_length_'+str(args.length)+ '_all_samples.pth'
                    torch.save(torch.stack(all_samples).reshape(-1, args.length), args.model_path+file_name)

        elif args.mode == 'ent':

            test_batch_size = 16
            file_name = (
                '_eval/' + 'totalpoints_' + str(args.no_batches * args.batch_size)
                + '_steps_' + str(args.steps)
                + '_length_' + str(args.length)
                + '_all_samples.pth'
            )
            all_samples = torch.load(args.model_path + file_name)      # [N, length]
            all_samples = torch.chunk(all_samples, all_samples.shape[0] // test_batch_size)

            total_entropy = []          # one H per generated sequence
            i = 0                       # how many mini-batches processed

            for s in all_samples:       # s: [B, length]   where B = test_batch_size
                ids_np = s.cpu().numpy()          # work on CPU in NumPy for speed
                batch_H = []

                # ---------- per-sequence entropy ----------
                #   H = −Σ_k p_k log p_k   (natural log → nats)
                for ids in ids_np:                   # ids: shape (length,)
                    L = ids.shape[0]
                    if L == 0:                       # (shouldn’t happen here)
                        batch_H.append(0.0)
                        continue

                    # Fast histogram with NumPy bincount (counts includes zeros)
                    counts = np.bincount(ids, minlength=ids.max() + 1)
                    probs  = counts[counts > 0] / L   # remove zeros before log
                    H = -(probs * np.log(probs)).sum()   # units = nats
                    batch_H.append(H / np.log(2))        # convert to bits if you prefer

                # ---------- bookkeeping ----------
                i += 1
                total_entropy += batch_H

                #print(i)
                total_np = np.array(total_entropy, dtype=np.float64)
                mean_H   = total_np.mean()
                std_err  = total_np.std() / math.sqrt(i * test_batch_size)

            print('Entropy mean (bits):', mean_H)
            print('Entropy std-error  :', std_err)

            # ---------- save ----------
            np.save(
                args.model_path + '_eval/'
                + 'Entropy_'
                + 'totalpoints_' + str(args.no_batches * args.batch_size)
                + '_steps_'       + str(args.steps)
                + '_length_'      + str(args.length)
                + '.npy',
                np.array(total_entropy, dtype=np.float32)
            )



        elif args.mode=='gpt':
            test_batch_size = 32
            file_name = '_eval/' +'totalpoints_'+str(args.no_batches*args.batch_size)+'_steps_'+str(args.steps)+'_length_'+str(args.length)+ '_all_samples.pth'
            all_samples = torch.load(args.model_path+file_name)
            all_samples = torch.chunk(all_samples, all_samples.shape[0]//test_batch_size)
            total_perplexity = []
            eval_gpt = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device).eval()
            i=0
            for s in all_samples:
                try:
                #GPT

                    _, logits = eval_gpt(s, labels=s)[:2]
                except:
                    print('fail')
                    continue
                    
                i+=1
                print(i)
                perplexity = F.cross_entropy(logits[:, :-1].reshape(-1, logits.size(-1)), s[..., 1:].reshape(-1), reduction='none')
                perplexity = perplexity.reshape(-1, args.length-1).mean(-1).exp().to(torch.float64).cpu().numpy().tolist()
                total_perplexity += perplexity 
                print('GPT2 perplexity mean:', np.array(total_perplexity).mean())
                print('GPT2 perplexity std:', np.array(total_perplexity).std()/torch.sqrt(torch.tensor(i*test_batch_size)))
                np.save(args.model_path+'_eval/'+'GenPerplexity_'+'totalpoints_'+str(args.no_batches*args.batch_size)+'_steps_'+str(args.steps)+'_length_'+str(args.length)+'.npy', np.array(total_perplexity))


        # elif args.mode=='llama':
        #     test_batch_size = 16
        #     file_name = '_eval/' +'totalpoints_'+str(args.no_batches*args.batch_size)+'_steps_'+str(args.steps)+'_length_'+str(args.length)+ '_all_samples.pth'
        #     all_samples = torch.load(args.model_path+file_name)
        #     all_samples = torch.chunk(all_samples, all_samples.shape[0]//test_batch_size)
        #     total_llama_perplexity = []
        #     eval_llama = AutoModelForCausalLM.from_pretrained(
        #         model_id,
        #         torch_dtype=torch.bfloat16).to(device).eval()
        #     with torch.inference_mode():
        #         i=0
        #         for s in all_samples:
        #             #LLAMA
        #             i+=1
        #             print(i)
        #             text_samples = tokenizer_gpt.batch_decode(s)

        #             encoded_inputs = tokenizer_llama(
        #                 text_samples,
        #                 return_tensors="pt",  
        #                 padding=True,         
        #                 #truncation=True       
        #             )
        #             # Extract input_ids and attention_mask
        #             input_ids = encoded_inputs['input_ids'].to(device)
        #             attention_mask = encoded_inputs['attention_mask'].to(device)


        #             # Shift input_ids by one to create labels and set padding tokens to -100
        #             labels = input_ids.clone()
        #             labels[labels == tokenizer_llama.pad_token_id] = 50000 
        #             labels = labels.to(device)

        #             # Get model outputs
        #             outputs = eval_llama(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    
        #             llama_logits = outputs.logits

        #             # Calculate the loss for each token in the batch without reduction
        #             llama_perplexity = F.cross_entropy(llama_logits[:, :-1].reshape(-1, llama_logits.size(-1)), labels[..., 1:].reshape(-1), reduction='none')
        #             llama_perplexity = llama_perplexity.reshape(-1, input_ids.shape[1]-1)
        #             llama_perplexity = llama_perplexity*attention_mask[..., 1:]
        #             llama_perplexity = llama_perplexity.sum(-1)/attention_mask[..., 1:].sum(-1)
                    
        #             llama_perplexity=llama_perplexity.exp().to(torch.float64).cpu().numpy().tolist()

        #             total_llama_perplexity += llama_perplexity 
        #             np.save(args.model_path+'_eval/'+'LlamaGenPerplexity_'+'totalpoints_'+str(args.no_batches*args.batch_size)+'_steps_'+str(args.steps)+'_length_'+str(args.length)+'.npy', np.array(total_llama_perplexity))
        #             del llama_logits, outputs, input_ids, attention_mask, labels
        #             torch.cuda.empty_cache()
        #             print('Llama perplexity mean:', np.array(total_llama_perplexity).mean())
        #             print('Llama perplexity std:', np.array(total_llama_perplexity).std()/torch.sqrt(torch.tensor(i*test_batch_size)))


if __name__=="__main__":
    main()