import os
import json
import time
import argparse

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import ot
from datasets import load_dataset
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import PreTrainedTokenizerFast

from load_model import load_model
from model import utils as mutils


def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def cycle_loader(dataloader, sampler=None):
    while True:
        # sampler kept for compatibility; we don't use a DistributedSampler here
        if sampler is not None:
            sampler.set_epoch(np.random.randint(0, 100000))
        for data in dataloader:
            yield data


# ---------- tokenizer + dataset code (mirrors your NEW training script) ----------

def build_char_tokenizer(vocab_path="char_vocab.json"):
    """
    Build the same char-level tokenizer as in training.
    If char_vocab.json already exists, load it; otherwise create it using
    the exact custom_chars list from the training script.
    """
    if os.path.exists(vocab_path):
        with open(vocab_path, "r", encoding="utf-8") as f:
            vocab_dict = json.load(f)
    else:
        custom_chars = [
            ' ', 'e', 't', 'o', 'a', 'h', 'n', 's', 'r', 'i', 'l', 'd', '\n', 'u', 'm',
            'y', ',', '.', 'w', 'f', 'c', 'g', 'I', 'p', 'b', 'A', 'E', 'T', 'v', 'S',
            'O', "'", 'k', 'R', 'N', 'L', 'C', 'H', ';', 'W', 'M', 'B', 'D', 'U', 'F',
            'G', 'P', '?', 'Y', '!', '-', 'K', 'x', 'V', 'j', 'q', '[', ']', 'J', ':',
            'Q', 'z', '9', '1', '(', ')', 'Z', 'X', '<', '"', '>', '2', '3', '0', '4',
            '5', '_', '6', '7', '8', '|', '&', '}', '`'
        ]

        vocab_dict = {char: idx for idx, char in enumerate(custom_chars)}
        vocab_dict["<pad>"] = len(vocab_dict)
        vocab_dict["<unk>"] = len(vocab_dict)

        with open(vocab_path, "w", encoding="utf-8") as f:
            json.dump(vocab_dict, f)

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=Tokenizer(models.WordLevel(vocab_dict, unk_token="<unk>")),
        unk_token="<unk>",
        pad_token="<pad>",
    )

    tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Split(
        "", behavior="isolated"
    )

    # Optional: custom decoder (useful for debugging / inspecting samples)
    id_to_char = {v: k for k, v in vocab_dict.items()}

    def custom_decoder(token_ids):
        return "".join(id_to_char.get(token_id, "<unk>") for token_id in token_ids)

    tokenizer.custom_decoder = custom_decoder

    return tokenizer


def make_tokenize_and_group_fns(tokenizer, block_size):
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            add_special_tokens=False,
            truncation=True,
            max_length=block_size,
            return_attention_mask=False,
            return_token_type_ids=False,
        )

    def group_texts(examples):
        concatenated = []
        for tokens in examples["input_ids"]:
            concatenated.extend(tokens)
        total_length = (len(concatenated) // block_size) * block_size
        result = {
            "input_ids": [
                concatenated[i: i + block_size]
                for i in range(0, total_length, block_size)
            ]
        }
        return result

    return tokenize_function, group_texts


def get_char_dataset(file_path, cache_dir, block_size, tokenizer):
    """
    Equivalent to your training get_dataset(), but parameterized by args
    instead of cfg.
    """
    dataset = load_dataset(
        "text",
        data_files={"train": file_path},
        cache_dir=cache_dir,
    )["train"]

    tokenize_function, group_texts = make_tokenize_and_group_fns(
        tokenizer, block_size
    )

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"],
    )

    grouped_dataset = tokenized_dataset.map(
        lambda x: group_texts(x),
        batched=True,
    )

    # You had `.flatten()` in your training script; we mirror that
    try:
        grouped_dataset = grouped_dataset.flatten()
    except AttributeError:
        # in case flatten() isn't available for the installed datasets version
        pass

    grouped_dataset.set_format(type="torch", columns=["input_ids"])
    return grouped_dataset


# ---------- J1 evaluation ----------

