import pytorch_lightning as pl
from pathlib import Path
import pickle
import numpy as np
from typing import Literal
import requests

import torch
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader

from symo.experiments.models import MLP

Data = tuple[torch.Tensor, torch.Tensor]


def input_moments(
    generator: torch.Generator,
    dim: int,
    min_log_x: float = 0,
    max_log_x: float = -5,
    device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate mean and covariance matrix for input distribution."""

    diag = torch.logspace(min_log_x, max_log_x, dim, device=device)
    mat = torch.randn(dim, dim, generator=generator, device=device)
    q, _ = torch.linalg.qr(mat)
    sigma = (q * diag) @ q.T
    mean = torch.zeros(dim, device=device)

    return mean, sigma


def make_inputs(
    generator: torch.Generator,
    mean: torch.Tensor,
    sigma: torch.Tensor,
    num_points: int,
) -> torch.Tensor:
    """Generate input samples from multivariate normal distribution."""
    dist = torch.distributions.MultivariateNormal(mean, sigma, validate_args=False)
    # Set the generator for reproducibility
    with torch.random.fork_rng(devices=[mean.device] if mean.is_cuda else []):
        if generator is not None:
            torch.manual_seed(generator.initial_seed())
        return dist.sample((num_points,))


def mlp_teacher_data(
    generator: torch.Generator,
    mlp: MLP,
    num_train_points: int,
    num_test_points: int,
    device: torch.device | None = None,
) -> tuple[torch.Generator, tuple[Data, Data]]:
    """Generate training and test data using an MLP as teacher."""
    mean, sigma = input_moments(generator, mlp.input_dim, device=device)
    train_inputs = make_inputs(generator, mean, sigma, num_train_points)
    with torch.no_grad():
        train_outputs = mlp(train_inputs)

    test_inputs = make_inputs(generator, mean, sigma, num_test_points)
    with torch.no_grad():
        test_outputs = mlp(test_inputs)

    return generator, ((train_inputs, train_outputs), (test_inputs, test_outputs))


class ShakespeareDataset(Dataset):
    def __init__(
        self,
        data_dir: str = "./.datasets",
        split: Literal["train", "test"] = "train",
        block_size: int = 256,
    ):
        self.data_dir = data_dir
        self.split = split
        self.block_size = block_size

        Path(data_dir).mkdir(exist_ok=True)

        self.input_file = Path(data_dir, "input.txt")
        self.meta_file = Path(data_dir, "meta.pkl")

        self._prepare_data()

        self.data = self.train_ids if split == "train" else self.val_ids

    def _prepare_data(self):
        """Download and prepare the Shakespeare dataset."""
        if not Path(self.input_file).exists():
            print("[Start] Downloading tiny Shakespeare dataset...")
            data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
            with open(self.input_file, "w", encoding="utf-8") as f:
                f.write(requests.get(data_url).text)
            print("[Done] Downloading tiny Shakespeare dataset.")

        with open(self.input_file, "r", encoding="utf-8") as f:
            data = f.read()

        chars = sorted(list(set(data)))
        self.vocab_size = len(chars)

        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}

        n = len(data)
        train_data = data[: int(n * 0.9)]
        val_data = data[int(n * 0.9) :]

        self.train_ids = np.array(self.encode(train_data), dtype=np.int64)
        self.val_ids = np.array(self.encode(val_data), dtype=np.int64)

        meta = {
            "vocab_size": self.vocab_size,
            "itos": self.itos,
            "stoi": self.stoi,
        }

        with open(self.meta_file, "wb") as f:
            pickle.dump(meta, f)

    def encode(self, s: str) -> list[int]:
        """Encode a string to list of integers."""
        return [self.stoi[c] for c in s]

    def decode(self, l: list[int] | torch.Tensor) -> str:
        """Decode a list of integers to string."""
        if isinstance(l, torch.Tensor):
            l = l.tolist()
        return "".join([self.itos[i] for i in l])

    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.data) - self.block_size

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        s = idx
        e = idx + self.block_size
        x = torch.from_numpy(self.data[s:e].astype(np.int64))
        y = torch.from_numpy(self.data[s + 1 : e + 1].astype(np.int64))

        return x, y


class ShakespeareDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./.datasets",
        block_size: int = 256,
        batch_size: int = 64,
        num_workers: int = 4,
        pin_memory: bool = True,
        preload_to_gpu: bool = True,
        seed: int = 0,
    ):
        super().__init__()
        Path(data_dir).expanduser().mkdir(parents=True, exist_ok=True)

        self.data_dir = data_dir
        self.block_size = block_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.preload_to_gpu = preload_to_gpu

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self._vocab_size = None
        self._stoi = None
        self._itos = None

        self.device = None
        self.seed = seed

    def prepare_data(self):
        ShakespeareDataset(
            data_dir=self.data_dir,
            split="train",
            block_size=self.block_size,
        )

    def _create_tensor_dataset(self, dataset: "ShakespeareDataset") -> TensorDataset:
        num_samples = len(dataset)

        x_data = torch.zeros((num_samples, self.block_size), dtype=torch.long)
        y_data = torch.zeros((num_samples, self.block_size), dtype=torch.long)

        for i in range(num_samples):
            x, y = dataset[i]
            x_data[i] = x
            y_data[i] = y

        assert self.device is not None
        if self.device.type == "cuda":
            x_data = x_data.to(self.device)
            y_data = y_data.to(self.device)

        return TensorDataset(x_data, y_data)

    def setup(self, stage: str = None):
        preload_to_gpu = self.preload_to_gpu and torch.cuda.is_available()
        device = "cuda" if preload_to_gpu else "cpu"
        self.device = torch.device(device)

        if self._vocab_size is None:
            meta_file = Path(self.data_dir, "meta.pkl")
            with open(meta_file, "rb") as f:
                meta = pickle.load(f)
            self._vocab_size = meta["vocab_size"]
            self._stoi = meta["stoi"]
            self._itos = meta["itos"]

        if stage == "fit" or stage is None:
            train_dataset_orig = ShakespeareDataset(
                data_dir=self.data_dir,
                split="train",
                block_size=self.block_size,
            )
            val_dataset_orig = ShakespeareDataset(
                data_dir=self.data_dir,
                split="test",
                block_size=self.block_size,
            )

            self.train_dataset = self._create_tensor_dataset(train_dataset_orig)
            self.val_dataset = self._create_tensor_dataset(val_dataset_orig)

        if stage == "test":
            test_dataset_orig = ShakespeareDataset(
                data_dir=self.data_dir,
                split="test",
                block_size=self.block_size,
            )
            self.test_dataset = self._create_tensor_dataset(test_dataset_orig)

    @property
    def use_workers(self):
        assert self.device is not None
        return 0 if self.device.type == "cuda" else self.num_workers

    @property
    def use_pin_memory(self):
        assert self.device is not None
        return False if self.device.type == "cuda" else self.pin_memory

    def train_dataloader(self):
        assert self.train_dataset is not None
        persistent = True if self.use_workers > 0 else False
        g = torch.Generator()
        g.manual_seed(self.seed)
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.use_workers,
            pin_memory=self.use_pin_memory,
            persistent_workers=persistent,
            generator=g,
        )

    def val_dataloader(self):
        assert self.val_dataset is not None
        persistent = True if self.use_workers > 0 else False
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.use_workers,
            pin_memory=self.use_pin_memory,
            persistent_workers=persistent,
        )

    def test_dataloader(self):
        assert self.test_dataset is not None
        persistent = True if self.use_workers > 0 else False
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.use_workers,
            pin_memory=self.use_pin_memory,
            persistent_workers=persistent,
        )

    @property
    def vocab_size(self):
        if self._vocab_size is None:
            meta_file = Path(self.data_dir, "meta.pkl")
            with open(meta_file, "rb") as f:
                meta = pickle.load(f)
            self._vocab_size = meta["vocab_size"]
            self._stoi = meta["stoi"]
            self._itos = meta["itos"]
        return self._vocab_size

    def encode(self, s: str) -> list[int]:
        if self._stoi is None:
            _ = self.vocab_size
        return [self._stoi[c] for c in s]

    def decode(self, l: list[int] | torch.Tensor) -> str:
        if self._itos is None:
            _ = self.vocab_size
        if isinstance(l, torch.Tensor):
            l = l.tolist()
        return "".join([self._itos[i] for i in l])
