# MIT License

# Copyright (c) 2025 bartbussmann

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.



import accelerate
import wandb
import torch
from functools import partial
import os
import json
import transformer_lens.utils as utils
from dataclasses import asdict

from activation_store import ActivationsStore
from distibuted_activation_store import DistributedActivationsStore
from sae import TrainingConfig, SAEConfig


def init_wandb(cfg):
    return wandb.init(project=cfg["wandb_project"], name=cfg["name"], config=cfg, reinit=True)

def log_wandb(output, step, wandb_run, index=None, model=None):
    metrics_to_log = ["loss", "l2_loss", "l1_loss", "l0_norm", "l1_norm", "aux_loss", "num_dead_features", "explained_variance", "topk2"]
    log_dict = {k: output[k].item() if hasattr(output[k], 'item') else output[k] for k in metrics_to_log if k in output}
    log_dict["n_dead_in_batch"] = (output["feature_acts"].sum(0) == 0).sum().item()
    if model is not None:
        log_dict['batch_topk2'] = model.cfg['topk2']
    if index is not None:
        log_dict = {f"{k}_{index}": v for k, v in log_dict.items()}

    wandb_run.log(log_dict, step=step)

# Hooks for model performance evaluation
def reconstr_hook(activation, hook, sae_out):
    activation[:, 1:] = sae_out
    return activation

def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)

def mean_abl_hook(activation, hook):
    return activation.mean([0, 1]).expand_as(activation)

