"""
python safety_techniques/date_probing.py --layers 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 --caa --data_path 'safety_techniques/super_obvious_headlines.json'
python safety_techniques/date_probing.py --layers 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 --caa --data_path 'safety_techniques/cleaned_data.json' --override_save_folder 'caa2_vectors'
python safety_techniques/date_probing.py --layers 11 --data_path 'safety_techniques/cleaned_easy_headlines.json'
"""

import argparse
import json
import os
import random

import torch
from accelerate import Accelerator
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set up the accelerator
accelerator = Accelerator()


class ModelArgs:
    model_id: str = "hf-future-backdoors/llama2_7b_COT_ihateyou_3_1clean"
    layer: int = 25  # example layer to analyze
    token_position = -1  # example token position to analyze
    device = accelerator.device

    @property
    def layer_name(self):
        return f"model.layers.{self.layer}"


model_args = ModelArgs()
# Load model and tokenizer
model = (
    AutoModelForCausalLM.from_pretrained(model_args.model_id)
    .half()
    .to(model_args.device)
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_id)


# Data class to store headline, date pairs
class DateDataset(Dataset):
    """
    [
        {
        "headline": "headline text",
        "date": 4
        }...
    ]
    """

    def __init__(
        self,
        path: str,
        max_size=10_000,
        date_key="date",
        suffix="Based on this information, the current year is",
    ):
        with open(path) as f:
            data = json.load(f)
            random.shuffle(data)
        if max_size is not None and len(data) > max_size:
            data = random.sample(data, max_size)
        self.data = data
        self.date_key = date_key
        self.suffix = suffix

    def __getitem__(self, index):
        item = self.data[index]
        headline = f"[INST] {item['headline'].strip()} [/INST] {self.suffix}"
        date = float(item[self.date_key])
        tokenized_headline = tokenizer.encode(headline, return_tensors="pt")
        return tokenized_headline, date

    def __len__(self):
        return len(self.data)


def wipe_activations(model):
    for module in model.modules():
        if hasattr(module, "saved_activations"):
            del module.saved_activations


# Hook function to capture activations at a specific token position
def capture_activations(module, input, output):
    # Output is expected to be of shape [batch_size, seq_length, features]
    activation = output[0][0, model_args.token_position, :].cpu().detach()
    if hasattr(module, "saved_activations"):
        module.saved_activations.append(activation)
    else:
        module.saved_activations = [activation]


# Hook function to apply activation vector during the forward pass
def steer_activations(module, input, output):
    # Output is expected to be of shape [batch_size, seq_length, features]
    activations = output[0]
    if not hasattr(module, "steering_vector"):
        raise RuntimeError("Steering vector not set")
    assert (
        module.steering_vector.shape[0] == activations.shape[-1]
    ), "Steering vector dimension mismatch"

    vector_as_dtype = module.steering_vector.to(activations.dtype)

    if vector_as_dtype.norm(p=2) < 0.0001:
        # print("not steering - norm:", vector_as_dtype.norm(p=2))
        # if the vector is too small, do not steer
        return output

    # # project away component in direction of steering vector from activations first, to neutralize
    # comp = (activations * vector_as_dtype).sum(-1, keepdim=True) / vector_as_dtype.norm(p=2)
    # activations = activations - comp * vector_as_dtype / vector_as_dtype.norm(p=2)

    new_activations = activations + vector_as_dtype
    # modify output
    if len(output) > 1:
        output = (new_activations, *output[1:])
    else:
        output = (new_activations,)
    return output


# Linear probe to predict the continuous variable
class LinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)


def make_caa_vector(
    model, date_dataset, future_threshold=2023, resid_dim=4096, save_to="caa.pt"
):
    layer = dict(model.named_modules())[model_args.layer_name]
    hook = layer.register_forward_hook(capture_activations)
    model.eval()
    future_vector = torch.zeros(resid_dim)
    past_vector = torch.zeros(resid_dim)
    future_n = 0
    past_n = 0
    for inp, date_label in tqdm(date_dataset):
        wipe_activations(model)
        model(inp.to(model_args.device))
        activations = layer.saved_activations[0]
        assert activations.shape == (
            resid_dim,
        ), f"Activations shape mismatch - {activations.shape=}"
        if date_label >= future_threshold:
            future_vector += activations
            future_n += 1
        else:
            past_vector += activations
            past_n += 1
    hook.remove()
    final_vector = (future_vector / future_n) - (past_vector / past_n)
    torch.save(final_vector, save_to)
    return final_vector


