"""
Reconstruct texts from the original and perturbed features of a language model.
Please refer to the README for more details on how to run this script.
"""
from typing import Any, Optional, Callable
import argparse
import math
import os
import traceback
import json
import yaml
import tqdm
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
import wandb

from metamer.reconstruct.experiment_db import JSONExperimentDB
from metamer.icnn_replication.evaluation import (
    cosine_distance, correlation_distance, l2_distance
)

EPSILON = 1e-8
TOLERANCE = 0.0001  # tolerance for the error in correlation distance

def parse_output_dir(config):
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = config['model']['pretrained'].replace('/', '_')
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_language', 'results', model_name, dataset_name, exp_name)


def initialize_experiment_db(config, output_dir):
    experiment_db = JSONExperimentDB(
        exp_db_path=os.path.join(output_dir, 'experiment_db.json'),
        exp_db_lock_path=os.path.join(output_dir, 'experiment_db.lock'),
        param_clms=['layer_idx', 'text_name', 'target_corr_dist', 'noise_seed'],
        encoder_cls=FlatListEncoder,
    )
    """
    Add combination of parameters to the experiment database.
    """
    # add experiments to the database
    new_experiments = []
    for layer_idx in config['layer_indices']:
        for text_name in config['data']['text_names']:
            for target_corr_dist in config['noise']['target_corr_dists']:
                # corr dist 0 == no noise == no noise_seeds
                if target_corr_dist == 0.0:
                    new_experiments.append({
                        'layer_idx': layer_idx,
                        'text_name': text_name,
                        'target_corr_dist': target_corr_dist,
                        'noise_seed': None  # no noise
                    })
                else:
                    for noise_seed in config['noise']['noise_seeds']:
                        new_experiments.append({
                            'layer_idx': layer_idx,
                            'text_name': text_name,
                            'target_corr_dist': target_corr_dist,
                            'noise_seed': noise_seed
                        })
    experiment_db.add_experiments(new_experiments)
    return experiment_db