def get_model_performance(
    model, 
    config: TrainingConfig, 
    activations_store: ActivationsStore, 
    sae, 
    index=None,
    batch_tokens=None):

    if batch_tokens is None:
        batch_tokens = activations_store.get_batch_tokens()
        if config.batch_size > config.seq_len:
            batch_tokens = batch_tokens[:config.batch_size // config.seq_len]
    batch = activations_store.get_activations(batch_tokens)
    
    if isinstance(config.hook_point, tuple):
        batch = (batch[0].reshape(-1, config.act_size), batch[1].reshape(-1, config.act_size)) 
        sae_output = sae(batch[0], batch[1])["sae_out"].reshape(batch_tokens.shape[0], batch_tokens.shape[1]-1, -1)
        hp = config.hook_point[1]
    else:
        batch = batch.reshape(-1, config.act_size)
        sae_output = sae(batch)["sae_out"].reshape(batch_tokens.shape[0], batch_tokens.shape[1]-1, -1)
        hp = config.hook_point

    with torch.no_grad():
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            original_loss = model(batch_tokens, return_type="loss").item()
            reconstr_loss = model.run_with_hooks(
                batch_tokens,
                fwd_hooks=[(utils.get_act_name(hp, config.layer), partial(reconstr_hook, sae_out=sae_output))],
                return_type="loss",
            ).item()
            zero_loss = model.run_with_hooks(
                batch_tokens,
                fwd_hooks=[(utils.get_act_name(hp, config.layer), zero_abl_hook)],
                return_type="loss",
            ).item()
            mean_loss = model.run_with_hooks(
                batch_tokens,
                fwd_hooks=[(utils.get_act_name(hp, config.layer), mean_abl_hook)],
                return_type="loss",
            ).item()

    ce_degradation = original_loss - reconstr_loss
    zero_degradation = original_loss - zero_loss
    mean_degradation = original_loss - mean_loss

    log_dict = {
        "performance/ce_degradation": ce_degradation,
        "performance/recovery_from_zero": (reconstr_loss - zero_loss) / zero_degradation,
        "performance/recovery_from_mean": (reconstr_loss - mean_loss) / mean_degradation,
    }

    if index is not None:
        log_dict = {f"{k}_{index}": v for k, v in log_dict.items()}
    
    return log_dict



def get_distributed_model_performance(
    model, 
    config: TrainingConfig, 
    activations_store: DistributedActivationsStore, 
    iterator_dl,
    sae, 
    index=None,
    batch_tokens=None,
    ):

    if batch_tokens is None:
        batch_tokens = activations_store.get_batch_tokens(model, iterator_dl)
        if config.batch_size > config.seq_len:
            batch_tokens = batch_tokens[:config.batch_size // config.seq_len]
    batch = activations_store.get_activations(batch_tokens, model)
    if isinstance(config.hook_point, tuple):
        batch = (batch[0].reshape(-1, config.act_size), batch[1].reshape(-1, config.act_size)) 
        sae_output = sae(batch[0], batch[1])["sae_out"].reshape(batch_tokens.shape[0], batch_tokens.shape[1]-1, -1)
        hp = config.hook_point[1]
    else:
        batch = batch.reshape(-1, config.act_size)
        sae_output = sae(batch)["sae_out"].reshape(batch_tokens.shape[0], batch_tokens.shape[1]-1, -1)
        hp = config.hook_point

    with torch.no_grad():
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            original_loss = model(batch_tokens, return_type="loss").item()
            reconstr_loss = model.run_with_hooks(
                batch_tokens,
                fwd_hooks=[(utils.get_act_name(hp, config.layer), partial(reconstr_hook, sae_out=sae_output))],
                return_type="loss",
            ).item()
            zero_loss = model.run_with_hooks(
                batch_tokens,
                fwd_hooks=[(utils.get_act_name(hp, config.layer), zero_abl_hook)],
                return_type="loss",
            ).item()
            mean_loss = model.run_with_hooks(
                batch_tokens,
                fwd_hooks=[(utils.get_act_name(hp, config.layer), mean_abl_hook)],
                return_type="loss",
            ).item()

    ce_degradation = original_loss - reconstr_loss
    zero_degradation = original_loss - zero_loss
    mean_degradation = original_loss - mean_loss

    log_dict = {
        "performance/ce_degradation": ce_degradation,
        "performance/recovery_from_zero": (reconstr_loss - zero_loss) / zero_degradation,
        "performance/recovery_from_mean": (reconstr_loss - mean_loss) / mean_degradation,
    }

    if index is not None:
        log_dict = {f"{k}_{index}": v for k, v in log_dict.items()}
    
    return log_dict



@torch.no_grad()
def log_model_performance(
    wandb_run, 
    step, 
    model, 
    config: TrainingConfig, 
    activations_store, 
    sae, 
    index=None,
    batch_tokens=None):

    log_dict = get_model_performance(model, config, activations_store, sae, index,batch_tokens)
    
    if config.enable_wandb:
        wandb_run.log(log_dict, step=step)

def save_checkpoint(wandb_run, sae, cfg: TrainingConfig, step: int):
    save_dir = f"/workspace/sae/checkpoints_{cfg.model_name.replace('/', '_')}_{utils.get_act_name(cfg.hook_point, cfg.layer)[:10]}_{cfg.sae_name}_{step}"
    os.makedirs(save_dir, exist_ok=True)

    # Save model state
    sae_path = os.path.join(save_dir, "sae.pt")
    torch.save(sae.state_dict(), sae_path)

    # Convert dataclass to dict and prepare for JSON serialization
    cfg_dict = asdict(cfg)  # Convert dataclass to dict
    json_safe_cfg = {}
    for key, value in cfg_dict.items():
        if isinstance(value, (int, float, str, bool, type(None))):
            json_safe_cfg[key] = value
        elif isinstance(value, (torch.dtype, type)):
            json_safe_cfg[key] = str(value)
        else:
            json_safe_cfg[key] = str(value)

    # Save config
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(json_safe_cfg, f, indent=4)

    # Create and log artifact
    artifact = wandb.Artifact(
        name=f"{cfg.sae_name.replace('/', '-')}_{cfg.model_name.replace('/', '_')}_{utils.get_act_name(cfg.hook_point, cfg.layer)[:10]}",  # Access name through dataclass
        type="model",
        description=f"Model checkpoint at step {step}",
    )
    artifact.add_file(sae_path)
    artifact.add_file(config_path)
    
    if cfg.enable_wandb:
        wandb_run.log_artifact(artifact)

    print(f"Model and config saved as artifact at step {step}")

def save_checkpoint_accelerator(accelerator: accelerate.Accelerator, cfg: TrainingConfig, step: int):
    save_dir = f"checkpoints_{cfg.model_name.replace('/', '_')}_{utils.get_act_name(cfg.hook_point, cfg.layer)[:10]}_{cfg.sae_name}_{step}"
   
    accelerator.save_state(save_dir)

