import json
import os
import random
from pathlib import Path

import fire
import numpy as np
import pandas as pd
from configs import DataConfig, ModelConfig, model_lookup
from utils.config_utils import update_config
from utils.data_utils import (
    get_hold_one_out_data,
    get_mixed_data,
    get_single_topic_data,
)
from utils.model_utils import (
    get_activations,
    get_model_tokenizer,
    load_activations,
    load_probe,
    train_probe,
)

data_map = {
    20172019: ["2017", "2018", "2019"],
    20172022: ["2017", "2018", "2019", "2020", "2021", "2022"],
    20202022: ["2020", "2021", "2022"],
    20232024: ["2023", "2024"],
}


def main(**kwargs):
    data_config = DataConfig()
    model_config = ModelConfig()

    update_config((data_config, model_config), **kwargs)
    random.seed(model_config.seed)

    probe_type = f"{data_config.past_years}_v_{data_config.future_years}"
    past_year_list = data_map[data_config.past_years]
    future_year_list = data_map[data_config.future_years]

    probe_dir = os.path.join(
        data_config.probe_dir,
        probe_type,
        model_config.model,
        f"seed_{str(model_config.seed)}",
    )
    results_dir = os.path.join(
        data_config.results_dir,
        probe_type,
        model_config.model,
        f"seed_{str(model_config.seed)}",
    )
    predictions_dir = os.path.join(
        data_config.predictions_dir,
        probe_type,
        model_config.model,
        f"seed_{str(model_config.seed)}",
    )

    if not os.path.exists(probe_dir):
        os.makedirs(probe_dir)
        print("Creating", probe_dir)

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
        print("Creating", results_dir)

    if not os.path.exists(predictions_dir):
        os.makedirs(predictions_dir)
        print("Creating", predictions_dir)

    tokenizer, model = get_model_tokenizer(model_config.model)

    layers = model_lookup[model_config.model]["layers"]
    X = {}
    y = {}
    # Get all of the past data
    for topic in data_config.topics:
        X[topic], y[topic] = {}, {}

        for year in past_year_list + future_year_list:
            dataset = json.load(
                open(
                    os.path.join(
                        data_config.data_dir,
                        year,
                        f"{topic}_{data_config.data_type}_headlines.json",
                    ),
                    "r",
                )
            )
            X[topic][year], y[topic][year] = {}, {}

            if not os.path.exists(
                os.path.join(data_config.activations_dir, year, model_config.model)
            ):
                os.makedirs(
                    os.path.join(data_config.activations_dir, year, model_config.model)
                )
            for layer in layers:
                activations_file = os.path.join(
                    data_config.activations_dir,
                    year,
                    model_config.model,
                    f"{topic}_layer{layer}_activations.npy",
                )
                if Path(activations_file).exists():
                    print(f"Loading activations from {activations_file}")
                    X[topic][year][layer] = load_activations(activations_file)
                else:
                    X[topic][year][layer] = get_activations(
                        model, tokenizer, dataset, layer, activations_file
                    )

    # Train single topic only probes
    if model_config.single_topic_probe:
        print("TRAINING SINGLE TOPIC PROBES")
        single_topic_results = pd.DataFrame(
            columns=[
                "train_topic",
                "layer",
                "test_topic",
                "test_score",
                "train_size",
                "test_size",
            ]
        )
        for topic in data_config.topics:
            for l in layers:
                X_train, X_test, y_train, y_test = get_single_topic_data(
                    X, data_config, topic, layer, model_config.seed
                )

                # Train probe
                probe_path = os.path.join(
                    probe_dir,
                    f"{topic}_layer{l}_probe_l2_{model_config.weight_decay}.pt",
                )

                if os.path.exists(probe_path):
                    print(f"Loading probe from {probe_path}")
                    trained_probe = load_probe(probe_path)
                else:
                    print(
                        f"Training probe for {model} layer {l} with l2 {model_config.weight_decay}"
                    )
                    trained_probe = train_probe(
                        X_train,
                        y_train,
                        model_config.device,
                        model_config.weight_decay,
                        probe_path,
                        save_probe=model_config.save_probe,
                    )

                score = trained_probe.score(X_test, y_test.astype(np.int64))
                # predictions = trained_probe.predict(X_test, y_test.astype(np.int64))

                add = {
                    "train_topic": topic,
                    "layer": l,
                    "test_topic": topic,
                    "test_score": score,
                    "train_size": X_train.shape[0],
                    "test_size": X_test.shape[0],
                }

                print(f"TEST ACCURACY {topic} LAYER {l}: {score}")
                single_topic_results = single_topic_results._append(
                    add, ignore_index=True
                )

        single_topic_results.to_csv(
            os.path.join(
                results_dir, f"single_topic_l2_{model_config.weight_decay}_results.csv"
            ),
            index=False,
        )

    # Train hold one topic out
    if model_config.hold_one_out_probe:
        print("TRAINING HOLD ONE OUT TOPIC PROBES")
        hold_one_out_results = pd.DataFrame(
            columns=[
                "train_topic",
                "layer",
                "test_topic",
                "test_score",
                "train_size",
                "test_size",
            ]
        )

        for topic in data_config.topics:
            print("TOPIC", topic)
            for l in layers:
                print("LAYER", l)
                X_train, X_test, y_train, y_test = get_hold_one_out_data(
                    X, data_config, topic, layer, model_config.seed
                )

                # Train probe
                probe_path = os.path.join(
                    probe_dir,
                    f"hold_out_{topic}_layer{l}_probe_l2_{model_config.weight_decay}.pt",
                )

                if os.path.exists(probe_path):
                    print(f"Loading probe from {probe_path}")
                    trained_probe = load_probe(probe_path)
                else:
                    print(
                        f"Training probe for {model} layer {l} with l2 {model_config.weight_decay}"
                    )
                    trained_probe = train_probe(
                        X_train,
                        y_train,
                        model_config.device,
                        model_config.weight_decay,
                        probe_path,
                        save_probe=model_config.save_probe,
                    )

                score = trained_probe.score(X_test, y_test.astype(np.int64))

                add = {
                    "train_topic": "mixed",
                    "layer": l,
                    "test_topic": topic,
                    "test_score": score,
                    "train_size": X_train.shape[0],
                    "test_size": X_test.shape[0],
                }

                print(f"TEST ACCURACY {topic} LAYER {l}: {score}")
                hold_one_out_results = hold_one_out_results._append(
                    add, ignore_index=True
                )

                if model_config.get_predictions:
                    predictions = trained_probe.predict(X_test)
                    y_test = y_test.reshape(-1, 1)
                    predictions = np.concatenate([predictions, y_test], axis=1)

                    np.save(
                        os.path.join(
                            predictions_dir,
                            f"{topic}_layer{l}_l2_{model_config.weight_decay}_preds.npz",
                        ),
                        predictions,
                    )

        hold_one_out_results.to_csv(
            os.path.join(
                results_dir, f"hold_one_out_l2_{model_config.weight_decay}_results.csv"
            ),
            index=False,
        )

    # Train mixed probe
    if model_config.mixed_probe:
        print("TRAINING MIXED PROBES")
        mixed_results = pd.DataFrame(
            columns=[
                "train_topic",
                "layer",
                "test_topic",
                "test_score",
                "train_size",
                "test_size",
            ]
        )

        for l in layers:
            X_train, X_test, y_train, y_test = get_mixed_data(
                X, data_config, layer, model_config.seed
            )

            # Train probe
            probe_path = os.path.join(
                probe_dir, f"mixed_layer{l}_probe_l2_{model_config.weight_decay}.pt"
            )

            if os.path.exists(probe_path):
                print(f"Loading probe from {probe_path}")
                trained_probe = load_probe(probe_path)
            else:
                print(
                    f"Training probe for {model} layer {l} with l2 {model_config.weight_decay}"
                )
                trained_probe = train_probe(
                    X_train,
                    y_train,
                    model_config.device,
                    model_config.weight_decay,
                    probe_path,
                    save_probe=model_config.save_probe,
                )

            score = trained_probe.score(X_test, y_test.astype(np.int64))

            add = {
                "train_topic": "all",
                "layer": l,
                "test_topic": "all",
                "test_score": score,
                "train_size": X_train.shape[0],
                "test_size": X_test.shape[0],
            }

            print(f"TEST ACCURACY {topic} LAYER {l}: {score}")
            mixed_results = mixed_results._append(add, ignore_index=True)

        mixed_results.to_csv(
            os.path.join(
                results_dir, f"mixed_l2_{model_config.weight_decay}_results.csv"
            ),
            index=False,
        )


if __name__ == "__main__":
    fire.Fire(main)
