"""
Reconstruct texts from the original and perturbed features of a language model.
Please refer to the README for more details on how to run this script.
"""
from typing import Any, Optional, Callable
import argparse
import math
import os
import traceback
import json
import yaml
import tqdm
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
import wandb

from metamer.reconstruct.experiment_db import JSONExperimentDB
from metamer.icnn_replication.evaluation import (
    cosine_distance, correlation_distance, l2_distance
)

EPSILON = 1e-8
TOLERANCE = 0.0001  # tolerance for the error in correlation distance

def parse_output_dir(config):
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = config['model']['pretrained'].replace('/', '_')
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_language', 'results', model_name, dataset_name, exp_name)


def parse_dtype(config):
    # Resolve torch dtype from config
    dtype_str = config.get("dtype", "float32")
    if dtype_str == "bf16":
        dtype = torch.bfloat16
    elif dtype_str == "fp16":
        dtype = torch.float16
    elif dtype_str == "float32":
        dtype = torch.float32
    else:
        raise ValueError(f"Unsupported dtype: {dtype_str}")
    return dtype


def initialize_experiment_db(config, output_dir):
    experiment_db = JSONExperimentDB(
        exp_db_path=os.path.join(output_dir, 'experiment_db.json'),
        exp_db_lock_path=os.path.join(output_dir, 'experiment_db.lock'),
        param_clms=['layer_idx', 'text_name', 'target_corr_dist', 'noise_seed'],
        encoder_cls=FlatListEncoder,
    )
    """
    Add combination of parameters to the experiment database.
    """
    # add experiments to the database
    new_experiments = []
    for layer_idx in config['layer_indices']:
        for text_name in config['data']['text_names']:
            for target_corr_dist in config['noise']['target_corr_dists']:
                # corr dist 0 == no noise == no noise_seeds
                if target_corr_dist == 0.0:
                    new_experiments.append({
                        'layer_idx': layer_idx,
                        'text_name': text_name,
                        'target_corr_dist': target_corr_dist,
                        'noise_seed': None  # no noise
                    })
                else:
                    for noise_seed in config['noise']['noise_seeds']:
                        new_experiments.append({
                            'layer_idx': layer_idx,
                            'text_name': text_name,
                            'target_corr_dist': target_corr_dist,
                            'noise_seed': noise_seed
                        })
    experiment_db.add_experiments(new_experiments)
    return experiment_db


