import json
import os
import random
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
from jaxtyping import Float

from code_demeanor.logger import logger


def save_list_jsonl(data_list: List[str], file_path: str):
    with open(file_path, "w") as f:
        for item in data_list:
            f.write(json.dumps(item) + "\n")


def load_list_jsonl(file_path: str) -> List[str]:
    data_list = []
    with open(file_path, "r") as f:
        for line in f:
            data_list.append(json.loads(line))
    return data_list


def read_jsonl_file(jsonl_path: str):
    with open(jsonl_path, "r") as f:
        for line in f:
            yield json.loads(line)


def write_jsonl(file_path, data, append: bool = False):
    mode = "a" if append else "w"
    with open(file_path, mode) as file:
        for item in data:
            file.write(json.dumps(item) + "\n")


def load_data(
    path: str, split_ratio: float = 0.8, max_samples: Optional[int] = None
) -> Tuple[Dataset, Dataset]:
    # Load and shuffle the dataset
    samples = load_dataset("json", data_files=path, split="train").shuffle(seed=42)

    # Limit total number of samples if max_samples is set
    if max_samples is not None:
        samples = samples.select(range(min(max_samples, len(samples))))

    # Compute split index
    split_index = int(len(samples) * split_ratio)

    # Split the dataset
    train_split = samples.select(range(split_index))
    test_split = samples.select(range(split_index, len(samples)))

    return train_split, test_split


def save_tensor(tensor: torch.Tensor, file_path: str, append: bool = False):
    """Save a PyTorch tensor to a file, with optional append mode.
    If append=True, the tensor is appended to a list of tensors in the file.
    If append=False, the tensor overwrites the file."""
    # Detach and send to cpu before saving
    tensor = tensor.detach().cpu()
    if append and os.path.exists(file_path):
        # Load existing tensors (stored as a list)
        data = torch.load(file_path)
        if not isinstance(data, list):
            data = [data]
        data.append(tensor)
    else:
        # Start fresh
        data = [tensor]
    torch.save(data, file_path)


def load_tensor(file_path: str, device: str = "cpu") -> torch.Tensor:
    """Load a PyTorch tensor from a file."""
    return torch.load(file_path, map_location=torch.device(device))


def save_tensor_jsonl(tensor: torch.Tensor, file_path: str, append: bool = False):
    """
    Save a PyTorch tensor to a JSONL file.

    Each line will contain one tensor serialized as JSON.
    If append=True, the new tensor is appended as a new line.
    If append=False, the file is overwritten with just this tensor.
    """
    # Detach and move to CPU before saving
    tensor = tensor.detach().cpu()

    # Convert tensor to list for JSON
    tensor_data = tensor.tolist()

    mode = "a" if append else "w"
    with open(file_path, mode, encoding="utf-8") as f:
        json.dump(tensor_data, f)
        f.write("\n")  # separate entries by newline


def load_tensors_jsonl(file_path: str, device: str = "cpu") -> list[torch.Tensor]:
    """
    Load all tensors from a JSONL file.
    Returns a list of tensors.
    """
    tensors = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():  # skip blank lines
                data = json.loads(line)
                tensors.append(torch.tensor(data, device=device))
    return tensors


