import json
from typing import Optional, Dict, Any
import torch
import os
import numpy as np
from ...utils.distributed_ops import reduce_any
import torch.nn.functional as F
from ... import data_structures, utils


class LambadaTest:
    SUPPORTS_DISTRIBUTED = True

    def __init__(self, data, vocabulary: data_structures.vocabulary.Vocabulary, batch_dim: int = 1):
        self.batch_dim = batch_dim
        self.time_dim = 1 - self.batch_dim
        self.vocabulary = vocabulary
        # This copy of data is to avoid detokenizing every single time.
        self.data = data
        if self.batch_dim != 1:
            raise NotImplementedError("Batch dim must be 1")

        self.n_ok = 0
        self.n_tok_ok = 0
        self.n_total = 0
        self.loss_total = 0
        self.lm_loss = 0
        self.n_total_tokens = 0

    def step(self, net_out_logits: torch.Tensor, data: Dict[str, torch.Tensor]):
        last_words = [self.data[int(i)]["text"].split(" ")[-1] for i in data["index"].cpu().numpy().tolist()]
        net_out = net_out_logits.argmax(-1)
        for i in range(net_out.shape[self.batch_dim]):
            in_l = int(data["in_len"][i])
            wlen = len(last_words[i])

            out_seq = net_out[in_l - 1 - wlen: in_l - 1, i].cpu().numpy().tolist()
            out_seq = [int(i) for i in out_seq]
            detok = self.vocabulary.to_string(out_seq)
            last_predicted = detok.split(" ")[-1]

            last_word = self.vocabulary.sentence_to_indices(" "+last_words[i])
            last_tok = last_word[-1]
            self.n_tok_ok += int(last_tok == net_out[in_l - 2, i])

            out_end = net_out_logits[in_l - 1 - len(last_word):in_l - 1, i]
            loss = F.cross_entropy(out_end, torch.tensor(last_word, device=out_end.device, dtype=torch.long))

            self.loss_total += loss.cpu().item()

            self.n_ok += int(last_predicted == last_words[i])
            self.n_total += 1

        target = data["data"][1:].contiguous()
        target = target.masked_fill(torch.arange(target.size(0), device=target.device)[:, None] >= (data["in_len"][None] - 1), -100)
        self.lm_loss += F.cross_entropy(net_out_logits.flatten(end_dim=-2), target.flatten().long(), ignore_index=-100, reduction="sum").item()
        self.n_total_tokens += (data["in_len"] - 1).sum().item()


    @property
    def accuracy(self):
        return reduce_any(self.n_ok) / reduce_any(self.n_total)

    def plot(self) -> Dict[str, Any]:
        lm_loss = reduce_any(self.lm_loss) / reduce_any(self.n_total_tokens)

        return {
            "accuracy/total": self.accuracy,
            "accuracy/openai_last_token": reduce_any(self.n_tok_ok) / reduce_any(self.n_total),
            "perplexity": np.exp(self.loss_total / self.n_total),
            "lm_loss": lm_loss,
            "lm_perplexity": np.exp(lm_loss)
        }

class Lambada:
    URL = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"

    def __init__(self, vocabulary: data_structures.vocabulary.Vocabulary,
                 cache_dir: str = "./cache", sep: Optional[str] = None) -> None:
        self.vocabulary = vocabulary
        self.sep = sep

        if len(self.vocabulary) <= 256:
            self.dtype = np.uint8
        if len(self.vocabulary) < 32768:
            self.dtype = np.int16
        else:
            self.dtype = np.int32

        self.cache_dir = f"{cache_dir}/{self.__class__.__name__}/"
        os.makedirs(self.cache_dir, exist_ok=True)

        in_file = f"{self.cache_dir}/lambada_test.jsonl"
        with utils.LockFile(self.cache_dir+"lock"):
            if not os.path.isfile(in_file):
                utils.download(self.URL, self.cache_dir, ignore_if_exists=True)

        with open(in_file, "r") as f:
            self.data = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        d = self.data[idx]["text"]
        if self.sep is not None:
            d = self.sep + d
        d = self.vocabulary.sentence_to_indices(d)
        return {
            "data": np.array(d, dtype=self.dtype),
            "in_len": len(d),
            "index": idx
        }

    def start_test(self):
        return LambadaTest(self.data, vocabulary=self.vocabulary)


class Lambada1024(Lambada):
    def __getitem__(self, idx):
        d = filler_text + self.data[idx]["text"]
        if self.sep is not None:
            d = self.sep + d
        d = self.vocabulary.sentence_to_indices(d)
        return {
            "data": np.array(d, dtype=self.dtype)[-1024:],
            "in_len": 1024,
            "index": idx
        }