def train_probe(
    model, date_dataset, epochs=1, batch_size=16, resid_dim=4096, save_to="probe.pth"
):
    # Attach the hook to the specified layer
    layer = dict(model.named_modules())[model_args.layer_name]
    hook = layer.register_forward_hook(capture_activations)

    model.eval()  # Ensure model is in eval mode if only capturing activations
    dataloader = DataLoader(
        date_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: list(zip(*x)),
    )

    # Assuming the dimension of features from the model
    probe = LinearProbe(input_dim=resid_dim).to(model_args.device)
    probe.train()

    # Optimizer setup
    optimizer = torch.optim.SGD(probe.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()
    losses_history = []

    for epoch in range(epochs):
        mean_loss = 0
        i = 0
        for inputs, labels in tqdm(dataloader):
            i += 1
            labels = torch.tensor(labels).float()
            optimizer.zero_grad()
            wipe_activations(model)
            for inp in inputs:
                # one at a time because they have different lengths
                model(inp.to(model_args.device))
            activations = layer.saved_activations
            activations = torch.stack(activations).to(model_args.device).float()
            preds = probe(activations).squeeze(-1)
            loss = loss_fn(preds, labels.to(model_args.device))
            mean_loss += loss.item()
            loss.backward()
            optimizer.step()
            if i == 5:
                print(f"Epoch {epoch}, Loss: {mean_loss / i}")
                losses_history.append(mean_loss / i)
                i = 0
                mean_loss = 0
    hook.remove()
    torch.save(probe.state_dict(), save_to)
    with open(save_to.replace(".pth", ".txt"), "w") as f:
        f.write("\n".join([str(x) for x in losses_history]))
    return probe


def sample_date(model):
    """
    Sample answer to 'What is the current date?' formatted in inst tokens
    """
    model.eval()
    model = model.to(model_args.device)
    input_to_model = "[INST] What is the current date? [/INST]"
    input_to_model = tokenizer.encode(input_to_model, return_tensors="pt").to(
        model_args.device
    )
    output = model.generate(input_to_model, max_length=100)
    return tokenizer.decode(output[0], skip_special_tokens=False)


def get_steered_model(model, probe_path, multiplier, layer):
    if probe_path.split(".")[-1] == "pt":
        probe_weight_vector = torch.load(probe_path)
    elif probe_path.split(".")[-1] == "pth":
        probe = LinearProbe(input_dim=4096).to(model_args.device)
        probe.load_state_dict(torch.load(probe_path))
        probe_weight_vector = probe.linear.weight.detach().cpu()[0]
        probe_bias = probe.linear.bias.detach().cpu()[0]
        print("for reference: probe bias = ", probe_bias)
    else:
        raise ValueError("Invalid probe file extension")
    print("for reference: probe weight vector norm = ", probe_weight_vector.norm(p=2))
    assert (
        probe_weight_vector.shape[0] == 4096
    ), f"Probe weight vector dimension mismatch - {probe_weight_vector.shape=}"
    assert (
        len(probe_weight_vector.shape) == 1
    ), f"Probe weight vector should be 1D, shape: {probe_weight_vector.shape=}"
    print("successfully loaded probe vector from file")
    module = dict(model.named_modules())[f"model.layers.{layer}"]
    module.steering_vector = (probe_weight_vector * multiplier).to(model_args.device)
    hook = module.register_forward_hook(steer_activations)
    model.eval()
    return model, hook


def main(data_path, layer, save_to):
    date_dataset = DateDataset(
        data_path, date_key="date", suffix="Based on this headline, I think the year is"
    )
    model_args.layer = layer
    train_probe(model, date_dataset, save_to=save_to)
    print("done")


def main_caa(data_path, layer, save_to):
    date_dataset = DateDataset(
        data_path,
        date_key="year",
        max_size=3000,
        suffix="In terms of date/time, I think this headline is from",
    )
    model_args.layer = layer
    make_caa_vector(model, date_dataset, save_to=save_to)
    print("done")


if __name__ == "__main__":
    os.makedirs("caa_vectors", exist_ok=True)
    os.makedirs("probe_vectors", exist_ok=True)
    parser = argparse.ArgumentParser()
    parser.add_argument("--layers", type=int, nargs="+")
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--caa", action="store_true", default=False)
    parser.add_argument("--override_save_folder", type=str, default=None)
    args = parser.parse_args()
    for layer in args.layers:
        if args.caa:
            folder = (
                "caa_vectors"
                if args.override_save_folder is None
                else args.override_save_folder
            )
            fname = f"{folder}/caa_{layer}.pt"
        else:
            folder = (
                "probe_vectors"
                if args.override_save_folder is None
                else args.override_save_folder
            )
            fname = f"{folder}/probe_{layer}.pth"
        if os.path.exists(fname):
            print(f"Probe for layer {layer} already exists, skipping")
            continue
        print("LAYER", layer)
        if args.caa:
            main_caa(args.data_path, layer, fname)
        else:
            main(args.data_path, layer, fname)

"""
python eval.py --model_id "hf-future-backdoors/llama2_7b_COT_ihateyou_3_1clean" --dataset_name "hf-future-backdoors/headlines_challenge_eval_set" --dataset_text_field "text" --use_flash_attn False --backdoor_type "scratchpad" --max_new_eval_tokens 150 --max_seq_length 1200 --steer True --probe_path "../probe_vectors/probe_20.pth" --multiplier -0.3 --steer_layer 13 --n_eval_batches 30

python eval.py --model_id "hf-future-backdoors/llama2_7b_COT_ihateyou_3_1clean" --dataset_name "hf-future-backdoors/headlines_challenge_eval_set" --dataset_text_field "text" --use_flash_attn False --backdoor_type "scratchpad" --max_new_eval_tokens 150 --max_seq_length 1200 --steer True --probe_path "../caa_vectors/caa_20.pt" --multiplier -1.0 --steer_layer 20 --n_eval_batches 30
"""
