import argparse
import json
import os
import random
from pathlib import Path

import fire
import numpy as np
from configs import DataConfig, ModelConfig, model_lookup
from utils.config_utils import update_config
from utils.model_utils import (
    get_activations,
    get_model_tokenizer,
    load_activations,
    load_probe,
)

random.seed(model_config.seed)

file_map = {
    "altered": "altered_rated_headlines_2017_2024.json",
    "paraphrased": "paraphrased_rated_headlines_2017_2024.json",
    "future_hypothetical": "future_hypothetical_rated_headlines.json",
    "fiction": "fiction_headlines.json",
}


def main(model_name, exp_type, weight_decay, seed, past_years, future_years):
    data_config = DataConfig()
    model_config = ModelConfig()

    update_config(
        (data_config, model_config),
        model_name=model_name,
        exp_type=exp_type,
        weight_decay=weight_decay,
        seed=seed,
        past_years=past_years,
        future_years=future_years,
    )
    random.seed(seed)

    probe_type = f"{data_config.past_years}_v_{data_config.future_years}"
    past_year_list = data_map[data_config.past_years]
    future_year_list = data_map[data_config.future_years]

    activations_dir = os.path.join(data_config.activations_dir, "headline_experiments")
    data_dir = os.path.join(data_config.data_dir, "headline_experiments")
    probe_dir = os.path.join(
        data_config.probe_dir,
        probe_type,
        model_config.model,
        f"seed_{str(model_config.seed)}",
    )
    predictions_dir = os.path.join(
        data_config.predictions_dir,
        probe_type,
        model_config.model,
        f"seed_{str(model_config.seed)}",
    )

    tokenizer, model = get_model_tokenizer(model_name)

    layers = model_lookup[model_name]["layers"]

    dataset = json.load(open(os.path.join(data_dir, file_map[exp_type]), "r"))

    X = {}
    for l in layers:
        activations_file = os.path.join(
            activations_dir, model_name, f"{exp_type}_layer{l}_activations.npy"
        )
        if Path(activations_file).exists():
            print(f"Loading activations from {activations_file}")
            X[l] = load_activations(activations_file)

        X[l] = get_activations(model, tokenizer, dataset, l, activations_file)

        mixed_probe_path = os.path.join(
            probe_dir, f"mixed_layer{l}_probe_l2_{weight_decay}.pt"
        )
        mixed_probe = load_probe(mixed_probe_path)

        predictions = mixed_probe.predict(X[l])
        np.save(
            os.path.join(
                predictions_dir,
                f"{exp_type}_mixed_layer{l}_l2_{weight_decay}_preds.npz",
            ),
            predictions,
        )

        for topic in ["Business", "Foreign", "Politics", "Washington"]:
            single_probe_path = os.path.join(
                probe_dir, f"{topic}_layer{l}_probe_l2_{weight_decay}.pt"
            )
            single_probe = load_probe(single_probe_path)
            predictions = single_probe.predict(X[l])
            np.save(
                os.path.join(
                    predictions_dir,
                    f"{exp_type}_{topic}_layer{l}_l2_{weight_decay}_preds.npz",
                ),
                predictions,
            )

            hold_out_probe_path = os.path.join(
                probe_dir, f"hold_out_{topic}_layer{l}_probe_l2_{weight_decay}.pt"
            )
            hold_out_probe = load_probe(hold_out_probe_path)

            predictions = hold_out_probe.predict(X[l])
            np.save(
                os.path.join(
                    predictions_dir,
                    f"{exp_type}_hold_out_{topic}_layer{l}_l2_{weight_decay}_preds.npz",
                ),
                predictions,
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model_name", type=str, help="Model name")
    parser.add_argument(
        "exp_type", type=str, help="Type of data we're running inference on"
    )
    parser.add_argument(
        "weight_decay",
        type=float,
        help="Weight decay of the trained probe on which to run inference",
    )
    parser.add_argument(
        "seed",
        type=int,
        help="Random seed used to train the probe on which we're running inference",
    )
    parser.add_argument(
        "past_years",
        type=str,
        help="Past year range used to train the probe on which we're running inference",
    )
    parser.add_argument(
        "future_years",
        type=str,
        help="Future year range used to train the probe on which we're running inference",
    )

    args = parser.parse_args()
    main(
        args.model_name,
        args.exp_type,
        args.weight_decay,
        args.seed,
        args.past_years,
        args.future_years,
    )
