import os
import json
import torch
import time
import numpy as np
import argparse
import ot

# from your original code
# import data   # <-- no longer needed
from load_model import load_model
import torch.nn.functional as F
import sampling
from torch.utils.data import DataLoader
from model import utils as mutils

# new imports for the char-level dataset
from datasets import load_dataset
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import PreTrainedTokenizerFast


def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def cycle_loader(dataloader, sampler=None):
    while 1:
        # sampler kept for compatibility; we don't use a DistributedSampler here
        if sampler is not None:
            sampler.set_epoch(np.random.randint(0, 100000))
        for data in dataloader:
            yield data


# ---------- NEW: tokenizer + dataset code (adapted from your training script) ----------

def build_char_tokenizer(vocab_path="char_vocab.json"):
    """
    Build the same char-level tokenizer as in training.
    If char_vocab.json already exists, load it; otherwise create it.
    """
    if os.path.exists(vocab_path):
        with open(vocab_path, "r", encoding="utf-8") as f:
            vocab_dict = json.load(f)
    else:
        # exactly as in your training code
        custom_chars = ['.', ' ', '-', '\n', '[', ']', '<', '>', '|', '}', '`']
        vocab_dict = {char: idx for idx, char in enumerate(custom_chars)}
        vocab_dict["<pad>"] = len(vocab_dict)
        vocab_dict["<unk>"] = len(vocab_dict)

        with open(vocab_path, "w", encoding="utf-8") as f:
            json.dump(vocab_dict, f)

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=Tokenizer(models.WordLevel(vocab_dict, unk_token="<unk>")),
        unk_token="<unk>",
        pad_token="<pad>",
    )
    tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Split(
        "", behavior="isolated"
    )

    # Optional: custom decoder (not needed for J1, but kept for parity)
    id_to_char = {v: k for k, v in vocab_dict.items()}

    def custom_decoder(token_ids):
        return "".join(id_to_char.get(token_id, "<unk>") for token_id in token_ids)

    tokenizer.custom_decoder = custom_decoder

    return tokenizer


def make_tokenize_and_group_fns(tokenizer, block_size):
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            add_special_tokens=False,
            truncation=True,
            max_length=block_size,
            return_attention_mask=False,
            return_token_type_ids=False,
        )

    def group_texts(examples):
        concatenated = []
        for tokens in examples["input_ids"]:
            concatenated.extend(tokens)
        total_length = (len(concatenated) // block_size) * block_size
        # chunk into blocks of block_size
        result = {
            "input_ids": [
                concatenated[i : i + block_size]
                for i in range(0, total_length, block_size)
            ]
        }
        return result

    return tokenize_function, group_texts


def get_char_dataset(file_path, cache_dir, block_size, tokenizer):
    """
    Equivalent to your training get_dataset(), but parameterized by args
    instead of cfg.
    """
    dataset = load_dataset(
        "text",
        data_files={"train": file_path},
        cache_dir=cache_dir,
    )["train"]

    tokenize_function, group_texts = make_tokenize_and_group_fns(
        tokenizer, block_size
    )

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"],
    )

    grouped_dataset = tokenized_dataset.map(
        lambda x: group_texts(x),
        batched=True,
    )

    # You had `.flatten()` in your training script; we mirror that
    try:
        grouped_dataset = grouped_dataset.flatten()
    except AttributeError:
        # in case flatten() isn't available for the installed datasets version
        pass

    # make it return PyTorch tensors with key "input_ids"
    grouped_dataset.set_format(type="torch", columns=["input_ids"])
    return grouped_dataset


# ---------- J1 stays the same ----------