def J1(args, train_iter, device, noise, graph, model):
    sampling_eps = 0.0001
    losses = []
    loss_type = graph.loss_type  # kept for parity with your code (not used below)

    for i in range(args.no_batches):
        # batch is a dict with key "input_ids" from HF datasets
        databatch = next(train_iter)["input_ids"].to(device)  # [B, L]
        B, L = databatch.shape

        sourcebatch = graph.sample_limit(databatch.shape).to(databatch)

        # t ~ Uniform(eps, 1 - eps)
        t = (1 - 2 * sampling_eps) * torch.rand(databatch.shape[0], device=databatch.device) + sampling_eps

        # optional OT-based matching (same as your previous J1 implementation)
        hamming = False
        if hamming:
            try:
                # 1) Pairwise Hamming distance matrix M[i, j]
                source_exp = sourcebatch.unsqueeze(1)   # [B, 1, L]
                data_exp = databatch.unsqueeze(0)       # [1, B, L]

                M = (source_exp != data_exp).sum(dim=2).float()  # [B, B]

                # normalize by sequence length
                M = M / M.size(1)

                # 2) Exact OT (EMD) with uniform marginals
                a_np = np.full(B, 1.0 / B, dtype=np.float64)
                b_np = np.full(B, 1.0 / B, dtype=np.float64)

                C_np = M.double().cpu().numpy()

                P_np = ot.emd(a_np, b_np, C_np)
                plan = torch.from_numpy(P_np).to(device=device, dtype=torch.float32)  # [B, B]

                # 3) Get permutation indices
                col_idx = plan.argmax(dim=1)      # [B]
                row_idx = torch.arange(B, device=device)

                sourcebatch_matched = sourcebatch[row_idx]
                databatch_matched = databatch[col_idx]

                sourcebatch = sourcebatch_matched
                databatch = databatch_matched
            except Exception as e:
                print("Error calculating optimal transport, continuing with independent sampling this batch")
                print(e)

        perturbed_batch = graph.sample_transition(sourcebatch, databatch, t[:, None])

        mask = perturbed_batch != databatch

        log_score_fn = mutils.get_score_fn(model, train="True", sampling=False)
        log_score = log_score_fn(perturbed_batch, t)  # [..., vocab]

        probs = torch.softmax(log_score, dim=-1)

        # log p(x_t == data | x_t)
        loss1 = mask * (
            torch.log(
                torch.gather(probs, dim=-1, index=databatch.unsqueeze(-1)) + 1e-10
            )
            + 1
        ).squeeze()

        # 1 - p(x_t == perturbed | x_t)
        loss2 = 1 - torch.gather(probs, dim=-1, index=perturbed_batch.unsqueeze(-1)).squeeze()

        loss = -loss1 + loss2
        loss = (loss / (1 - t[:, None])).sum(-1).mean()

        losses.append(loss.detach().cpu().numpy())

        if torch.isinf(
            torch.tensor(
                np.exp(np.array(losses).mean() / args.length)
            )
        ).item():
            break

        if i % 100 == 0:
            print("iter:", i)
            print(
                "mean:",
                np.exp(np.array(losses).mean() / args.length),
            )

        if i % 1000 == 0:
            np.save(
                os.path.join(
                    args.model_path + "_eval/",
                    "lossesJ1_"
                    + "dataset_"
                    + args.dataset
                    + "_length_"
                    + str(args.length)
                    + ".npy",
                ),
                np.array(losses) / args.length,
            )


def main():
    parser = argparse.ArgumentParser(description="J1 evaluation for char-level Shakespeare")

    parser.add_argument(
        "--model_path",
        default="x",
        type=str,
        help="Path where the trained model (and cfg) are stored.",
    )
    parser.add_argument("--dataset", default="shakespeare", type=str)
    parser.add_argument("--length", default=128, type=int)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument(
        "--perturbed_points_nr", type=int, default=128 * (50 * 1024)
    )
    parser.add_argument("--cache_dir", type=str, default="data")
    parser.add_argument("--J", type=str, default="1")

    # evaluation file (matches training: shakespeare_test.txt)
    parser.add_argument(
        "--eval_file",
        type=str,
        default="shakespeare_test.txt",
        help="Plain text file to evaluate on (e.g. test set).",
    )

    args = parser.parse_args()

    try:
        args.no_batches = args.perturbed_points_nr // args.batch_size
    except Exception:
        print("args.perturbed_points_nr should be bigger than args.batch_size")
        return

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    create_directory_if_not_exists(args.model_path + "_eval/")

    # Build tokenizer + dataset exactly like training
    tokenizer = build_char_tokenizer()  # uses / creates char_vocab.json
    eval_set = get_char_dataset(
        file_path=args.eval_file,
        cache_dir=args.cache_dir,
        block_size=args.length,
        tokenizer=tokenizer,
    )

    print(f"The size of the evaluation dataset: {len(eval_set)}")

    eval_loader = cycle_loader(
        DataLoader(
            eval_set,
            batch_size=args.batch_size,
            shuffle=True,          # you had shuffle=True in training
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )
    )
    eval_iter = iter(eval_loader)

    with torch.no_grad():
        model, graph, noise = load_model(args.model_path, device)
        J1(args, eval_iter, device, noise, graph, model)


if __name__ == "__main__":
    main()