filler_text = """In the early hours of a misty morning, the town of Eldervale awakened slowly beneath a soft veil of fog. The ancient cobblestone streets, lined with quaint lampposts and flowering window boxes, echoed the gentle murmurs of residents beginning their daily rituals. Neighbors greeted one another with warm smiles, their voices carrying stories of days gone by. Every corner of the town whispered hints of history, from the venerable clock tower that marked the passage of time to the ivy-clad facades of old inns. In this quiet haven, nature and human endeavor coexisted harmoniously, blending modern aspirations with traditions rooted in centuries of collective memory. The air, fresh and invigorating, filled the heart with both anticipation and calm, as if the day itself promised an unfolding narrative of subtle adventures and familiar comforts. Every stone in the pavement seemed to carry a memory, and the soft chatter of early market sellers intermingled with the rustling leaves overhead. The horizon, a tapestry of pale blues and gentle oranges, hinted at the possibilities that awaited as the day advanced. Even the quiet murmur of the nearby river contributed to the serene symphony that defined the town’s character. In the heart of Eldervale, life moved with deliberate grace, as residents balanced the demands of modernity with the comfort of longstanding traditions. The market square bustled with activity, where vendors displayed an array of colorful produce and handcrafted wares. Children laughed and played near the fountain, their carefree spirits infusing the scene with vitality. Meanwhile, elders sat on benches beneath ancient oak trees, recounting tales of yesteryears with voices rich in wisdom and experience. The architecture of the town told a story of generations; weathered bricks and ornate carvings celebrated both the passage of time and the resilience of the community. In every conversation, there was a sense of belonging and shared destiny, as if the fabric of life was woven with threads of history, hope, and the simple joys of everyday existence. A gentle breeze carried the scent of freshly baked bread and blooming flowers, binding the senses to the charm of a life well-lived. Beyond the boundaries of the town, the countryside unfolded in a mosaic of vibrant landscapes and quiet solitude. Rolling meadows, dotted with wildflowers and ancient stone markers, stretched toward distant mountains that touched the sky. Rustic farms and winding country roads told stories of toil and perseverance, where nature’s bounty was both a gift and a challenge. In these open spaces, time appeared to slow, inviting travelers to pause and reflect on the simple elegance of existence. Birds soared on gentle breezes while the earth, rich with history and nourished by countless seasons, whispered secrets of renewal and growth. Each element of the rural scene, from the glistening dew on blades of grass to the sturdy silhouette of an old barn, contributed to a harmonious portrait of life that balanced the quiet rhythms of nature with the enduring spirit of human endeavor. Within the vibrant tapestry of urban and rural life, artistic expression blossomed as a tribute to both heritage and innovation. In cozy cafes and lively galleries, painters, poets, and musicians gathered to share their visions, infusing the air with creative energy and thoughtful dialogue. Murals on brick walls captured the essence of the town’s soul, blending abstract forms with recognizable landscapes that echoed the past. The sound of a violin, echoing softly through a narrow alley, interwove with the rhythmic pulse of footsteps and whispered conversations. Creativity was celebrated not as an isolated act but as a communal ritual, where every brushstroke and note resonated with a deep sense of purpose. In every artful creation, there was a fusion of tradition and modernity, a dance of colors and ideas that transcended the boundaries of time and place, inviting all who witnessed it to experience a moment of shared wonder. As the day progressed, the interplay of light and shadow transformed the landscape into a living canvas. In sun-dappled courtyards and along meandering paths, individuals found solace in quiet contemplation and spontaneous encounters. The rustle of leaves, the distant hum of conversation, and the soft clinking of porcelain in a busy tea room all converged to create an atmosphere of serene introspection. Amid these moments, the boundaries between the ordinary and the extraordinary blurred, allowing a sense of magic to infuse even the most mundane experiences. Artists captured these ephemeral instants with words and images, striving to preserve the fleeting beauty of a transient world. Each moment was a delicate balance of warmth and melancholy, a reflection of the intricate dance between human emotion and the ever-changing rhythms of nature. As twilight descended upon the horizon, the town and countryside alike embraced a quiet transformation. The fading light bathed buildings and fields in a gentle glow, evoking feelings of nostalgia and contemplation. In the cool of the evening, families gathered in communal spaces, sharing stories and dreams beneath a tapestry of stars. The atmosphere was imbued with a reflective quality, as if the world itself paused to honor the ephemeral beauty of each passing moment. In the soft murmur of nighttime conversations and the rhythmic chirping of crickets, one could sense an underlying promise of renewal. Every whispered word, every shared glance, and every silent smile carried the weight of memories and the hope of tomorrow. In this serene interlude between day and night, the essence of life was celebrated in its quiet, profound simplicity. In the final moments before the deep embrace of night, the world seemed to hold its breath in quiet anticipation. The enduring pulse of life, both in bustling streets and secluded fields, resonated with the promise of a new day. Each moment was a silent testament to the beauty of existence, a gentle reminder that every ending carried the seed of a fresh beginning. As stars emerged one by one, their light mingling with the lingering warmth of dusk, the eternal cycle of hope and renewal was unmistakably affirmed. Embracing the tranquil energy of the twilight, the soul found solace in the profound unity of nature and human aspiration, gently bridging the gap between yesterday and tomorrow."""