def set_seed(seed: int = 42):
    """
    Set the random seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    return seed


def get_device(verbose: bool = True) -> str:
    """
    Get the device to be used for training.
    """
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    if verbose:
        logger.info("Getting device.", device=device)
    return device


def kl_div(original_logits: torch.Tensor, new_logits: torch.Tensor):
    log_probs_new = torch.nn.functional.log_softmax(new_logits, dim=-1)
    probs_orig = torch.nn.functional.softmax(original_logits, dim=-1)
    kl = torch.nn.functional.kl_div(log_probs_new, probs_orig, reduction="none")
    return kl.sum(dim=-1)


def js_div(original_logits: torch.Tensor, new_logits: torch.Tensor):
    original_probs = torch.nn.functional.softmax(original_logits, dim=-1)
    new_probs = torch.nn.functional.softmax(new_logits, dim=-1)
    m = (original_probs + new_probs) / 2
    kl_om = kl_div(original_logits, m)
    kl_nm = kl_div(new_logits, m)
    return (kl_om + kl_nm) / 2


def send_to_device(tensor: torch.Tensor, device: str) -> torch.Tensor:
    """Send a tensor to the specified device (CPU or GPU)."""
    if device == "cuda":
        return tensor.to("cuda")
    elif device == "mps":
        return tensor.to("mps")
    else:
        return tensor.cpu()


def project_onto_direction(
    H: Float[torch.Tensor, "n d_1"],
    direction: Float[torch.Tensor, "d_2"],
    device: torch.device = get_device(verbose=False),
) -> Float[torch.Tensor, "n d_2"]:
    """Project matrix H (n, d_1) onto direction vector (d_2,)"""
    # Ensures H and direction are on the same device (CPU, GPU or MPS).
    if type(direction) != torch.Tensor:
        H = send_to_device(torch.Tensor(H), device)
    if type(direction) != torch.Tensor:
        direction = torch.Tensor(direction)
        direction = send_to_device(direction, device)
    # Calculates the magnitude of the direction vector.
    mag = torch.norm(direction)
    assert not torch.isinf(mag).any()
    # Calculates the projection.
    # H.matmul(direction) computes the dot product between each row of H and the direction vector.
    # Dividing by the magnitude scales the projection to the unit length of the direction vector.
    # This indicates how much of H lies in the direction of the vector.
    projection = H.matmul(direction) / mag
    return projection


def recenter(x, mean=None, device=get_device(verbose=False)):
    x = torch.Tensor(x)
    x = send_to_device(x, device)
    if mean is None:
        mean = torch.mean(x, axis=0, keepdims=True)
        mean = send_to_device(mean, device)
    else:
        mean = torch.Tensor(mean)
        mean = send_to_device(mean, device)
    return x - mean


def _fro_norm_per_batch(t: torch.Tensor) -> torch.Tensor:
    """
    Frobenius (L2) norm over all dims except batch.
    Input shape: [batch, ...]
    Returns: [batch]
    """
    return t.reshape(t.size(0), -1).norm(p=2, dim=-1)


def _stat_features(x: torch.Tensor) -> torch.Tensor:
    """
    x: [B, ...]  (everything after B is reduced)
    returns: [B, N_STATS] on same device/dtype as x
    """
    B = x.shape[0]
    flat = x.reshape(B, -1)
    # cast to float32 for numerical stability on fp16 models
    f = flat.float()

    mean = f.mean(dim=1)
    var = f.var(dim=1, unbiased=False)
    std = var.sqrt()

    # centralized moments
    if f.shape[1] > 0:
        z = f - mean[:, None]
        m3 = (z**3).mean(dim=1)
        m4 = (z**4).mean(dim=1)
    else:
        m3 = torch.zeros_like(mean)
        m4 = torch.zeros_like(mean)

    # guard zero std to avoid NaNs
    eps = 1e-12
    inv_std = 1.0 / torch.clamp(std, min=eps)

    skew = m3 * (inv_std**3)
    kurt = m4 * (inv_std**4)

    rmin = f.amin(dim=1)
    rmax = f.amax(dim=1)
    rng = rmax - rmin

    out = torch.stack([mean, std, rng, skew, kurt], dim=1)
    return out.to(x.dtype)


def read_yaml_config(file_path: str) -> Dict:
    import yaml

    with open(file_path, "r") as file:
        config = yaml.safe_load(file)
    return config


def get_dataset(
    samples_path: str,
    test_samples_path: Optional[str] = None,
    max_samples: Optional[int] = None,
) -> List[Dict[str, Union[str, int]]]:
    """
    Load dataset from a JSONL file.
    """
    if test_samples_path:
        # Load training and testing samples from separate files
        train_samples = load_dataset("json", data_files=samples_path, split="train")
        test_samples = load_dataset("json", data_files=test_samples_path, split="train")
    else:
        # Load the samples
        train_samples, test_samples = load_data(samples_path, max_samples=max_samples)

    # Extract inputs and labels
    train_inputs = [sample["input"] for sample in train_samples]
    train_labels = [sample["label"] for sample in train_samples]
    test_inputs = [sample["input"] for sample in test_samples]
    test_labels = [sample["label"] for sample in test_samples]
    return train_inputs, train_labels, test_inputs, test_labels
