import numpy as np
from collections import Counter
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
import re
import random
from tqdm import tqdm
from models.ae import AutoEncoder
from models.vae import VAE
from models.paretovae import ParetoVAE
from models.laplacevae import LaplaceVAE
from models.t3vae import T3VAE
import torch.optim as optim
import wandb


def set_random_seed(seed=42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # For deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class BowDataset(Dataset):
    def __init__(self, bow_vectors):
        self.bow_vectors = bow_vectors

    def __len__(self):
        return len(self.bow_vectors)

    def __getitem__(self, idx):
        return torch.from_numpy(self.bow_vectors[idx])


class WordFrequencyAnalysis:

    def __init__(
        self,
        latent_dim=128,
        nu=3.1,
        dataset="wikitext-2-raw-v1",
        normalization="batchnorm",
        optimizer="adam",
        use_wandb=True,
        min_freq=1,
        tail_freq=10,
        loss_type="mse",
        epochs=10,
        batch_size=512,
        learning_rate=1e-3,
        seed=42,
        model_name=None,
    ):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.latent_dim = latent_dim
        self.nu = nu
        self.dataset_name = dataset
        self.optimizer = optimizer
        self.use_wandb = use_wandb
        self.min_freq = min_freq
        self.tail_freq = tail_freq
        self.loss_type = loss_type
        self.normalization = normalization
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.seed = seed
        self.model_name = model_name
        self.best_head_jaccard = 0
        self.best_tail_jaccard = 0
        self.best_head_overlap = 0
        self.best_tail_overlap = 0
        self.best_epoch = 0

        if not os.path.exists("results"):
            os.makedirs("results")

        self.dataset = load_dataset("wikitext", dataset)
        self.build_vocab()  # Build vocab during initialization
        self.create_bow_dataset()

    def build_vocab(self):
        self.texts = [
            text for text in self.dataset["train"]["text"] if text.strip()
        ]

        word_counts = Counter()
        for text in tqdm(
            self.texts,
            desc="Building vocab from all texts",
        ):
            words = self.tokenize(text)
            word_counts.update(words)

        self.vocab = [
            word
            for word, count in word_counts.items()
            if count >= self.min_freq
        ]

        self.word_to_idx = {word: i for i, word in enumerate(self.vocab)}
        self.idx_to_word = {i: word for i, word in enumerate(self.vocab)}

        self.tail_count = sum(
            1
            for count in word_counts.values()
            if self.min_freq <= count and count <= self.tail_freq
        )
        self.head_count = self.tail_count

        print(
            f"Vocabulary size: {len(self.vocab)} "
            f"with frequency >= {self.min_freq} "
            f"and {self.tail_freq} with "
            f"({self.tail_count} rare words)"
        )

    def tokenize(self, text):
        text = text.lower()
        text = re.sub(r"[^a-z\s]", " ", text)
        text = re.sub(r"\s+", " ", text)
        return re.findall(r"\b\w+\b", text)

    def create_bow_dataset(self):
        texts = self.texts
        print(len(texts), "texts found (using same texts as vocab).")

        num_docs = len(texts)
        vocab_size = len(self.vocab)
        self.bow_vectors = np.zeros((num_docs, vocab_size), dtype=np.float32)

        for i, text in enumerate(
            tqdm(
                texts,
                desc="Creating BoW vectors",
            )
        ):
            words = self.tokenize(text)
            for word in words:
                if word in self.word_to_idx:
                    idx = self.word_to_idx[word]
                    self.bow_vectors[i, idx] += 1

        print("BoW dataset created.")

    def train(self):
        if self.use_wandb:
            wandb.init(
                project="iclr2026",
                group="word_frequency",
                config={
                    "model": self.model_name,
                    "loss_type": self.loss_type,
                    "epochs": self.epochs,
                    "batch_size": self.batch_size,
                    "optimizer": self.optimizer,
                    "learning_rate": self.learning_rate,
                    "latent_dim": self.latent_dim,
                    "dataset": self.dataset_name,
                    "min_freq": self.min_freq,
                    "vocab_size": len(self.vocab),
                    "nu": self.nu,
                    "normalization": self.normalization,
                    "tail_freq": self.tail_freq,
                    "seed": self.seed,
                },
            )

        dataset = BowDataset(self.bow_vectors)
        dataloader = DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True
        )

        vocab_size = len(self.vocab)
        if self.model_name == "vae":
            MODEL = VAE
        elif self.model_name == "pareto":
            MODEL = ParetoVAE
        elif self.model_name == "laplace":
            MODEL = LaplaceVAE
        elif self.model_name == "t3":
            MODEL = T3VAE
        elif self.model_name == "ae":
            MODEL = AutoEncoder

        self.model = MODEL(
            nu=self.nu,
            input_shape=[vocab_size],
            latent_dim=self.latent_dim,
            reconstruction=self.loss_type,
            mode="mlp",
            normalization=self.normalization,
            activation="relu",
            device=self.device,
        ).to(self.device)

        if self.optimizer == "adam-wd":
            optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=1e-5,
            )
        elif self.optimizer == "adamw-wd":
            optimizer = optim.AdamW(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=1e-5,
            )
        elif self.optimizer == "adam":
            optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
            )
        elif self.optimizer == "adamw":
            optimizer = optim.AdamW(
                self.model.parameters(),
                lr=self.learning_rate,
            )
        elif self.optimizer == "rmsprop":
            optimizer = optim.RMSprop(
                self.model.parameters(),
                lr=self.learning_rate,
            )
        elif self.optimizer == "nadam":
            optimizer = optim.NAdam(
                self.model.parameters(),
                lr=self.learning_rate,
            )
        elif self.optimizer == "radam":
            optimizer = optim.RAdam(
                self.model.parameters(),
                lr=self.learning_rate,
            )
        else:
            raise ValueError("Unsupported optimizer type")

        for epoch in range(1, self.epochs + 1):
            total_loss = 0
            total_recon_loss = 0
            total_reg_loss = 0
            if self.best_epoch > 0 and self.best_epoch + 50 < epoch:
                print(
                    f"No improvement for 50 epochs since epoch {self.best_epoch}. Sheadping training."
                )
                break
            for data in tqdm(dataloader, desc=f"Epoch {epoch}/{self.epochs}"):
                data = data.to(self.device)
                optimizer.zero_grad()
                z, recon, derivatives = self.model(data)
                recon_loss = self.model.reconstruction(recon, data)
                reg_loss = self.model.regularization(derivatives)
                regularization_weight = 2 * epoch / self.epochs
                loss = recon_loss + regularization_weight * reg_loss

                if torch.isnan(loss).any():
                    print(
                        f"NaN loss detected at epoch {epoch}. "
                        "Sheadping training."
                    )
                    if self.use_wandb:
                        wandb.log({"nan_detected": True, "epoch_nan": epoch})
                        wandb.finish()
                    return

                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                total_recon_loss += recon_loss.item()
                total_reg_loss += reg_loss.item()

            avg_loss = total_loss / len(dataloader)
            avg_recon_loss = total_recon_loss / len(dataloader)
            avg_reg_loss = total_reg_loss / len(dataloader)

            # Check for NaN in averaged losses
            nan_recon_loss = torch.isnan(torch.tensor(avg_recon_loss))
            nan_reg_loss = torch.isnan(torch.tensor(avg_reg_loss))

            if nan_recon_loss:
                print(
                    f"NaN in reconstruction losses at epoch {epoch}. "
                    "Sheadping training."
                )
                if self.use_wandb:
                    wandb.log({"nan_detected": True, "epoch_nan": epoch})
                    wandb.finish()
                return
            elif nan_reg_loss:
                print(
                    f"NaN in regularization losses at epoch {epoch}. "
                    "Sheadping training."
                )
                if self.use_wandb:
                    wandb.log({"nan_detected": True, "epoch_nan": epoch})
                    wandb.finish()
                return

            print(
                f"Epoch [{epoch}/{self.epochs}]",
                f"Recon Loss: {avg_recon_loss:.4f}",
                f"Reg Loss: {avg_reg_loss:.4f}",
                f"Total Loss: {avg_loss + regularization_weight * avg_reg_loss:.4f}",
                f"Reg Weight: {regularization_weight:.4f}",
            )

            # Log to wandb
            if self.use_wandb:
                wandb.log(
                    {
                        "epoch": epoch,
                        "total_loss": avg_loss,
                        "reconstruction_loss": avg_recon_loss,
                        "regularization_loss": avg_reg_loss,
                        "reg_weight": regularization_weight,
                    }
                )

            self.analyze_and_visualize(
                epoch=epoch,
            )

        print("Training complete.")

        # Finish wandb run
        if self.use_wandb:
            wandb.finish()

    def analyze_and_visualize(self, epoch):
        self.model.eval()
        reconstructed_bow = np.zeros_like(self.bow_vectors)

        with torch.no_grad():
            num_batches = int(np.ceil(len(self.bow_vectors) / self.batch_size))
            for i in tqdm(range(num_batches)):
                start_idx = i * self.batch_size
                end_idx = min((i + 1) * self.batch_size, len(self.bow_vectors))

                batch_bow = torch.from_numpy(
                    self.bow_vectors[start_idx:end_idx]
                ).to(self.device)

                _, batch_recon, _ = self.model(batch_bow)

                reconstructed_bow[start_idx:end_idx] = (
                    batch_recon.cpu().numpy()
                )

        original_freq = self.bow_vectors.sum(axis=0)
        reconstructed_freq = reconstructed_bow.sum(axis=0)

        self._save_individual_plots(epoch, original_freq, reconstructed_freq)

        head_k_metrics = self.compare_head_k_words(
            epoch,
            original_freq,
            reconstructed_freq,
        )

        tail_k_metrics = self.compare_tail_k_words(
            epoch,
            original_freq,
            reconstructed_freq,
        )

        # Log metrics to wandb
        if self.use_wandb:
            wandb.log(
                {
                    f"head_{self.head_count}_overlap_ratio": head_k_metrics[
                        "overlap_ratio"
                    ],
                    f"head_{self.head_count}_jaccard_similarity": head_k_metrics[
                        "jaccard_sim"
                    ],
                    f"tail_{self.tail_count}_overlap_ratio": tail_k_metrics[
                        "overlap_ratio"
                    ],
                    f"tail_{self.tail_count}_jaccard_similarity": tail_k_metrics[
                        "jaccard_sim"
                    ],
                    f"tail_{self.tail_count}_p_value": tail_k_metrics[
                        "p_value"
                    ],
                    f"head_{self.head_count}_p_value": head_k_metrics[
                        "p_value"
                    ],
                },
                step=epoch,
            )
        if self.best_head_overlap < head_k_metrics["overlap_ratio"]:
            self.best_head_overlap = head_k_metrics["overlap_ratio"]
            self.best_epoch = epoch
            if self.use_wandb:
                wandb.log(
                    {
                        "best/head_overlap_ratio": head_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/head_jaccard_similarity": head_k_metrics[
                            "jaccard_sim"
                        ],
                        "best/tail_overlap_ratio": tail_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/tail_jaccard_similarity": tail_k_metrics[
                            "jaccard_sim"
                        ],
                    },
                    step=epoch,
                )
        if self.best_head_jaccard < head_k_metrics["jaccard_sim"]:
            self.best_head_jaccard = head_k_metrics["jaccard_sim"]
            self.best_epoch = epoch
            if self.use_wandb:
                wandb.log(
                    {
                        "best/head_overlap_ratio": head_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/head_jaccard_similarity": head_k_metrics[
                            "jaccard_sim"
                        ],
                        "best/tail_overlap_ratio": tail_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/tail_jaccard_similarity": tail_k_metrics[
                            "jaccard_sim"
                        ],
                    },
                    step=epoch,
                )
        if self.best_tail_overlap < tail_k_metrics["overlap_ratio"]:
            self.best_tail_overlap = tail_k_metrics["overlap_ratio"]
            self.best_epoch = epoch
            if self.use_wandb:
                wandb.log(
                    {
                        "best/head_overlap_ratio": head_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/head_jaccard_similarity": head_k_metrics[
                            "jaccard_sim"
                        ],
                        "best/tail_overlap_ratio": tail_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/tail_jaccard_similarity": tail_k_metrics[
                            "jaccard_sim"
                        ],
                    },
                    step=epoch,
                )
        if self.best_tail_jaccard < tail_k_metrics["jaccard_sim"]:
            self.best_tail_jaccard = tail_k_metrics["jaccard_sim"]
            self.best_epoch = epoch
            if self.use_wandb:
                wandb.log(
                    {
                        "best/head_overlap_ratio": head_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/head_jaccard_similarity": head_k_metrics[
                            "jaccard_sim"
                        ],
                        "best/tail_overlap_ratio": tail_k_metrics[
                            "overlap_ratio"
                        ],
                        "best/tail_jaccard_similarity": tail_k_metrics[
                            "jaccard_sim"
                        ],
                    },
                    step=epoch,
                )

    def plot_zipf_distribution(self, epoch, original_freq, reconstructed_freq):
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))

        ax1 = axes[0, 0]
        sorted_original_freq = np.sort(original_freq)[::-1]
        ax1.plot(
            range(1, len(sorted_original_freq) + 1),
            sorted_original_freq,
            label="Original",
            alpha=0.7,
            linewidth=2,
        )

        sorted_reconstructed_freq = np.sort(reconstructed_freq)[::-1]
        ax1.plot(
            range(1, len(sorted_reconstructed_freq) + 1),
            sorted_reconstructed_freq,
            label="Reconstructed",
            alpha=0.7,
            linewidth=2,
        )

        ax1.set_xscale("log")
        ax1.set_yscale("log")
        ax1.set_title("Zipf's Law Distribution (Log-Log Scale)")
        ax1.set_xlabel("Rank")
        ax1.set_ylabel("Frequency")
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        ax2 = axes[0, 1]

        orig_nonzero = original_freq[original_freq > 0]
        recon_nonzero = reconstructed_freq[reconstructed_freq > 0]

        max_freq = max(np.max(orig_nonzero), np.max(recon_nonzero))
        bins = np.logspace(0, np.log10(max_freq + 1), 75)

        ax2.hist(
            orig_nonzero, bins=bins, alpha=0.6, label="Original", density=True
        )
        ax2.hist(
            recon_nonzero,
            bins=bins,
            alpha=0.6,
            label="Reconstructed",
            density=True,
        )

        ax2.set_xscale("log")
        ax2.set_title("Frequency Distribution Histogram")
        ax2.set_xlabel("Word Frequency")
        ax2.set_ylabel("Density")
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        ax3 = axes[1, 0]
        head_n = min(100, len(sorted_original_freq))
        x_pos = range(1, head_n + 1)

        ax3.plot(
            x_pos,
            sorted_original_freq[:head_n],
            "o-",
            label="Original",
            alpha=0.7,
            markersize=3,
        )
        ax3.plot(
            x_pos,
            sorted_reconstructed_freq[:head_n],
            "s-",
            label="Reconstructed",
            alpha=0.7,
            markersize=3,
        )

        ax3.set_title(f"head {head_n} Words - Linear Scale")
        ax3.set_xlabel("Rank")
        ax3.set_ylabel("Frequency")
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        ax4 = axes[1, 1]

        mask = (original_freq > 0) & (reconstructed_freq > 0)
        orig_scatter = original_freq[mask]
        recon_scatter = reconstructed_freq[mask]

        ax4.scatter(orig_scatter, recon_scatter, alpha=0.6, s=20)

        max_val = max(np.max(orig_scatter), np.max(recon_scatter))
        ax4.plot(
            [0, max_val],
            [0, max_val],
            "r--",
            alpha=0.8,
            label="Perfect Correlation",
        )

        ax4.set_xscale("log")
        ax4.set_yscale("log")
        ax4.set_title("Original vs Reconstructed Frequencies")
        ax4.set_xlabel("Original Frequency")
        ax4.set_ylabel("Reconstructed Frequency")
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        plt.tight_layout()

        fig.suptitle(
            f"Word Frequency Analysis - Epoch {epoch} ({self.model.name})",
            fontsize=16,
            y=0.98,
        )

        results_dir = self.get_results_dir()
        os.makedirs(results_dir, exist_ok=True)
        plt.savefig(
            f"{results_dir}/zipf_distribution_{epoch}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()

    def _save_individual_plots(self, epoch, original_freq, reconstructed_freq):
        """Save individual plots for detailed analysis"""
        results_dir = self.get_results_dir()
        os.makedirs(results_dir, exist_ok=True)

        plt.figure(figsize=(10, 6))
        sorted_original_freq = np.sort(original_freq)[::-1]
        sorted_reconstructed_freq = np.sort(reconstructed_freq)[::-1]

        plt.plot(
            range(1, len(sorted_original_freq) + 1),
            sorted_original_freq,
            label="Original",
            alpha=0.7,
            linewidth=2,
        )
        plt.plot(
            range(1, len(sorted_reconstructed_freq) + 1),
            sorted_reconstructed_freq,
            label="Reconstructed",
            alpha=0.7,
            linewidth=2,
        )

        plt.xscale("log")
        plt.yscale("log")
        plt.title(f"Zipf's Law Distribution - Epoch {epoch}")
        plt.xlabel("Rank")
        plt.ylabel("Frequency")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(
            f"{results_dir}/zipf_only_{epoch}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()

        plt.figure(figsize=(10, 6))
        orig_nonzero = original_freq[original_freq > 0]
        recon_nonzero = reconstructed_freq[reconstructed_freq > 0]

        max_freq = max(np.max(orig_nonzero), np.max(recon_nonzero))
        bins = np.logspace(0, np.log10(max_freq + 1), 75)

        plt.hist(
            orig_nonzero, bins=bins, alpha=0.6, label="Original", density=True
        )
        plt.hist(
            recon_nonzero,
            bins=bins,
            alpha=0.6,
            label="Reconstructed",
            density=True,
        )

        plt.xscale("log")
        plt.title(f"Word Frequency Histogram - Epoch {epoch}")
        plt.xlabel("Word Frequency")
        plt.ylabel("Density")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(
            f"{results_dir}/histogram_{epoch}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()

    def compare_head_k_words(
        self, epoch, original_freq, reconstructed_freq, k=50
    ):
        original_head_k_indices = np.argsort(original_freq)[::-1][
            : self.head_count
        ]
        reconstructed_head_k_indices = np.argsort(reconstructed_freq)[::-1][
            : self.head_count
        ]

        original_head_k_words = [
            self.idx_to_word[i] for i in original_head_k_indices
        ]
        reconstructed_head_k_words = [
            self.idx_to_word[i] for i in reconstructed_head_k_indices
        ]

        common_words = set(original_head_k_words) & set(
            reconstructed_head_k_words
        )
        overlap_ratio = len(common_words) / self.head_count

        original_set = set(original_head_k_words)
        reconstructed_set = set(reconstructed_head_k_words)

        union_size = len(original_set.union(reconstructed_set))
        jaccard_sim = len(common_words) / union_size
        print(
            f"Jaccard similarity (head {self.head_count}): {jaccard_sim:.4f}"
        )
        print(
            f"  head {self.head_count} overlap: {len(common_words)}/{self.head_count} "
            f"({overlap_ratio:.2%})"
        )

        mmd_stat, p_value, mmd_bootstrap_list = self.mmd_linear_bootstrap_test(
            reconstructed_head_k_words, original_head_k_words
        )
        # # Log overlap metrics to file
        results_dir = self.get_results_dir()
        log_file = f"{results_dir}/overlap_metrics.txt"
        os.makedirs(results_dir, exist_ok=True)
        with open(log_file, "a") as f:
            f.write(f"head {self.head_count} words (Epoch {epoch}):\n")
            f.write(
                f"  head {self.head_count} overlap: {len(common_words)}/{self.head_count} "
                f"({overlap_ratio:.2%})\n"
            )
            f.write(f"  Jaccard similarity: {jaccard_sim:.4f}\n")
            f.write(f"  Common words count: {len(common_words)}\n")
            f.write(f"  Common words: {sorted(list(common_words))[:20]}...\n")

        # Plotting
        original_head_k_indices_plot = np.argsort(original_freq)[::-1][:30]
        reconstructed_head_k_indices_plot = np.argsort(reconstructed_freq)[
            ::-1
        ][:30]

        original_head_k_words_plot = [
            self.idx_to_word[i] for i in original_head_k_indices_plot
        ]
        reconstructed_head_k_words_plot = [
            self.idx_to_word[i] for i in reconstructed_head_k_indices_plot
        ]

        # Additional overlap analysis
        original_set = set(original_head_k_words_plot)
        reconstructed_set = set(reconstructed_head_k_words_plot)

        plt.figure(figsize=(12, 10))

        plt.subplot(2, 1, 1)
        plt.bar(
            original_head_k_words_plot,
            original_freq[original_head_k_indices_plot],
        )
        plt.title(f"head {self.head_count} Word Frequencies (Original)")
        plt.xticks(rotation=45)

        plt.subplot(2, 1, 2)
        plt.bar(
            reconstructed_head_k_words_plot,
            reconstructed_freq[reconstructed_head_k_indices_plot],
        )
        plt.title(f"head {self.head_count} Word Frequencies (Reconstructed)")
        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig(f"{results_dir}/head_k_words_{epoch}.png")
        plt.close()

        return {
            "overlap_ratio": overlap_ratio,
            "p_value": p_value,
            "jaccard_sim": jaccard_sim,
            "common_words_count": len(common_words),
        }

    def compare_tail_k_words(
        self,
        epoch,
        original_freq,
        reconstructed_freq,
    ):
        """
        Compare tail self.tail_count words between original and reconstructed distributions
        """
        # Get tail self.tail_count words (non-zero frequencies only)
        non_zero_original_indices = np.where(original_freq > 0)[0]
        non_zero_recon_indices = np.where(reconstructed_freq > 0)[0]

        if len(non_zero_original_indices) > self.tail_count:
            original_tail_k_indices = non_zero_original_indices[
                np.argsort(original_freq[non_zero_original_indices])[
                    : self.tail_count
                ]
            ]
        else:
            original_tail_k_indices = np.argsort(original_freq)[
                : self.tail_count
            ]

        if len(non_zero_recon_indices) > self.tail_count:
            recon_tail_k_indices = non_zero_recon_indices[
                np.argsort(reconstructed_freq[non_zero_recon_indices])[
                    : self.tail_count
                ]
            ]
        else:
            recon_tail_k_indices = np.argsort(reconstructed_freq)[
                : self.tail_count
            ]

        original_tail_k_words = [
            self.idx_to_word[i] for i in original_tail_k_indices
        ]
        recon_tail_k_words = [
            self.idx_to_word[i] for i in recon_tail_k_indices
        ]

        # Calculate overlap
        common_words = set(original_tail_k_words) & set(recon_tail_k_words)
        overlap_ratio = len(common_words) / self.tail_count

        # Additional overlap analysis
        original_set = set(original_tail_k_words)
        recon_set = set(recon_tail_k_words)

        # Jaccard similarity
        union_size = len(original_set.union(recon_set))
        jaccard_sim = len(common_words) / union_size
        print(
            f"Jaccard similarity (tail {self.tail_count}): {jaccard_sim:.4f}"
        )
        print(
            f"  tail {self.tail_count} overlap: {len(common_words)}/{self.tail_count} "
            f"({overlap_ratio:.2%})"
        )

        # Log overlap metrics to file
        results_dir = self.get_results_dir()
        log_file = f"{results_dir}/overlap_metrics.txt"
        os.makedirs(results_dir, exist_ok=True)
        with open(log_file, "a") as f:
            f.write(f"tail {self.tail_count} words (Epoch {epoch}):\n")
            f.write(
                f"  tail {self.tail_count} overlap: {len(common_words)}/{self.tail_count} "
                f"({overlap_ratio:.2%})\n"
            )
            f.write(f"  Jaccard similarity: {jaccard_sim:.4f}\n")
            f.write(f"  Common words count: {len(common_words)}\n")
            f.write(f"  Common words: {sorted(list(common_words))[:20]}...\n")

        mmd_stat, p_value, mmd_bootstrap_list = self.mmd_linear_bootstrap_test(
            recon_tail_k_words, original_tail_k_words
        )

        return {
            "overlap_ratio": overlap_ratio,
            "p_value": p_value,
            "jaccard_sim": jaccard_sim,
            "common_words_count": len(common_words),
        }

    def get_results_dir(self):
        results_dir = (
            f"results/word_frequency/"
            f"{self.model.name}/{self.loss_type}/{self.learning_rate}/"
            f"{self.latent_dim}/{self.batch_size}/"
            f"{self.min_freq}_{self.tail_freq}/"
            f"{self.head_count}/{self.seed}"
        )
        return results_dir

    def mmd_linear(self, z_hat, z, sigma2_k=None):
        n = min([int(z.shape[0] / 2), int(z_hat.shape[0] / 2)])
        z_hat_1 = z_hat[0:n]
        z_hat_2 = z_hat[n : 2 * n]

        z_1 = z[0:n]
        z_2 = z[n : 2 * n]

        term_1 = (z_hat_1 - z_hat_2).pow(2).sum(1)
        term_2 = (z_1 - z_2).pow(2).sum(1)
        term_3 = (z_hat_1 - z_2).pow(2).sum(1)
        term_4 = (z_hat_2 - z_1).pow(2).sum(1)

        if sigma2_k is None:
            sigma2_k = torch.cat([term_1, term_2, term_3, term_4]).topk(2 * n)[
                0
            ][-1]

        res1 = torch.mean(torch.exp(-term_1 / 2.0 / sigma2_k))
        res2 = torch.mean(torch.exp(-term_2 / 2.0 / sigma2_k))
        res3 = torch.mean(torch.exp(-term_3 / 2.0 / sigma2_k))
        res4 = torch.mean(torch.exp(-term_4 / 2.0 / sigma2_k))
        return res1 + res2 - res3 - res4

    def make_masking(self, n):
        indice = np.arange(0, 2 * n)
        mask = np.zeros(2 * n, dtype=bool)
        rand_indice = np.random.choice(2 * n, n, replace=False)
        mask[rand_indice] = True

        return indice[mask], indice[~mask]

    def mmd_linear_bootstrap_test(self, z_hat, z, sigma2=None, iteration=1999):
        z = (
            torch.tensor(
                [self.word_to_idx[word] for word in z],
                device=self.device,
            )
            .float()
            .view(-1, 1)
        )

        z_hat = (
            torch.tensor(
                [self.word_to_idx[word] for word in z_hat],
                device=self.device,
            )
            .float()
            .view(-1, 1)
        )

        n = min([int(z.shape[0] / 2), int(z_hat.shape[0] / 2)])
        # n = min([int(len(z) / 2), int(len(z_hat) / 2)])
        if n == 0:
            print(
                "There is no such a sample. It may be due to an insufficient training. "
            )
            return None, None, None

        z_hat_1 = z_hat[0:n]
        z_hat_2 = z_hat[n : 2 * n]

        z_1 = z[0:n]
        z_2 = z[n : 2 * n]

        term_1 = (z_hat_1 - z_hat_2).pow(2).sum(1)
        term_2 = (z_1 - z_2).pow(2).sum(1)
        term_3 = (z_hat_1 - z_2).pow(2).sum(1)
        term_4 = (z_hat_2 - z_1).pow(2).sum(1)

        if sigma2 is None:
            sigma2 = torch.cat([term_1, term_2, term_3, term_4]).topk(2 * n)[
                0
            ][-1]

        mmd_stat = self.mmd_linear(z_hat, z, sigma2).item()
        mmd_bootstrap_list = []
        full_data = torch.cat([z_hat[0 : 2 * n], z[0 : 2 * n]], dim=0)

        for _ in range(iteration):
            ind_1, ind_2 = self.make_masking(2 * n)
            z_1 = full_data[ind_1]
            z_2 = full_data[ind_2]
            mmd_bootstrap_list.append(self.mmd_linear(z_1, z_2, sigma2).item())

        sum([int(stat > mmd_stat) for stat in mmd_bootstrap_list])
        p_value = (
            1 + sum([int(stat > mmd_stat) for stat in mmd_bootstrap_list])
        ) / (1 + iteration)

        return mmd_stat, p_value, mmd_bootstrap_list


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Word Frequency Analysis with Various VAE Models"
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="ae",
        choices=["ae", "vae", "t3", "pareto", "laplace"],
        help="Model type to use",
    )
    parser.add_argument(
        "--loss-type",
        type=str,
        default="l2",
        choices=["l1", "l2", "mae", "mse"],
        help="Reconstruction loss type",
    )
    parser.add_argument(
        "--latent-dim",
        type=int,
        default=128,
        help="Latent dimension size",
    )
    parser.add_argument(
        "--nu",
        type=float,
        default=3.1,
        help="Degrees of freedom for ParetoVAE",
    )

    parser.add_argument(
        "--epochs",
        type=int,
        default=50,
        help="Number of training epochs",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=512,
        help="Batch size for training",
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=1e-3,
        help="Learning rate for optimizer",
    )

    parser.add_argument(
        "--normalization",
        type=str,
        default="batchnorm",
        choices=[
            "batchnorm",
            "layernorm",
            "instancenorm",
            "groupnorm",
            "none",
        ],
        help="Normalization technique to use",
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="wikitext-2-raw-v1",
        choices=["wikitext-2-raw-v1", "wikitext-103-raw-v1"],
        help="WikiText dataset version to use",
    )
    parser.add_argument(
        "--min-freq",
        type=int,
        default=1,
        help="Minimum frequency for words to include in vocabulary",
    )
    parser.add_argument(
        "--tail-freq",
        type=int,
        default=10,
        help="Frequency threshold to consider a word as tail",
    )

    parser.add_argument(
        "--optimizer",
        type=str,
        default="adam",
        choices=[
            "adam",
            "adamw",
            "adam-wd",
            "adamw-wd",
            "rmsprop",
            "nadam",
            "radam",
        ],
        help="Optimizer to use",
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility",
    )
    parser.add_argument(
        "--no-wandb",
        action="store_true",
        help="Disable wandb logging",
    )

    args = parser.parse_args()
    set_random_seed(args.seed)

    analysis = WordFrequencyAnalysis(
        latent_dim=args.latent_dim,
        nu=args.nu,
        min_freq=args.min_freq,
        normalization=args.normalization,
        optimizer=args.optimizer,
        tail_freq=args.tail_freq,
        loss_type=args.loss_type,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        seed=args.seed,
        model_name=args.model_name,
        dataset=args.dataset,
        use_wandb=not args.no_wandb,
    )

    analysis.train()