def J1(args, train_iter, device, noise, graph, model):
    sampling_eps=0.0001
    losses = []
    loss_type = graph.loss_type
    for i in range(args.no_batches):

        
        databatch = next(train_iter)['input_ids'].to(device)
        B, L = databatch.shape
        sourcebatch = graph.sample_limit(databatch.shape).to(databatch)
        
        t = (1 - 2*sampling_eps) * torch.rand(databatch.shape[0], device=databatch.device)+sampling_eps





        hamming = True
        if hamming:
            try:
                # 1) Pairwise Hamming distance matrix M[i, j]
                source_exp = sourcebatch.unsqueeze(1)   # [B, 1, L]
                data_exp   = databatch.unsqueeze(0)     # [1, B, L]

                M = (source_exp != data_exp).sum(dim=2).float()  # [B, B]

                # Optional: normalize by sequence length instead of M.max()
                # This keeps scale consistent across batches
                M = M / M.size(1)

                # 2) Exact OT (EMD) with uniform marginals
                a_np = np.full(B, 1.0 / B, dtype=np.float64)
                b_np = np.full(B, 1.0 / B, dtype=np.float64)

                C_np = M.double().cpu().numpy()  # cost matrix in float64 on CPU

                P_np = ot.emd(a_np, b_np, C_np)  # exact OT plan ~ permutation
                plan = torch.from_numpy(P_np).to(device=device, dtype=torch.float32)  # [B, B]

                # 3) Get permutation indices: one column per row
                # plan is (up to 1/B) a permutation matrix, so argmax per row is safe
                col_idx = plan.argmax(dim=1)      # [B], LongTensor
                row_idx = torch.arange(B, device=device)

                # 4) Apply the 1-1 matching
                # row i in sourcebatch is matched to column col_idx[i] in databatch
                sourcebatch_matched = sourcebatch[row_idx]    # [B, L]
                databatch_matched   = databatch[col_idx]      # [B, L]

                # If you want to overwrite originals:
                sourcebatch = sourcebatch_matched
                databatch = databatch_matched           
            except:
                print('Error calculating optimal transport, continuing with independent sampling this batch')



        perturbed_batch = graph.sample_transition(sourcebatch, databatch, t[:, None])

        mask= perturbed_batch!=databatch

        log_score_fn = mutils.get_score_fn(model, train='True', sampling=False)
        log_score = log_score_fn(perturbed_batch, t)

        probs = torch.softmax(log_score, dim=-1)

        loss1 = mask * (torch.log(torch.gather(probs, dim=-1, index=databatch.unsqueeze(-1) )+0.0000000001)+1).squeeze()
        loss2 = 1 - torch.gather(probs, dim=-1, index=perturbed_batch.unsqueeze(-1)).squeeze()

        loss = -loss1 + loss2

        loss =  (loss/(1-t[:, None])).sum(-1).mean()
        
        losses.append(loss.cpu().numpy())

        if torch.isinf(
            torch.tensor(
                np.exp(np.array(losses).mean() / args.length)
            )
        ).item():
            break
        if i % 100 == 0:
            print("iter:", i)
            print(
                "mean:",
                np.exp(np.array(losses).mean() / args.length),
            )
        if i % 1000 == 0:
            np.save(
                args.model_path
                + "_eval/"
                + "lossesJ1_"
                + "dataset_"
                + args.dataset
                + "_length_"
                + str(args.length)
                + ".npy",
                np.array(losses) / args.length,
            )


def main():
    parser = argparse.ArgumentParser(description="J1 evaluation")

    parser.add_argument(
        "--model_path", default="x", type=str
    )
    # this is now just a name tag for saving (e.g. "shakespeare")
    parser.add_argument("--dataset", default="shakespeare", type=str)
    parser.add_argument("--length", default=128, type=int)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument(
        "--perturbed_points_nr", type=int, default=128 * (50 * 1024)
    )
    parser.add_argument("--cache_dir", type=str, default="data")
    parser.add_argument("--J", type=str, default="1")

    # NEW: path to the text file you used in training
    parser.add_argument(
        "--eval_file",
        type=str,
        default="shakespeare_test_conv.txt",
        help="Plain text file to evaluate on (e.g. test set).",
    )

    args = parser.parse_args()
    try:
        args.no_batches = args.perturbed_points_nr // args.batch_size
    except Exception:
        print(
            "args.perturbed_points_nr should be bigger than args.batch_size"
        )
        return

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    create_directory_if_not_exists(args.model_path + "_eval/")

    # --------- NEW: build tokenizer + dataset exactly like training ---------
    tokenizer = build_char_tokenizer()  # uses / creates char_vocab.json
    train_set = get_char_dataset(
        file_path=args.eval_file,
        cache_dir=args.cache_dir,
        block_size=args.length,
        tokenizer=tokenizer,
    )

    print(f"The size of the dataset: {len(train_set)}")

    train_loader = cycle_loader(
        DataLoader(
            train_set,
            batch_size=args.batch_size,
            shuffle=True,  # you had shuffle=True in training
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )
    )
    train_iter = iter(train_loader)

    with torch.no_grad():
        model, graph, noise = load_model(args.model_path, device)
        J1(args, train_iter, device, noise, graph, model)


if __name__ == "__main__":
    main()