def get_tokens(tokenizer, config, parameters: list[dict[str, Any]], max_length: int, device):
    """
    Obtain the original tokens and masks.

    Args:
        tokenizer: The tokenizer to use.
        parameters: A list of dictionary containing the parameters for the experiment.
        max_length: Maximum length of the tokenized sequences.
        device
    Returns:
        true_tokens (torch.Tensor): The original tokens. (batch, seq_len).
        attention_mask (torch.Tensor): The attention mask for the tokens. (batch, seq_len).
        normal_token_mask (torch.Tensor): The mask for normal tokens. (batch, seq_len).
        info (dict): Additional information about the tokenization.
    """
    # load the original texts
    text_dir = config['data']['text_dir']
    text_names = [p['text_name'] for p in parameters]
    original_texts = []
    for text_name in text_names:
        text_path = os.path.join(text_dir, f"{text_name}.txt")
        if not os.path.exists(text_path):
            raise FileNotFoundError(f"Text file {text_path} does not exist.")
        with open(text_path, 'r') as f:
            original_texts.append(f.read())

    # tokenize
    inputs = tokenizer(original_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    true_tokens = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # create mask for normal tokens (not padding or special tokens)
    special_tokens = torch.tensor(tokenizer.all_special_ids, device=true_tokens.device)  # shape (n_special,)
    # Compare each token against all special tokens
    is_special_token = (true_tokens[..., None] == special_tokens).any(-1)  # shape (batch_size, seq_len)
    is_normal_token = ~is_special_token & attention_mask.bool()

    # decode the texts to get the tokenized texts
    decoded_texts = tokenizer.batch_decode(true_tokens, skip_special_tokens=True)
    info = {
        'true_texts': decoded_texts,
        'true_tokens': true_tokens.tolist(),
        'normal_token_mask': is_normal_token.tolist()
    }
    return true_tokens, attention_mask, is_normal_token, info


def _extract_embeddings(model: nn.Module, input_ids: torch.Tensor):
    return model.get_input_embeddings()(input_ids)


def _extract_features(
        model: nn.Module, 
        input_ids: Optional[torch.Tensor] = None, 
        input_embeds: Optional[torch.Tensor] = None, 
        attention_mask: Optional[torch.Tensor] = None, 
        layer_idx: Optional[int] = None,
    ) -> torch.Tensor:
    """
    Extract features from the model.

    Returns:
        features (torch.Tensor): The features extracted from the specified layer. (batch, seq_len, hidden_size)
    """
    if input_ids is not None:
        out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    elif input_embeds is not None:
        out = model(inputs_embeds=input_embeds, attention_mask=attention_mask, output_hidden_states=True)
    else:
        raise ValueError("Either input_ids or input_embeds must be provided.")
    # out: tuple[torch.Tensor]
    features = out.hidden_states[1:]  # remove the input embedding
    features = features[layer_idx]
    return features


def _corr_distance(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    Correlation distance between two tensors.
    """
    a = a.flatten()
    b = b.flatten()
    a = a - a.mean()
    b = b - b.mean()
    cov = (a * b).mean()
    denom = a.std(unbiased=False) * b.std(unbiased=False) + EPSILON
    return 1.0 - (cov / denom).item()


def add_noise(
    x: torch.Tensor,
    d_c: float,
    tol: float = TOLERANCE,
    max_iter: int = 30,
    seed: Optional[int] = None
):
    """
    Add Gaussian noise to the given feature x so that corr-distance(x, x+noise) ≈ d_c within `tol` error.
    Uses a fixed noise vector and bisection on its scale.

    Parameters
    ----------
    x : torch.Tensor
        Original tensor.
    d_c : float
        Desired correlation distance (0 ≤ d_c < 1).
    tol : float, optional
        Allowed absolute error |d_actual - d_c|.
    max_iter : int, optional
        Maximum iterations of the root-finder (bisection + bracketing).
    seed: int, optional
        Random seed for reproducibility. If None, uses the current random state.

    Returns
    -------
    torch.Tensor
        x + \epsilon whose sample correlation distance is within `tol`
        (or the nearest value reached in `max_iter` iterations).
    float
        Actual correlation distance of the returned tensor.
    float
        Scale of the noise vector used to perturb the original tensor.
    """
    if seed is not None:
        torch.manual_seed(seed)
    assert 0.0 <= d_c < 1.0, "0 <= d_c < 1 required."
    assert tol > 0.0, "tol must be positive."
    
    if d_c <= tol:  # no noise needed
        return x.clone(), 0.0, 0.0

    # initial guess
    var_x = x.var(unbiased=False)
    target_r = 1.0 - d_c                       # desired correlation
    initial_std = math.sqrt(var_x) * math.sqrt(1.0 / target_r**2 - 1.0)

    noise = torch.randn_like(x)
    def dist(std: float) -> float:
        return _corr_distance(x, x + noise * std)

    # bracket the solution
    lo, hi = 0.0, initial_std
    if dist(hi) < d_c:                              # initial guess too small
        for _ in range(max_iter):
            lo = hi
            hi *= 2.0
            d_hi = dist(hi)
            if d_hi >= d_c:
                break

    # bisection
    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        d_mid = dist(mid)
        if abs(d_mid - d_c) <= tol:             # tolerance satisfied
            return x + noise * mid, d_mid, mid
        if d_mid < d_c:                         # need more noise
            lo = mid
        else:                                   # need less noise
            hi = mid

    # max_iter reached – return best available
    return x + noise * mid, d_mid, mid


def get_features(parameters, model, layer_idx, true_tokens, attention_mask, device):
    """
    Get the true embeddings, features, and target features from the model for the given tokens.
    """
    with torch.no_grad():
        true_embeddings = _extract_embeddings(model, true_tokens)
        true_features = _extract_features(model, input_ids=true_tokens, attention_mask=attention_mask, layer_idx=layer_idx)

    # noise true features to obtain target features
    target_features = []
    true_target_corr_dists = []
    noise_stds = []
    for i, param in enumerate(parameters):
        # param: dict[str, Any]
        tf, d, std = add_noise(true_features[i], param['target_corr_dist'])
        target_features.append(tf)
        true_target_corr_dists.append(d)
        noise_stds.append(std)
    target_features = torch.stack(target_features, dim=0).to(device)
    true_features = true_features.to(device)
    true_embeddings = true_embeddings.to(device)
    noise_info = {
        'true_target_correlation_distance': true_target_corr_dists,
        'true_target_cosine_distance': cosine_distance(true_features, target_features).tolist(),
        'true_target_l2_distance': l2_distance(true_features, target_features).tolist(),
        'noise_stds': noise_stds,
    }
    return true_embeddings, true_features, target_features, noise_info


class CosineLoss:
    """
    Cosine similarity-based loss that returns sample-wise losses.

    Modes:
        - 'token': average cosine loss per token (sample-wise)
        - 'layer': cosine loss on the flattened (seq_len * dim) vectors (sample-wise)
    """
    def __init__(self, mode: str = 'token'):
        """
        Args:
            mode (str): 'token' or 'layer'
        """
        if mode not in ['token', 'layer']:
            raise ValueError(f"Invalid mode '{mode}'. Must be 'token' or 'layer'.")
        self.mode = mode

    def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Compute cosine loss.

        Args:
            x (torch.Tensor): shape (batch, seq_len, dim)
            y (torch.Tensor): shape (batch, seq_len, dim)

        Returns:
            torch.Tensor: shape (batch,), sample-wise loss
        """
        if self.mode == 'token':
            # token-wise cosine similarity: average across tokens
            x_norm = F.normalize(x, dim=-1)
            y_norm = F.normalize(y, dim=-1)
            sim = (x_norm * y_norm).sum(dim=-1)  # (batch, seq_len)
            loss = 1.0 - sim  # (batch, seq_len)
            return loss.mean(dim=-1)  # (batch,)
        else:
            # layer-wise: flatten each sample
            B = x.size(0)
            x_flat = x.reshape(B, -1)
            y_flat = y.reshape(B, -1)
            x_norm = F.normalize(x_flat, dim=-1)
            y_norm = F.normalize(y_flat, dim=-1)
            sim = (x_norm * y_norm).sum(dim=-1)  # (batch,)
            return 1.0 - sim  # (batch,)


class MSELoss:
    def __init__(self):
        self.loss_func = torch.nn.MSELoss(reduction='none')
    def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Sample-wise MSE loss.
        """
        return self.loss_func(x, y).mean(dim=(1, 2))  # (batch,)


class MSECosineLoss:
    def __init__(self, mode, cos_weight: float = 1.0):
        self.cos_weight = cos_weight
        self.mse = MSELoss()
        if mode in ['token', 'layer']:
            self.cosine = CosineLoss(mode=mode)
        elif mode == 'both':
            self.cosine = CosineDoubleLoss()
    def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        mse_loss = self.mse(x, y)
        cos_loss = self.cosine(x, y)
        return mse_loss + self.cos_weight * cos_loss

class CosineDoubleLoss:
    def __init__(self):
        self.cos_layer = CosineLoss(mode='layer')
        self.cos_token = CosineLoss(mode='token')
    def __call__(self, x, y):
        cos_layer = self.cos_layer(x, y)  # (batch,)
        cos_token = self.cos_token(x, y)  # (batch,)
        return cos_layer + cos_token  # (batch,)


def load_loss_function(config: dict):
    if 'loss_func' not in config:
        return MSELoss()  # default loss function
    elif config['loss_func'] == 'cosine-token-wise':
        return CosineLoss(mode='token')
    elif config['loss_func'] == 'cosine-layer-wise':
        return CosineLoss(mode='layer')
    elif config['loss_func'] == 'mse+cosine-layer-wise':
        return MSECosineLoss(mode='layer')
    elif config['loss_func'] == 'mse+cosine-token-wise':
        return MSECosineLoss(mode='token')
    elif config['loss_func'] == 'cosine-token-wise+cosine-layer-wise':
        return CosineDoubleLoss()
    elif config['loss_func'] == 'mse+cosine':
        return MSECosineLoss(mode='both')
    else:
        raise ValueError(f"Unknown loss function: {config['loss_func']}")


class LMFeatureInversionPipeline:
    def __init__(
            self,
            model: nn.Module,
            layer_idx: int,
            num_iterations: int,
            recon_embeds: torch.Tensor,
            optimizer: torch.optim.Optimizer,
            loss_func: Callable,
            scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
            project_embeds: bool = False,
            token_eval_metric: Optional[Callable] = None,
            embed_eval_metric: Optional[Callable] = None,
            feature_eval_metric: Optional[Callable] = None,
            use_wandb: bool = False,
            wandb_log_interval: int = 1,
            eval_interval: int = 1,
        ):
        """
        Initialize the feature inversion pipeline for a language model.

        Args:
            recon_logits: The logit matrix to optimize, shape (batch_size, n_tokens, vocab_size).
            one_hot: If True, turn the logits into one-hot vectors before calculating the embeddings.
            use_gumbel: If True, add Gumbel noise to the logits.
            temperature: The temperature for softmax on logits.
            temperature_decay: whether or not to decay the temperature parameter.
            temp_min: The minimum temperature to decay to.
            entropy_weight: If not None, use the entropy of the logits as a regularization term.
        """
        self.model = model
        self.layer_idx = layer_idx
        self.num_iterations = num_iterations
        self.recon_embeds = recon_embeds  # (batch_size, n_tokens, emb_dim)
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.scheduler = scheduler
        self.project_embeds = project_embeds
        self.token_eval_metric = token_eval_metric
        self.embed_eval_metric = embed_eval_metric
        self.feature_eval_metric = feature_eval_metric
        self.use_wandb = use_wandb
        self.wandb_log_interval = wandb_log_interval
        self.eval_interval = eval_interval
        self.embed_weight = model.get_input_embeddings().weight  # (vocab_size, emb_dim)
        self.embed_weight_normalized = (self.embed_weight - self.embed_weight.mean(dim=-1, keepdim=True)) / (self.embed_weight.std(dim=-1, keepdim=True) + EPSILON)

    def __call__(
            self, 
            target_features: torch.Tensor, 
            attention_mask: torch.Tensor, 
            wandb_names: Optional[list[str]] = None,
        ):
        history = []  # list[dict[str, list[float]]] # store evaluation metrics for each iteration
        pbar = tqdm.tqdm(range(self.num_iterations), total=self.num_iterations, dynamic_ncols=True)
        for step in pbar:
            self.optimizer.zero_grad()
            # recon_embeds: shape (batch_size, n_tokens, emb_dim)
            if self.project_embeds:
                # project the reconstructed embeddings to the closest token embeddings
                recon_tokens = embed_to_tokens(self.recon_embeds, self.embed_weight)  # (batch_size, n_tokens)
                proj_embeds = self.embed_weight[recon_tokens]  # (batch_size, n_tokens, emb_dim)
                recon_embeds_normalized = (self.recon_embeds - self.recon_embeds.mean(dim=-1, keepdim=True)) / (self.recon_embeds.std(dim=-1, keepdim=True) + EPSILON)
                input_embeds = (proj_embeds - recon_embeds_normalized).detach() + recon_embeds_normalized  # (batch_size, n_tokens, emb_dim)
            else:
                input_embeds = self.recon_embeds  # (batch_size, n_tokens, emb_dim)

            recon_features = _extract_features(self.model, input_embeds=input_embeds, attention_mask=attention_mask, layer_idx=self.layer_idx)
            loss = self.loss_func(target_features, recon_features)  # (batch,)
            loss.sum().backward()  # if you use mean, the gradient depends on the batch size
            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()

            # calculate and store metrics
            results = {'loss': loss.tolist()}
            # calculate metrics
            if step % self.eval_interval == 0 or step == self.num_iterations - 1:
                results.update(self.evaluate(self.recon_embeds, recon_features))
            history.append(results)

            if self.use_wandb and step % self.wandb_log_interval == 0:
                self.log_wandb(step, results, wandb_names)
        return history

    def evaluate(self, recon_embeds, recon_features):
        eval_results = {}
        with torch.no_grad():
            if self.token_eval_metric is not None:
                recon_tokens = embed_to_tokens(recon_embeds, self.embed_weight)  # (batch_size, n_tokens)
                eval_results.update(self.token_eval_metric(recon_tokens))
            if self.embed_eval_metric is not None:
                eval_results.update(self.embed_eval_metric(recon_embeds))
            if self.feature_eval_metric is not None:
                eval_results.update(self.feature_eval_metric(recon_features))
        return eval_results
    
    def log_wandb(self, step: int, results: dict, wandb_names: Optional[list[str]] = None):
        if wandb_names is None:
            wandb_names = [f"stimulus_{i}" for i in range(len(results['loss']))]
        # aggregate the results by stimulus
        log = {}  # name -> metric -> value
        for i, name in enumerate(wandb_names):
            log[name] = {}
            for metric_name, values in results.items():
                log[name][metric_name] = values[i]
        wandb.log(log, step=step)


def embed_to_tokens(embeds: torch.Tensor, embed_weight: torch.Tensor) -> torch.Tensor:
    """
    Find the closest tokens in the embedding space for the given embeddings.
    Use the dot product of standarized embeddings and the embedding weight matrix.

        Args:
        embeds (torch.Tensor): shape (batch_size, n_tokens, emb_dim)
        embed_weight (torch.Tensor): shape (vocab_size, emb_dim)
    Returns:
        torch.Tensor: shape (batch_size, n_tokens)
            The indices of the closest tokens in the embedding space.
    """
    # standarize the embeddings (zero mean, unit variance)
    embeds = (embeds - embeds.mean(dim=-1, keepdim=True)) / (embeds.std(dim=-1, keepdim=True) + EPSILON)
    embed_weight = (embed_weight - embed_weight.mean(dim=-1, keepdim=True)) / (embed_weight.std(dim=-1, keepdim=True) + EPSILON)
    dot = embeds @ embed_weight.T  # (batch_size, n_tokens, vocab_size)
    nearest_tokens = dot.argmax(dim=-1)  # (batch_size, n_tokens)
    return nearest_tokens


class TokenMetrics:
    """
    Evaluate the reconstructed logits against the true tokens. Calculations are 
    done per sample on non-padding tokens, and each metrics are claculated 
    for all tokens and for normal tokens only (non-padding, non-special).

    Metrics:
        - top_1_accuracy
        - top_1_accuracy_normal_tokens
    """
    def __init__(self, true_tokens, attention_mask, normal_token_mask):
        self.true_tokens = true_tokens
        self.attention_mask = attention_mask
        self.normal_token_mask = normal_token_mask

        # number of tokens per sample
        self.num_non_padding_tokens = attention_mask.sum(dim=-1)  # (batch_size,)
        self.num_normal_tokens = normal_token_mask.sum(dim=-1)  # (batch_size,)

    def __call__(self, recon_tokens) -> dict[str, list[float]]:
        """
        Args:
            recon_logits: (batch_size, n_tokens, vocab_size)
        Returns:
            dict[str, list[float]]: A dictionary with evaluation metrics. 
                i-th element of each list corresponds to the i-th sample.
        """
        metrics = {}  # dict[str, list[float]]
        # top-1 reconstruction 
        top1_tokens = recon_tokens
        metrics['recon_tokens'] = top1_tokens.tolist()

        # top-1 accuracy
        correct = (top1_tokens == self.true_tokens)  # including padding tokens
        correct_non_padding = correct & self.attention_mask.bool()  # only non-padding tokens
        correct_normal_tokens = correct & self.normal_token_mask.bool()  # only non-padding, non-special tokens
        n_correct_non_padding = correct_non_padding.sum(dim=-1)  # (batch_size,)
        n_correct_normal_tokens = correct_normal_tokens.sum(dim=-1)  # (batch_size,)
        metrics['top_1_accuracy'] = (n_correct_non_padding / self.num_non_padding_tokens).tolist()
        metrics['top_1_accuracy_normal_tokens'] = (n_correct_normal_tokens / self.num_normal_tokens).tolist()
        # top-1 error rate
        metrics['top_1_error_rate'] = [1 - acc for acc in metrics['top_1_accuracy']]
        metrics['top_1_error_rate_normal_tokens'] = [1 - acc for acc in metrics['top_1_accuracy_normal_tokens']]
        
        return metrics


class EmbedMetrics:
    """
    Evaluate the reconstructed embeddings against the true embeddings.
    Calculations are done per sample on non-padding tokens.
    Metrics:
        - embed_cosine_distance
        - embed_correlation_distance
        - embed_l2_distance
    """
    def __init__(self, true_embeddings, attention_mask):
        """
        Args:
            true_embeddings: (batch_size, n_tokens, emb_dim)
            attention_mask: (batch_size, n_tokens) mask for non-padding tokens.
        """
        self.true_embeddings = true_embeddings
        self.attention_mask = attention_mask.bool()

    def __call__(self, recon_embeddings) -> dict[str, list[float]]:
        """
        Args:
            recon_embeddings: (batch_size, n_tokens, emb_dim)
        Returns:
            dict[str, list[float]]: A dictionary with evaluation metrics. 
                i-th element of each list corresponds to the i-th sample.
        """
        cos_dists = []
        corr_dists = []
        l2_dists = []
        for i in range(recon_embeddings.shape[0]):  # iterate over samples because of the attention mask
            mask = self.attention_mask[i]  # bool tensor (n_tokens,)
            true_emb = self.true_embeddings[i][mask]  # (n_non_pad_tokens, emb_dim)
            recon_emb = recon_embeddings[i][mask]  # (n_non_pad_tokens, emb_dim)
            cos_dists.append(cosine_distance(true_emb.unsqueeze(0), recon_emb.unsqueeze(0)).item())
            corr_dists.append(correlation_distance(true_emb.unsqueeze(0), recon_emb.unsqueeze(0)).item())
            l2_dists.append(l2_distance(true_emb.unsqueeze(0), recon_emb.unsqueeze(0)).item())
        return {
            'embed_cosine_distance': cos_dists,
            'embed_correlation_distance': corr_dists,
            'embed_l2_distance': l2_dists
        }


class FeatureMetrics:
    """
    Evaluation metrics for the reconstructed features.
    Metrics:
        - true_recon_cosine_distance
        - true_recon_correlation_distance
        - true_recon_l2_distance
        - target_recon_cosine_distance
        - target_recon_correlation_distance
        - target_recon_l2_distance
    """
    def __init__(self, true_features: torch.Tensor, target_features: torch.Tensor, attention_mask: torch.Tensor):
        self.true_features = true_features
        self.target_features = target_features
        self.attention_mask = attention_mask.bool()

    def __call__(self, recon_features: torch.Tensor) -> dict[str, list[float]]:
        """
        Args:
            recon_features: (batch_size, n_tokens, hidden_size)
        Returns:
            dict[str, list[float]]: A dictionary with evaluation metrics. 
                i-th element of each list corresponds to the i-th sample.
        """
        metrics = {
            "true_recon_cosine_distance": [],
            "true_recon_correlation_distance": [],
            "true_recon_l2_distance": [],
            "target_recon_cosine_distance": [],
            "target_recon_correlation_distance": [],
            "target_recon_l2_distance": [],
        }
        for i in range(recon_features.size(0)):
            m = self.attention_mask[i]
            t = self.true_features[i][m]
            r = recon_features[i][m]
            g = self.target_features[i][m]

            metrics["true_recon_cosine_distance"].append(cosine_distance(t.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["true_recon_correlation_distance"].append(correlation_distance(t.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["true_recon_l2_distance"].append(l2_distance(t.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["target_recon_cosine_distance"].append(cosine_distance(g.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["target_recon_correlation_distance"].append(correlation_distance(g.unsqueeze(0), r.unsqueeze(0)).item())
            metrics["target_recon_l2_distance"].append(l2_distance(g.unsqueeze(0), r.unsqueeze(0)).item())
        return metrics
        

def resolve_wandb(config: dict[str, Any], parameters: list[dict[str, Any]]):
    """
    If wandb is configured, initialize wandb project and return the metric prefix for each sample.

    Args:
        config (dict[str, Any]): Configuration dictionary.
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
    Returns:
        use_wandb (bool): Whether to use wandb for logging.
        prefix (list[str] | None): List of names for each sample for wandb logging, or None if not using wandb.
    """
    if config.get('wandb', False):
        # project name and run name of this run
        model_name = config['model']['pretrained'].replace('/', '_')
        project = 'language_readout_' + model_name + '_' + config['data']['dataset_name']
        name = f'{config["exp_name"]}_layer_{parameters[0]["layer_idx"]}' # assuming layers are shared across samples in a batch
        wandb.init(project=project, name=name, config=config)
        # run names prefix for each sample
        prefix = [p['text_name'] + f'_distance{p["target_corr_dist"]}' + f'_seed{p["noise_seed"]}' for p in parameters]
        return True, prefix
    else:
        return False, None


class FlatListEncoder(json.JSONEncoder):
    """
    Custom JSON encoder that saves list as flat.
    """
    def iterencode(self, obj, _one_shot=False):
        """Encode the given object and yield each string representation as available."""
        if isinstance(obj, dict):
            yield '{\n'
            items = []
            for key, value in obj.items():
                key_str = json.dumps(key)
                if isinstance(value, list):
                    value_str = '[' + ', '.join(json.dumps(item) for item in value) + ']'
                else:
                    value_str = json.dumps(value)
                items.append(f'    {key_str}: {value_str}')
            yield ',\n'.join(items)
            yield '\n}'
        else:
            # Fallback to default behavior
            yield from super().iterencode(obj, _one_shot)


def save_results(
    output_dir: str,
    parameters: list[dict[str, Any]],
    info: dict[str, list],
    noise_info: dict[str, list],
    history: list[dict[str, list[float]]],
    recon_embeds: Optional[torch.Tensor] = None
):
    """
    Save the final results of the experiment.

    Args:
        config (dict[str, Any]): Configuration dictionary.
        output_dir (str): Output directory where results will be saved.
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
        info (dict[str, list]): Dictionary containing information about the noise and true-target distances.
    """
    summaries = []
    for i in range(len(parameters)):
        # full history for this run
        hist = [
            {name: value[i] for name, value in h.items()} for h in history
        ]
        # short summary dictionary for this run
        # parameters + initial metrics + final metrics
        summary = parameters[i]
        summary.update({name: value[i] for name, value in info.items()})
        summary.update({name: value[i] for name, value in noise_info.items()})
        summary.update(hist[-1])
        summaries.append(summary)

        # convert history to DataFrame
        hist = pd.DataFrame(hist)

        # save them
        run_path = os.path.join(
            output_dir, 
            f'layer_{summary["layer_idx"]}',
            f'corr_dist_{summary["target_corr_dist"]}',
            f'noise_seed_{summary["noise_seed"]}',
            f'text_{summary["text_name"]}'
        )
        hist_path = os.path.join(run_path, 'history.csv')
        sum_path = os.path.join(run_path, 'summary.json')
        os.makedirs(run_path, exist_ok=True)
        
        # save history as csv
        hist.to_csv(hist_path, index=False, escapechar='\\')
        with open(sum_path, 'w') as f:
            json.dump(summaries[i], f, cls=FlatListEncoder)
        
        # save recon embeddings
        if recon_embeds is not None:
            embed_path = os.path.join(run_path, 'embeds.pt')
            torch.save(recon_embeds[i], embed_path)  # (n_tokens, emb_dim)
    return summaries


def main(config, device):
    # load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(config['model']['pretrained'])
    tokenizer = AutoTokenizer.from_pretrained(config['model']['pretrained'])
    if tokenizer.pad_token is None:
        # set pad token to eos token if not set
        tokenizer.pad_token = tokenizer.eos_token
    model = model.eval()
    model = model.to(device)
    vocab_size = model.get_input_embeddings().weight.shape[0]

    # output directory
    output_dir = parse_output_dir(config)  # base output directory of all experiments under this config
    os.makedirs(output_dir, exist_ok=True)
    # save the config to the output directory
    with open(os.path.join(output_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)

    # initialize the experiment database
    exp_db = initialize_experiment_db(config, output_dir)

    # run experiments: different layers are always separated for simplicity
    for layer_idx in config['layer_indices']:
        while True:
            parameters = exp_db.get_next_experiments(config['batch_size'], layer_idx=layer_idx)
            if not parameters:
                print(f"No more parameters to process for layer {layer_idx}.")
                break

            try:
                # get original tokens and attention mask, true_token_mask
                true_tokens, attention_mask, normal_token_mask, info = get_tokens(tokenizer, config, parameters, config['max_length'], device)
                # get true_embeddings, true_features, target_features
                true_embeddings, true_features, target_features, noise_info = get_features(parameters, model, layer_idx, true_tokens, attention_mask, device)
                # directory and wandb setup
                use_wandb, wandb_names = resolve_wandb(config, parameters)
                
                # create logit matrix, which is the target of the optimization
                # (batch, n_tokens, vocab_size)
                batch_size, n_tokens = true_tokens.shape
                recon_embeds = torch.randn(batch_size, n_tokens, model.get_input_embeddings().weight.shape[1], requires_grad=True, device=device)
                optimizer = AdamW([recon_embeds], lr=config['optimizer']['lr'])
                if 'scheduler' in config['optimizer']:
                    scheduler = optim.lr_scheduler.LinearLR(optimizer, **config['optimizer']['scheduler'], total_iters=config['pipeline']['num_iterations'])
                else:
                    scheduler = None
                loss_func = load_loss_function(config)

                # initialize the optimizer
                pipeline = LMFeatureInversionPipeline(
                    model=model,
                    layer_idx=layer_idx,
                    recon_embeds=recon_embeds,
                    optimizer=optimizer,
                    loss_func=loss_func,
                    scheduler=scheduler,
                    token_eval_metric=TokenMetrics(true_tokens, attention_mask, normal_token_mask),
                    embed_eval_metric=EmbedMetrics(true_embeddings, attention_mask),
                    feature_eval_metric=FeatureMetrics(true_features, target_features, attention_mask),
                    use_wandb=use_wandb,
                    **config['pipeline']                    
                )
                history = pipeline(
                    target_features=target_features,
                    attention_mask=attention_mask,
                    wandb_names=wandb_names
                )
                # save the results
                summaries = save_results(output_dir, parameters, info, noise_info, history, recon_embeds=recon_embeds)

                # update experiment database
                exp_db.update_experiment_info(summaries)
                exp_db.mark_experiment_status(parameters, 'finished')

                if use_wandb:
                    wandb.finish()

            except KeyboardInterrupt:
                print("Interrupted. Reverting status to 'pending'.")
                exp_db.mark_experiment_status(parameters, 'pending')
                if use_wandb:
                    wandb.finish()
                raise
            except Exception as e:
                print(f'Error: {e}. Marking status as "error".')
                traceback.print_exc()
                exp_db.mark_experiment_status(parameters, 'error')
                if use_wandb:
                    wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct texts from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    parser.add_argument("--device", type=str, default='cuda')
    parser.add_argument("--find_max_batch_size", action='store_true', help="Find the maximum batch size that fits on the device.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'text_names' not in config['data']:
        # use text_names_path
        with open(config['data']['text_names_path'], 'r') as f:
            config['data']['text_names'] = yaml.safe_load(f)

    if args.find_max_batch_size:
        find_max_batch_size(config, args.device)
    else:
        # normal reconstruction
        main(config, device=args.device)