import torch
from config import config
from llmtelora import Llmtelora, build_nanogpt_config
from data import Data
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import time
import argparse

tiny_story = """
    Once upon a time, in a quaint village nestled between rolling hills and dense forests, there lived a curious young girl named Lily. She had always been fascinated by the old stories her grandmother told her about magical creatures that supposedly inhabited the nearby woods. Despite warnings from the villagers to stay away, Lily couldn't resist the allure of adventure.
One crisp autumn morning, as golden leaves danced in the breeze, Lily decided to explore the forest. She packed a small bag with some bread, cheese, and her favorite book of fairy tales. With a mix of excitement and trepidation, she set off down the winding path that led into the heart of the woods.
As she ventured deeper into the forest, the trees grew taller and closer together, their branches forming a canopy that filtered the sunlight into dappled patterns on the forest floor. Lily marveled at the beauty around her – the vibrant mushrooms, the delicate wildflowers, and the moss-covered stones that seemed to whisper ancient secrets.
Suddenly, she heard a faint melody carried on the wind. Intrigued, Lily followed the sound, her heart racing with anticipation. As she pushed through a thick curtain of vines, she stumbled into a hidden clearing bathed in ethereal light. There, to her amazement, she saw a circle of tiny, glowing beings dancing in the air.
They were fairies, just like in her grandmother's stories! Their wings shimmered with iridescent colors, and they left trails of sparkling dust as they flitted about. Lily gasped in wonder, alerting the fairies to her presence. At first, they scattered in alarm, but then one brave fairy approached her.
"Hello, human child," the fairy said in a voice that sounded like tinkling bells. "What brings you to our sacred grove?"
Lily, overcoming her initial shock, replied, "I'm Lily, and I've always dreamed of meeting magical creatures like you. I mean no harm; I only wish to learn and maybe become friends."
The fairy, intrigued by Lily's sincerity, introduced herself as Thistle. Soon, the other fairies gathered around, their curiosity overcoming their fear. They began to share stories of their magical world with Lily, telling her about the different creatures that inhabited the forest and the ancient magic that flowed through the land.
As the day wore on, Lily realized she needed to return home before her parents worried. The fairies, having grown fond of their new human friend, gifted her with a small, enchanted acorn. "Keep this with you," Thistle explained, "and you'll always be able to find your way back to us."
Lily thanked her new friends and promised to visit again soon. As she made her way home, her mind whirled with all she had learned. She couldn't wait to tell her grandmother about her adventure, knowing that the old woman would believe her when no one else would.
From that day forward, Lily became a bridge between the human world and the magical realm of the forest. She visited her fairy friends often, learning their ways and helping to protect their home from those who might harm it. As she grew older, Lily became known in her village as a wise woman, respected for her knowledge of herbs and her uncanny ability to solve problems.
Years passed, and Lily, now a grandmother herself, would sit by the fireplace and tell her own grandchildren about the wonders of the magical forest. She would show them the enchanted acorn, still glowing with fairy magic after all this time, and encourage them to keep their hearts open to the magic that exists in the world around them.
And so, the cycle continued, with each generation discovering anew the magic and wonder that lies just beyond the edge of the ordinary world, waiting for those curious and brave enough to seek it out.
"""

class LlmteloraDemo(Llmtelora):

    def run_demos(self, prompts, output_length=200, temperature=1.0):
        matched_tokens_stats = None
        start_time = time.time()
        for i, prompt in tqdm(enumerate(prompts, 1), desc='Running demos'):
            matched_tokens_list = np.array(self.show_demo(prompt, output_length, i, temperature))
            if matched_tokens_stats is None:
                matched_tokens_stats = matched_tokens_list
            else:
                matched_tokens_stats = np.concatenate(
                    (matched_tokens_stats, matched_tokens_list)
                )
        end_time = time.time()
        runtime = end_time - start_time
        print(f"Total runtime for all prompts: {runtime:.2f} seconds")

        return matched_tokens_stats

    def show_demo(self, prompt, output_length=200, demo_number=1, temperature=1.0):
        #text_inp = self.data.encode(prompt)
        text_inp = self.data.encode(tiny_story[:2000] + ' ' + prompt)
        output, matched_tokens_list = self.model.generate_w_speculative(
            text_inp, 
            output_length, 
            return_stats=True,
            temperature=temperature
        )
        
        return matched_tokens_list
    

def plot_histogram(data, title, xlabel, ylabel, filename):
    plt.figure()
    plt.hist(data, bins=20, edgecolor='black', density=True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(filename)
    plt.close()

if __name__ == '__main__':
    args = config()

    args.block_size = 1024
    demo = LlmteloraDemo(args)
    demo.set_data()

    # List of prompts to test
    prompts = [
        "Once upon a time there was a",
        "One day, Jane and her",
        "It was a sunny day and Jimmy",
        "One summer morning, a little boy",
        "John and Sarah were playing",
        "One day, he decided to climb",
        "The students were excited to learn",
        "The rain continued for several days",
        "The children were playing in the park",
        "The old man was walking down the street"
    ]

    if args.from_pretrained:
        params = torch.load(args.from_pretrained, map_location=torch.device(args.device))
        print('Params are loaded from pretrained')
    else:
        params = None

    demo.set_model(params)

    matched_tokens_stats = demo.run_demos(prompts, temperature=args.temp)
    uniq, counts = np.unique(matched_tokens_stats, return_counts=True)
    counts = counts / np.sum(counts)

    plot_histogram(
        matched_tokens_stats, 
        f'Histogram of matched tokens (temp={args.temp})', 
        'Matched Tokens', 
        'Frequency', 
        os.path.join(demo.fold, f'matched_tokens_temp_{args.temp}.png')
    )

    print(args, counts)