def get_tokens(tokenizer, config, parameters: list[dict[str, Any]], max_length: int, device):
    """
    Obtain the original tokens and masks.

    Args:
        tokenizer: The tokenizer to use.
        parameters: A list of dictionary containing the parameters for the experiment.
        max_length: Maximum length of the tokenized sequences.
        device
    Returns:
        true_tokens (torch.Tensor): The original tokens. (batch, seq_len).
        attention_mask (torch.Tensor): The attention mask for the tokens. (batch, seq_len).
        normal_token_mask (torch.Tensor): The mask for normal tokens. (batch, seq_len).
        info (dict): Additional information about the tokenization.
    """
    # load the original texts
    text_dir = config['data']['text_dir']
    text_names = [p['text_name'] for p in parameters]
    original_texts = []
    for text_name in text_names:
        text_path = os.path.join(text_dir, f"{text_name}.txt")
        if not os.path.exists(text_path):
            raise FileNotFoundError(f"Text file {text_path} does not exist.")
        with open(text_path, 'r') as f:
            original_texts.append(f.read())

    # tokenize
    inputs = tokenizer(original_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    true_tokens = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # create mask for normal tokens (not padding or special tokens)
    special_tokens = torch.tensor(tokenizer.all_special_ids, device=true_tokens.device)  # shape (n_special,)
    # Compare each token against all special tokens
    is_special_token = (true_tokens[..., None] == special_tokens).any(-1)  # shape (batch_size, seq_len)
    is_normal_token = ~is_special_token & attention_mask.bool()

    # decode the texts to get the tokenized texts
    decoded_texts = tokenizer.batch_decode(true_tokens, skip_special_tokens=True)
    info = {
        'true_texts': decoded_texts,
        'true_tokens': true_tokens.tolist(),
        'normal_token_mask': is_normal_token.tolist()
    }
    return true_tokens, attention_mask, is_normal_token, info


def _extract_embeddings(model: nn.Module, input_ids: torch.Tensor):
    return model.get_input_embeddings()(input_ids)


def _extract_features(
        model: nn.Module, 
        input_ids: Optional[torch.Tensor] = None, 
        input_embeds: Optional[torch.Tensor] = None, 
        attention_mask: Optional[torch.Tensor] = None, 
        layer_idx: Optional[int] = None,
    ) -> torch.Tensor:
    """
    Extract features from the model.

    Returns:
        features (torch.Tensor): The features extracted from the specified layer. (batch, seq_len, hidden_size)
    """
    if input_ids is not None:
        out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    elif input_embeds is not None:
        out = model(inputs_embeds=input_embeds, attention_mask=attention_mask, output_hidden_states=True)
    else:
        raise ValueError("Either input_ids or input_embeds must be provided.")
    # out: tuple[torch.Tensor]
    features = out.hidden_states[1:]  # remove the input embedding
    features = features[layer_idx]
    return features


def _corr_distance(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    Correlation distance between two tensors.
    """
    a = a.flatten()
    b = b.flatten()
    a = a - a.mean()
    b = b - b.mean()
    cov = (a * b).mean()
    denom = a.std(unbiased=False) * b.std(unbiased=False) + EPSILON
    return 1.0 - (cov / denom).item()


def add_noise(
    x: torch.Tensor,
    d_c: float,
    tol: float = TOLERANCE,
    max_iter: int = 30,
    seed: Optional[int] = None
):
    """
    Add Gaussian noise to the given feature x so that corr-distance(x, x+noise) ≈ d_c within `tol` error.
    Uses a fixed noise vector and bisection on its scale.

    Parameters
    ----------
    x : torch.Tensor
        Original tensor.
    d_c : float
        Desired correlation distance (0 ≤ d_c < 1).
    tol : float, optional
        Allowed absolute error |d_actual - d_c|.
    max_iter : int, optional
        Maximum iterations of the root-finder (bisection + bracketing).
    seed: int, optional
        Random seed for reproducibility. If None, uses the current random state.

    Returns
    -------
    torch.Tensor
        x + \epsilon whose sample correlation distance is within `tol`
        (or the nearest value reached in `max_iter` iterations).
    float
        Actual correlation distance of the returned tensor.
    float
        Scale of the noise vector used to perturb the original tensor.
    """
    if seed is not None:
        torch.manual_seed(seed)
    assert 0.0 <= d_c < 1.0, "0 <= d_c < 1 required."
    assert tol > 0.0, "tol must be positive."
    
    if d_c <= tol:  # no noise needed
        return x.clone(), 0.0, 0.0

    # initial guess
    var_x = x.var(unbiased=False)
    target_r = 1.0 - d_c                       # desired correlation
    initial_std = math.sqrt(var_x) * math.sqrt(1.0 / target_r**2 - 1.0)

    noise = torch.randn_like(x)
    def dist(std: float) -> float:
        return _corr_distance(x, x + noise * std)

    # bracket the solution
    lo, hi = 0.0, initial_std
    if dist(hi) < d_c:                              # initial guess too small
        for _ in range(max_iter):
            lo = hi
            hi *= 2.0
            d_hi = dist(hi)
            if d_hi >= d_c:
                break

    # bisection
    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        d_mid = dist(mid)
        if abs(d_mid - d_c) <= tol:             # tolerance satisfied
            return x + noise * mid, d_mid, mid
        if d_mid < d_c:                         # need more noise
            lo = mid
        else:                                   # need less noise
            hi = mid

    # max_iter reached – return best available
    return x + noise * mid, d_mid, mid


def get_features(parameters, model, layer_idx, true_tokens, attention_mask, device):
    """
    Get the true embeddings, features, and target features from the model for the given tokens.
    """
    with torch.no_grad():
        true_embeddings = _extract_embeddings(model, true_tokens)
        true_features = _extract_features(model, input_ids=true_tokens, attention_mask=attention_mask, layer_idx=layer_idx)

    # noise true features to obtain target features
    target_features = []
    true_target_corr_dists = []
    noise_stds = []
    for i, param in enumerate(parameters):
        # param: dict[str, Any]
        tf, d, std = add_noise(true_features[i], param['target_corr_dist'])
        target_features.append(tf)
        true_target_corr_dists.append(d)
        noise_stds.append(std)
    target_features = torch.stack(target_features, dim=0).to(device)
    true_features = true_features.to(device)
    true_embeddings = true_embeddings.to(device)
    noise_info = {
        'true_target_correlation_distance': true_target_corr_dists,
        'true_target_cosine_distance': cosine_distance(true_features, target_features).tolist(),
        'true_target_l2_distance': l2_distance(true_features, target_features).tolist(),
        'noise_stds': noise_stds,
    }
    return true_embeddings, true_features, target_features, noise_info


class CosineLoss:
    """
    Cosine similarity-based loss that returns sample-wise losses.

    Modes:
        - 'token': average cosine loss per token (sample-wise)
        - 'layer': cosine loss on the flattened (seq_len * dim) vectors (sample-wise)
    """
    def __init__(self, mode: str = 'token'):
        """
        Args:
            mode (str): 'token' or 'layer'
        """
        if mode not in ['token', 'layer']:
            raise ValueError(f"Invalid mode '{mode}'. Must be 'token' or 'layer'.")
        self.mode = mode

    def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Compute cosine loss.

        Args:
            x (torch.Tensor): shape (batch, seq_len, dim)
            y (torch.Tensor): shape (batch, seq_len, dim)

        Returns:
            torch.Tensor: shape (batch,), sample-wise loss
        """
        if self.mode == 'token':
            # token-wise cosine similarity: average across tokens
            x_norm = F.normalize(x, dim=-1)
            y_norm = F.normalize(y, dim=-1)
            sim = (x_norm * y_norm).sum(dim=-1)  # (batch, seq_len)
            loss = 1.0 - sim  # (batch, seq_len)
            return loss.mean(dim=-1)  # (batch,)
        else:
            # layer-wise: flatten each sample
            B = x.size(0)
            x_flat = x.reshape(B, -1)
            y_flat = y.reshape(B, -1)
            x_norm = F.normalize(x_flat, dim=-1)
            y_norm = F.normalize(y_flat, dim=-1)
            sim = (x_norm * y_norm).sum(dim=-1)  # (batch,)
            return 1.0 - sim  # (batch,)


class MSELoss:
    def __init__(self):
        self.loss_func = torch.nn.MSELoss(reduction='none')
    def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Sample-wise MSE loss.
        """
        return self.loss_func(x, y).mean(dim=(1, 2))  # (batch,)


class MSECosineLoss:
    def __init__(self, mode, cos_weight: float = 1.0):
        self.cos_weight = cos_weight
        self.mse = MSELoss()
        if mode in ['token', 'layer']:
            self.cosine = CosineLoss(mode=mode)
        elif mode == 'both':
            self.cosine = CosineDoubleLoss()
    def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        mse_loss = self.mse(x, y)
        cos_loss = self.cosine(x, y)
        return mse_loss + self.cos_weight * cos_loss

class CosineDoubleLoss:
    def __init__(self):
        self.cos_layer = CosineLoss(mode='layer')
        self.cos_token = CosineLoss(mode='token')
    def __call__(self, x, y):
        cos_layer = self.cos_layer(x, y)  # (batch,)
        cos_token = self.cos_token(x, y)  # (batch,)
        return cos_layer + cos_token  # (batch,)


def load_loss_function(config: dict):
    if 'loss_func' not in config:
        return MSELoss()  # default loss function
    elif config['loss_func'] == 'cosine-token-wise':
        return CosineLoss(mode='token')
    elif config['loss_func'] == 'cosine-layer-wise':
        return CosineLoss(mode='layer')
    elif config['loss_func'] == 'mse+cosine-layer-wise':
        return MSECosineLoss(mode='layer')
    elif config['loss_func'] == 'mse+cosine-token-wise':
        return MSECosineLoss(mode='token')
    elif config['loss_func'] == 'cosine-token-wise+cosine-layer-wise':
        return CosineDoubleLoss()
    elif config['loss_func'] == 'mse+cosine':
        return MSECosineLoss(mode='both')
    else:
        raise ValueError(f"Unknown loss function: {config['loss_func']}")


class LMFeatureInversionPipeline:
    def __init__(
            self,
            model: nn.Module,
            layer_idx: int,
            recon_logits: torch.Tensor,
            num_iterations: int,
            optimizer: torch.optim.Optimizer,
            loss_func: Optional[Callable],
            scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
            use_one_hot: bool = False,
            use_gumbel: bool = False,
            temperature: float = 1.0,
            temp_decay: bool = False,
            temp_min: float = 0.01,
            entropy_weight: Optional[float] = None,
            token_eval_metric: Optional[Callable] = None,
            embed_eval_metric: Optional[Callable] = None,
            feature_eval_metric: Optional[Callable] = None,
            use_wandb: bool = False,
            wandb_log_interval: int = 1,
            eval_interval: int = 1,
        ):
        """
        Initialize the feature inversion pipeline for a language model.

        Args:
            recon_logits: The logit matrix to optimize, shape (batch_size, n_tokens, vocab_size).
            one_hot: If True, turn the logits into one-hot vectors before calculating the embeddings.
            use_gumbel: If True, add Gumbel noise to the logits.
            temperature: The temperature for softmax on logits.
            temperature_decay: whether or not to decay the temperature parameter.
            temp_min: The minimum temperature to decay to.
            entropy_weight: If not None, use the entropy of the logits as a regularization term.
        """
        self.model = model
        self.layer_idx = layer_idx
        self.recon_logits = recon_logits  # (batch_size, n_tokens, vocab_size)
        self.num_iterations = num_iterations
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.scheduler = scheduler
        self.use_one_hot = use_one_hot
        self.use_gumbel = use_gumbel
        self.temperature = temperature
        self.temp_decay = temp_decay
        self.temp_min = temp_min
        self.entropy_weight = entropy_weight
        self.token_eval_metric = token_eval_metric
        self.embed_eval_metric = embed_eval_metric
        self.feature_eval_metric = feature_eval_metric
        self.use_wandb = use_wandb
        self.wandb_log_interval = wandb_log_interval
        self.eval_interval = eval_interval
        self.embed_weight = model.get_input_embeddings().weight  # (vocab_size, emb_dim)

    def __call__(
            self, 
            target_features: torch.Tensor, 
            attention_mask: torch.Tensor, 
            wandb_names: Optional[list[str]] = None,
        ):
        history = []  # list[dict[str, list[float]]] # store evaluation metrics for each iteration
        pbar = tqdm.tqdm(range(self.num_iterations), total=self.num_iterations, dynamic_ncols=True)
        for step in pbar:
            self.optimizer.zero_grad()
            t = self.temperature if not self.temp_decay else self.get_temp(step)
            token_probs, recon_embeds = self.get_recon_embeds(t)  # (batch_size, n_tokens, vocab_size) and (batch_size, n_tokens, emb_dim)

            # shape (batch_size, n_tokens, emb_dim)
            recon_features = _extract_features(self.model, input_embeds=recon_embeds, attention_mask=attention_mask, layer_idx=self.layer_idx)
            per_sample_loss = self.loss_func(target_features, recon_features)  # (batch,)
            total_loss = per_sample_loss.sum()  # scalar loss for the batch
            
            if self.entropy_weight is not None:
                token_entropy = -(token_probs * token_probs.clamp(min=1e-9).log()).sum(dim=-1)
                per_sample_entropy = token_entropy.mean(dim=-1)  # average over tokens per sample
                entropy_loss = per_sample_entropy.sum()  # scalar
                total_loss = total_loss + self.entropy_weight * entropy_loss

            total_loss.backward()  # if you use mean, the gradient depends on the batch size
            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()

            # calculate and store metrics
            results = {'loss': per_sample_loss.tolist()}
            if self.temperature is not None:
                results['temperature'] = [t] * len(per_sample_loss)
            if self.entropy_weight is not None:
                results['entropy_loss'] = per_sample_entropy.tolist()
            # calculate metrics
            if step % self.eval_interval == 0 or step == self.num_iterations - 1:
                results.update(self.evaluate(self.recon_logits, recon_embeds, recon_features, token_probs=token_probs))
            history.append(results)

            if self.use_wandb and step % self.wandb_log_interval == 0:
                self.log_wandb(step, results, wandb_names)
        return history
    
    def get_recon_embeds(self, t: float) -> tuple[torch.Tensor, torch.Tensor]:     
        if self.use_one_hot:
            # one-hot embedding and use trick similar to Gumbel Softmax to 
            # obtain the gradient on the logits
            index = self.recon_logits.argmax(dim=-1, keepdim=True)  # (batch, seq_len, 1)
            one_hot = torch.zeros_like(self.recon_logits).scatter_(-1, index, 1.0)

            # softmax with temperature for gradient
            soft_logit = F.softmax(self.recon_logits / t, dim=-1)

            # use the trick to obtain the gradient on the logits
            # equal to one-hot but the gradient flow through the softmax
            y_st = (one_hot - soft_logit).detach() + soft_logit
            token_probs = soft_logit
            recon_embeds = y_st @ self.embed_weight  # (batch_size, n_tokens, emb_dim)
        elif self.use_gumbel:
            # use gumbel softmax
            # Sample gumbel noise
            gumbels = -torch.empty_like(self.recon_logits).exponential_().log()
            y = (self.recon_logits + gumbels) / t
            y_soft = F.softmax(y, dim=-1)
            # Straight-through trick
            index = y_soft.argmax(dim=-1, keepdim=True)  # shape: (B, T, 1)
            y_hard = torch.zeros_like(y_soft).scatter_(-1, index, 1.0)
            y_st = (y_hard - y_soft).detach() + y_soft  # stop gradient to hard version
            token_probs = y_st
            recon_embeds = y_st @ self.embed_weight  # use exact embedding of one-hot tokens
        else:
            # softmax on the logits
            token_probs =  F.softmax(self.recon_logits / t, dim=-1)
            recon_embeds = token_probs @ self.embed_weight
        return token_probs, recon_embeds  # (batch_size, n_tokens, vocab_size) and (batch_size, n_tokens, emb_dim)

    def evaluate(self, logits, recon_embeds, recon_features, token_probs):
        eval_results = {}
        if self.token_eval_metric is not None:
            eval_results.update(self.token_eval_metric(logits, token_probs))
        if self.embed_eval_metric is not None:
            eval_results.update(self.embed_eval_metric(recon_embeds))
        if self.feature_eval_metric is not None:
            eval_results.update(self.feature_eval_metric(recon_features))
        return eval_results
    
    def log_wandb(self, step: int, results: dict, wandb_names: Optional[list[str]] = None):
        if wandb_names is None:
            wandb_names = [f"stimulus_{i}" for i in range(len(results['loss']))]
        # aggregate the results by stimulus
        log = {}  # name -> metric -> value
        for i, name in enumerate(wandb_names):
            log[name] = {}
            for metric_name, values in results.items():
                log[name][metric_name] = values[i]
        wandb.log(log, step=step)
    
    def get_temp(self, step: int):
        if self.temperature is None:
            raise ValueError("Temperature is not set. Please set temperature or temp_decay to True.")
        if not self.temp_decay:
            return self.temperature
        # correct: decay over time
        temp = self.temperature - (self.temperature - self.temp_min) * (step / self.num_iterations)
        return max(temp, self.temp_min)


class TokenMetrics:
    """
    Evaluate the reconstructed logits against the true tokens. Calculations are 
    done per sample on non-padding tokens, and each metrics are claculated 
    for all tokens and for normal tokens only (non-padding, non-special).

    Metrics:
        - top_1_accuracy
        - top_5_accuracy
        - top_1_accuracy_normal_tokens
        - top_5_accuracy_normal_tokens
        - recon_text
        - recon_text_normal_tokens
    """
    def __init__(self, tokenizer, true_tokens, attention_mask, normal_token_mask, embed_weight):
        self.tokenizer = tokenizer
        self.true_tokens = true_tokens
        self.attention_mask = attention_mask
        self.normal_token_mask = normal_token_mask
        self.embed_weight = embed_weight

        # number of tokens per sample
        self.num_non_padding_tokens = attention_mask.sum(dim=-1)  # (batch_size,)
        self.num_normal_tokens = normal_token_mask.sum(dim=-1)  # (batch_size,)

    def __call__(self, recon_logits, token_probs) -> dict[str, list[float]]:
        """
        Args:
            recon_logits: (batch_size, n_tokens, vocab_size)
        Returns:
            dict[str, list[float]]: A dictionary with evaluation metrics. 
                i-th element of each list corresponds to the i-th sample.
        """
        metrics = {}  # dict[str, list[float]]
        
        # top-1 reconstruction 
        top1_tokens = recon_logits.argmax(dim=-1)    # (batch_size, n_tokens)
        metrics['recon_tokens'] = top1_tokens.tolist()
        # texts are removed because they broke csv file
        #metrics['recon_text'] = self.tokenizer.batch_decode(top1_tokens, skip_special_tokens=False)
        #metrics['recon_text_normal_tokens'] = self.tokenizer.batch_decode(top1_tokens, skip_special_tokens=True)

        # top-1 accuracy
        correct = (top1_tokens == self.true_tokens)  # including padding tokens
        correct_non_padding = correct & self.attention_mask.bool()  # only non-padding tokens
        correct_normal_tokens = correct & self.normal_token_mask.bool()  # only non-padding, non-special tokens
        n_correct_non_padding = correct_non_padding.sum(dim=-1)  # (batch_size,)
        n_correct_normal_tokens = correct_normal_tokens.sum(dim=-1)  # (batch_size,)
        metrics['top_1_accuracy'] = (n_correct_non_padding / self.num_non_padding_tokens).tolist()
        metrics['top_1_accuracy_normal_tokens'] = (n_correct_normal_tokens / self.num_normal_tokens).tolist()
        # top-1 error rate
        metrics['top_1_error_rate'] = [1 - acc for acc in metrics['top_1_accuracy']]
        metrics['top_1_error_rate_normal_tokens'] = [1 - acc for acc in metrics['top_1_accuracy_normal_tokens']]

        # top-5 accuracy
        top5_tokens = torch.topk(recon_logits, k=5, dim=-1).indices  # (batch_size, n_tokens, 5)
        correct_top5 = (top5_tokens == self.true_tokens.unsqueeze(-1)).any(dim=-1)  # (batch_size, n_tokens)
        correct_top5_non_padding = correct_top5 & self.attention_mask.bool()  # only non-padding tokens
        correct_top5_normal_tokens = correct_top5 & self.normal_token_mask.bool()  # only non-padding, non-special tokens
        n_correct_top5_non_padding = correct_top5_non_padding.sum(dim=-1)  # (batch_size,)
        n_correct_top5_normal_tokens = correct_top5_normal_tokens.sum(dim=-1)  # (batch_size,)
        metrics['top_5_accuracy'] = (n_correct_top5_non_padding / self.num_non_padding_tokens).tolist()
        metrics['top_5_accuracy_normal_tokens'] = (n_correct_top5_normal_tokens / self.num_normal_tokens).tolist()
        # top-5 error rate
        metrics['top_5_error_rate'] = [1 - acc for acc in metrics['top_5_accuracy']]
        metrics['top_5_error_rate_normal_tokens'] = [1 - acc for acc in metrics['top_5_accuracy_normal_tokens']]
        
        # if embed_weight is provided, evaluate the tokens that has closest embeddings
        recon_embeds = token_probs @ self.embed_weight  # (batch_size, n_tokens, emb_dim)
        # find the closest tokens in the embedding space
        dot = recon_embeds @ self.embed_weight.T  # (batch_size, n_tokens, vocab_size)
        nearest_tokens = dot.argmax(dim=-1)  # (batch_size, n_tokens)
        metrics['proj_nearest_tokens'] = nearest_tokens.tolist()

        correct = (nearest_tokens == self.true_tokens)  # including padding tokens
        correct_non_padding = correct & self.attention_mask.bool()  # only non-padding tokens
        correct_normal_tokens = correct & self.normal_token_mask.bool()
        n_correct_non_padding = correct_non_padding.sum(dim=-1)  # (batch_size,)
        n_correct_normal_tokens = correct_normal_tokens.sum(dim=-1)  # (batch_size,)
        metrics['proj_nearest_token_accuracy'] = (n_correct_non_padding / self.num_non_padding_tokens).tolist()
        metrics['proj_nearest_token_accuracy_normal_tokens'] = (n_correct_normal_tokens / self.num_normal_tokens).tolist()
        
        return metrics


class EmbedMetrics:
    """
    Evaluate the reconstructed embeddings against the true embeddings.
    Calculations are done per sample on non-padding tokens.
    Metrics:
        - embed_cosine_distance
        - embed_correlation_distance
        - embed_l2_distance
    """
    def __init__(self, true_embeddings, attention_mask):
        """
        Args:
            true_embeddings: (batch_size, n_tokens, emb_dim)
            attention_mask: (batch_size, n_tokens) mask for non-padding tokens.
        """
        self.true_embeddings = true_embeddings
        self.attention_mask = attention_mask.bool()

    def __call__(self, recon_embeddings) -> dict[str, list[float]]:
        """
        Args:
            recon_embeddings: (batch_size, n_tokens, emb_dim)
        Returns:
            dict[str, list[float]]: A dictionary with evaluation metrics. 
                i-th element of each list corresponds to the i-th sample.
        """
        cos_dists = []
        corr_dists = []
        l2_dists = []
        for i in range(recon_embeddings.shape[0]):
            mask = self.attention_mask[i]  # bool tensor (n_tokens,)
            true_emb = self.true_embeddings[i][mask]  # (n_non_pad_tokens, emb_dim)
            recon_emb = recon_embeddings[i][mask]  # (n_non_pad_tokens, emb_dim)
            cos_dists.append(cosine_distance(true_emb.unsqueeze(0), recon_emb.unsqueeze(0)).item())
            corr_dists.append(correlation_distance(true_emb.unsqueeze(0), recon_emb.unsqueeze(0)).item())
            l2_dists.append(l2_distance(true_emb.unsqueeze(0), recon_emb.unsqueeze(0)).item())
        return {
            'embed_cosine_distance': cos_dists,
            'embed_correlation_distance': corr_dists,
            'embed_l2_distance': l2_dists
        }


class FeatureMetrics:
    """
    Evaluation metrics for the reconstructed features.
    Metrics:
        - true_recon_cosine_distance
        - true_recon_correlation_distance
        - true_recon_l2_distance
        - target_recon_cosine_distance
        - target_recon_correlation_distance
        - target_recon_l2_distance
    """
    def __init__(self, true_features: torch.Tensor, target_features: torch.Tensor, attention_mask: torch.Tensor):
        self.true_features = true_features
        self.target_features = target_features
        self.attention_mask = attention_mask.bool()

    def __call__(self, recon_features: torch.Tensor) -> dict[str, list[float]]:
        """
        Args:
            recon_features: (batch_size, n_tokens, hidden_size)
        Returns:
            dict[str, list[float]]: A dictionary with evaluation metrics. 
                i-th element of each list corresponds to the i-th sample.
        """
        metrics = {
            "true_recon_cosine_distance": [],
            "true_recon_correlation_distance": [],
            "true_recon_l2_distance": [],
            "target_recon_cosine_distance": [],
            "target_recon_correlation_distance": [],
            "target_recon_l2_distance": [],
        }
        for i in range(recon_features.size(0)):
            m = self.attention_mask[i]
            t = self.true_features[i][m]
            r = recon_features[i][m]
            g = self.target_features[i][m]

            metrics["true_recon_cosine_distance"].append(cosine_distance(t.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["true_recon_correlation_distance"].append(correlation_distance(t.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["true_recon_l2_distance"].append(l2_distance(t.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["target_recon_cosine_distance"].append(cosine_distance(g.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["target_recon_correlation_distance"].append(correlation_distance(g.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["target_recon_l2_distance"].append(l2_distance(g.unsqueeze(0), r.unsqueeze(0)).item())
        return metrics
        

def resolve_wandb(config: dict[str, Any], parameters: list[dict[str, Any]]):
    """
    If wandb is configured, initialize wandb project and return the metric prefix for each sample.

    Args:
        config (dict[str, Any]): Configuration dictionary.
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
    Returns:
        use_wandb (bool): Whether to use wandb for logging.
        prefix (list[str] | None): List of names for each sample for wandb logging, or None if not using wandb.
    """
    if config.get('wandb', False):
        # project name and run name of this run
        model_name = config['model']['pretrained'].replace('/', '_')
        project = 'language_readout_' + model_name + '_' + config['data']['dataset_name']
        name = f'{config["exp_name"]}_layer_{parameters[0]["layer_idx"]}' # assuming layers are shared across samples in a batch
        wandb.init(project=project, name=name, config=config)
        # run names prefix for each sample
        prefix = [p['text_name'] + f'_distance{p["target_corr_dist"]}' + f'_seed{p["noise_seed"]}' for p in parameters]
        return True, prefix
    else:
        return False, None


class FlatListEncoder(json.JSONEncoder):
    """
    Custom JSON encoder that saves list as flat.
    """
    def iterencode(self, obj, _one_shot=False):
        """Encode the given object and yield each string representation as available."""
        if isinstance(obj, dict):
            yield '{\n'
            items = []
            for key, value in obj.items():
                key_str = json.dumps(key)
                if isinstance(value, list):
                    value_str = '[' + ', '.join(json.dumps(item) for item in value) + ']'
                else:
                    value_str = json.dumps(value)
                items.append(f'    {key_str}: {value_str}')
            yield ',\n'.join(items)
            yield '\n}'
        else:
            # Fallback to default behavior
            yield from super().iterencode(obj, _one_shot)


def save_results(
    output_dir: str,
    parameters: list[dict[str, Any]],
    info: dict[str, list],
    noise_info: dict[str, list],
    history: list[dict[str, list[float]]],
    tokenizer,
):
    """
    Save the final results of the experiment.

    Args:
        config (dict[str, Any]): Configuration dictionary.
        output_dir (str): Output directory where results will be saved.
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
        info (dict[str, list]): Dictionary containing information about the noise and true-target distances.
    """
    summaries = []
    for i in range(len(parameters)):
        # full history for this run
        hist = [
            {name: value[i] for name, value in h.items()} for h in history
        ]
        # short summary dictionary for this run
        # parameters + initial metrics + final metrics
        summary = parameters[i]
        summary.update({name: value[i] for name, value in info.items() if name != 'true_texts'})  # remove text because it may broke json format
        summary.update({name: value[i] for name, value in noise_info.items()})
        summary.update(hist[-1])
        summaries.append(summary)

        # convert history to DataFrame
        hist = pd.DataFrame(hist)

        # decode tokens into text
        recon_tokens = summary['recon_tokens']
        recon_text = tokenizer.decode(recon_tokens, skip_special_tokens=True)

        # save them
        run_path = os.path.join(
            output_dir, 
            f'layer_{summary["layer_idx"]}',
            f'corr_dist_{summary["target_corr_dist"]}',
            f'noise_seed_{summary["noise_seed"]}',
            f'text_{summary["text_name"]}'
        )
        hist_path = os.path.join(run_path, 'history.csv')
        sum_path = os.path.join(run_path, 'summary.json')
        text_path = os.path.join(run_path, 'text.txt')
        os.makedirs(run_path, exist_ok=True)
        
        # save history as csv
        hist.to_csv(hist_path, index=False, escapechar='\\')
        with open(sum_path, 'w') as f:
            json.dump(summaries[i], f, cls=FlatListEncoder)
        # save reconstructed text
        with open(text_path, 'w') as f:
            f.write(recon_text)
    return summaries


def main(config, device):
    # load model and tokenizer
    dtype = parse_dtype(config)
    model = AutoModelForCausalLM.from_pretrained(config['model']['pretrained'], torch_dtype=dtype)
    tokenizer = AutoTokenizer.from_pretrained(config['model']['pretrained'])
    if tokenizer.pad_token is None:
        # set pad token to eos token if not set
        tokenizer.pad_token = tokenizer.eos_token
    model = model.eval()
    model = model.to(device)
    vocab_size = model.get_input_embeddings().weight.shape[0]

    # output directory
    output_dir = parse_output_dir(config)  # base output directory of all experiments under this config
    os.makedirs(output_dir, exist_ok=True)
    # save the config to the output directory
    with open(os.path.join(output_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)

    # initialize the experiment database
    exp_db = initialize_experiment_db(config, output_dir)

    # run experiments: different layers are always separated for simplicity
    for layer_idx in config['layer_indices']:
        while True:
            parameters = exp_db.get_next_experiments(config['batch_size'], layer_idx=layer_idx)
            if not parameters:
                print(f"No more parameters to process for layer {layer_idx}.")
                break

            try:
                # get original tokens and attention mask, true_token_mask
                true_tokens, attention_mask, normal_token_mask, info = get_tokens(tokenizer, config, parameters, config['max_length'], device)
                # get true_embeddings, true_features, target_features
                true_embeddings, true_features, target_features, noise_info = get_features(parameters, model, layer_idx, true_tokens, attention_mask, device)
                # directory and wandb setup
                use_wandb, wandb_names = resolve_wandb(config, parameters)
                
                # create logit matrix, which is the target of the optimization
                # (batch, n_tokens, vocab_size)
                batch_size, n_tokens = true_tokens.shape
                recon_logits = torch.randn(batch_size, n_tokens, vocab_size, requires_grad=True, device=device, dtype=dtype)
                optimizer = AdamW([recon_logits], lr=config['optimizer']['lr'])
                if 'scheduler' in config['optimizer']:
                    scheduler = optim.lr_scheduler.LinearLR(optimizer, **config['optimizer']['scheduler'], total_iters=config['pipeline']['num_iterations'])
                else:
                    scheduler = None
                loss_func = load_loss_function(config)

                # initialize the optimizer
                pipeline = LMFeatureInversionPipeline(
                    model=model,
                    layer_idx=layer_idx,
                    recon_logits=recon_logits,
                    optimizer=optimizer,
                    loss_func=loss_func,
                    scheduler=scheduler,
                    token_eval_metric=TokenMetrics(tokenizer, true_tokens, attention_mask, normal_token_mask, embed_weight=model.get_input_embeddings().weight),
                    embed_eval_metric=EmbedMetrics(true_embeddings, attention_mask),
                    feature_eval_metric=FeatureMetrics(true_features, target_features, attention_mask),
                    use_wandb=use_wandb,
                    **config['pipeline']                    
                )
                history = pipeline(
                    target_features=target_features,
                    attention_mask=attention_mask,
                    wandb_names=wandb_names
                )
                # save the results
                summaries = save_results(output_dir, parameters, info, noise_info, history, tokenizer=tokenizer)

                # update experiment database
                exp_db.update_experiment_info(summaries)
                exp_db.mark_experiment_status(parameters, 'finished')

                if use_wandb:
                    wandb.finish()

            except KeyboardInterrupt:
                print("Interrupted. Reverting status to 'pending'.")
                exp_db.mark_experiment_status(parameters, 'pending')
                if use_wandb:
                    wandb.finish()
                raise
            except Exception as e:
                print(f'Error: {e}. Marking status as "error".')
                traceback.print_exc()
                exp_db.mark_experiment_status(parameters, 'error')
                if use_wandb:
                    wandb.finish()


def find_max_batch_size(config: dict, device: str, limit: int = 704):
    """
    Find the maximum batch size that can be processed on a given device
    """
    import tempfile
    print('Loading model and tokenizer...')
    dtype = parse_dtype(config)
    model = AutoModelForCausalLM.from_pretrained(config['model']['pretrained'], torch_dtype=dtype)
    tokenizer = AutoTokenizer.from_pretrained(config['model']['pretrained'])
    if tokenizer.pad_token is None:
        # set pad token to eos token if not set
        tokenizer.pad_token = tokenizer.eos_token
    model = model.eval()
    model = model.to(device)
    print('Model and tokenizer loaded.')

    # override config parameters
    config['exp_name'] = 'find_max_batch_size'
    config['pipeline']['num_iterations'] = 32

    for layer_idx in config['layer_indices']:
        # output directory
        output_dir = tempfile.mkdtemp(prefix='find_max_batch_size_')
        os.makedirs(output_dir, exist_ok=True)

        # initialize the experiment database
        exp_db = initialize_experiment_db(config, output_dir)

        # initial guess for batch size
        start = 60
        low = start
        high = start

        # Exponential phase
        while True:
            fits = batch_size_fits(config, exp_db, high, device, layer_idx, model, tokenizer, dtype)
            if fits:
                low = high
                high *= 2
                if high > limit:
                    high = limit
                    break
            else:
                break

        # Binary search phase
        while low < high:
            mid = (low + high + 1) // 2
            fits = batch_size_fits(config, exp_db, mid, device, layer_idx, model, tokenizer, dtype)
            if fits:
                low = mid
            else:
                high = mid - 1

        print(f"[RESULT] Maximum batch size that fits layer_{layer_idx}: {low}")


def batch_size_fits(config, exp_db, batch_size, device, layer_idx, model, tokenizer, dtype) -> bool:
    # remove cache
    import gc
    gc.collect()
    torch.cuda.empty_cache()

    print('Testing batch size:', batch_size)
    fits = True
    parameters = exp_db.get_next_experiments(batch_size, layer_idx=layer_idx)
    if len(parameters) < batch_size:
        print('WARNING: batch size hit the limit of the samples in the database.')
        print('Actual batch size:', len(parameters))
    vocab_size = model.get_input_embeddings().weight.shape[0]
    try:
        # get original tokens and attention mask, true_token_mask
        true_tokens, attention_mask, normal_token_mask, info = get_tokens(tokenizer, config, parameters, config['max_length'], device)
        # get true_embeddings, true_features, target_features
        true_embeddings, true_features, target_features, noise_info = get_features(parameters, model, layer_idx, true_tokens, attention_mask, device)
        # directory and wandb setup
        
        # create logit matrix, which is the target of the optimization
        # (batch, n_tokens, vocab_size)
        batch_size, n_tokens = true_tokens.shape
        recon_logits = torch.randn(batch_size, n_tokens, vocab_size, requires_grad=True, device=device, dtype=dtype)
        optimizer = AdamW([recon_logits], lr=config['optimizer']['lr'])
        if 'scheduler' in config['optimizer']:
            scheduler = optim.lr_scheduler.LinearLR(optimizer, **config['optimizer']['scheduler'], total_iters=config['pipeline']['num_iterations'])
        else:
            scheduler = None
        loss_func = load_loss_function(config)

        # initialize the optimizer
        pipeline = LMFeatureInversionPipeline(
            model=model,
            layer_idx=layer_idx,
            recon_logits=recon_logits,
            optimizer=optimizer,
            loss_func=loss_func,
            scheduler=scheduler,
            token_eval_metric=TokenMetrics(tokenizer, true_tokens, attention_mask, normal_token_mask, embed_weight=model.get_input_embeddings().weight),
            embed_eval_metric=EmbedMetrics(true_embeddings, attention_mask),
            feature_eval_metric=FeatureMetrics(true_features, target_features, attention_mask),
            use_wandb=False,
            **config['pipeline']                    
        )
        history = pipeline(
            target_features=target_features,
            attention_mask=attention_mask,
            wandb_names=None
        )
    except KeyboardInterrupt:
        print("Interrupted. Reverting status to 'pending'.")
        raise
    except  RuntimeError as e:
        print(e)
        fits = False
    except Exception as e:
        print(f'Error: {e}.')
        traceback.print_exc()
        if 'memory' in str(e).lower():
            fits = False
        else:
            raise
    finally:
        # undo the experiment db status change anyway
        exp_db.mark_experiment_status(parameters, 'pending')

    if fits:
        print('Batch size', batch_size, 'fits.')
    else:
        print('Batch size', batch_size, 'does not fit.')
    return fits


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct texts from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    parser.add_argument("--device", type=str, default='cuda')
    parser.add_argument("--find_max_batch_size", action='store_true', help="Find the maximum batch size that fits on the device.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'text_names' not in config['data']:
        # use text_names_path
        with open(config['data']['text_names_path'], 'r') as f:
            config['data']['text_names'] = yaml.safe_load(f)

    if args.find_max_batch_size:
        print("Finding maximum batch size...")
        find_max_batch_size(config, args.device)
    else:
        # normal reconstruction
        main(config, device=args.device)