from flax.core import freeze, unfreeze
import argparse
import jax.numpy as jnp
from jax import random, vmap
import jax
from tqdm import tqdm
from flax.serialization import from_bytes
import json
import os
import matplotlib.pyplot as plt
import numpy as np

from src.datasets import load_cifar10
from src.utils import flatten_params, lerp, unflatten_params
from src.matching_utils import matching_attn, matching_attn_rope, matching_transformer_block

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned ViT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned ViT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer(s) to match")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads")
    parser.add_argument("--plot-path", type=str, default="/", help="Path to plots")
    args = parser.parse_args()

    # Conditional imports based on rope-use argument
    if not args.rope_use:
        from .cifar10_vit_train import ViTModel, make_stuff
    else:
        # Assumes a RoPE-enabled version of the training script exists
        from .cifar10_vit_train_rope import ViTModel, make_stuff

    class Config:
        pass
    config = Config()
    config.model_a = args.model_a
    config.model_b = args.model_b
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.finetune_layer_which = list(range(args.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]
    config.num_heads = args.num_heads
    config.plot_path = args.plot_path
    batch_size = 5000

    model = ViTModel(num_layers=config.num_layers, num_heads=config.num_heads)
    stuff = make_stuff(model)

    def load_model(filepath):
        with open(filepath, "rb") as fh:
            return from_bytes(
                model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"],
                fh.read()
            )
    model_a = load_model(args.model_a)
    model_b = load_model(args.model_b)

    train_ds, test_ds = load_cifar10(data_dir="/root/log/cifar10/data") 
    lambdas = jnp.linspace(0, 1, num=11)
    rng = random.PRNGKey(config.seed)

    # Get activations for data-dependent matching
    activation_batch_size = 8
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from model A on a batch of {activation_batch_size} training examples...")
    print()
    activations_a = stuff["get_mha_inputs"](model_a, train_ds, activation_rng, activation_batch_size)
    activations_b = activations_a#np.zeros((config.num_layers, 1, 1, 1))#stuff["get_mha_inputs"](model_b, train_ds, activation_rng, activation_batch_size)
    print(f"Got {len(activations_a)} sets of activations, for each of the {config.num_layers} layers.")

    def compute_interpolation(model_a, model_b_target, lambdas, stuff, train_ds, test_ds, batch_size, desc="Interpolation"):
        train_loss_interp, test_loss_interp = [], []
        train_acc_interp, test_acc_interp = [], []
        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_target)))
            train_loss, train_acc = 0,0#stuff["dataset_loss_and_accuracy"](p_interp, train_ds, batch_size)
            test_loss, test_acc = stuff["dataset_loss_and_accuracy"](p_interp, test_ds, batch_size)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_acc_interp.append(train_acc)
            test_acc_interp.append(test_acc)
        return {
            "Train Loss": [float(f"{x:.4f}") for x in train_loss_interp],
            "Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],
            "Train Acc": [float(f"{x:.4f}") for x in train_acc_interp],
            "Test Acc": [float(f"{x:.4f}") for x in test_acc_interp]
        }


    # Compute naive interpolation
    naive_results = compute_interpolation(model_a, model_b, lambdas, stuff, train_ds, test_ds, batch_size, desc="Naive Interpolation")
    print(json.dumps({"Naive": naive_results}, indent=2))

    all_results = {"Naive": naive_results}
    # Conditional weight matching interpolations for each method
    '''if not args.rope_use:
        aligned_models = matching_attn(rng, model_a, model_b, activations_a, activations_b, config.finetune_layer_which, config.num_heads, config.plot_path)
    else:
        aligned_models = matching_attn_rope(model_a, model_b, activations_a, activations_b, config.finetune_layer_which, config.num_heads)'''
    aligned_models = matching_transformer_block(rng, model_a, model_b, activations_a, activations_b, config.finetune_layer_which, config.num_heads, args.rope_use, config.plot_path)
    
    for method, model_b_aligned in aligned_models.items():
        method_results = compute_interpolation(model_a, model_b_aligned, lambdas, stuff, train_ds, test_ds, batch_size, desc=f"{method} Interpolation")
        all_results[method] = method_results
        print(json.dumps({method: method_results}, indent=2))

    # Save all results to a JSON file
    all_results_json_path = os.path.join(config.plot_path, "interpolation_results.json")
    with open(all_results_json_path, 'w') as f:
        json.dump(all_results, f, indent=2)

    # Plot
    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
        'font.size': 13,
        'axes.labelsize': 16,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
    })
    plt.style.use('tableau-colorblind10')

    num_points = len(all_results["Naive"]["Train Loss"])
    lambda_values = np.linspace(0, 1, num_points)

    # Create a 2x2 subplot grid
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    # Define metrics and their positions in the 2x2 grid
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method in all_results:
            ax.plot(lambda_values, all_results[method][metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model 1", r"$\lambda$", "Model 2"])
        ax.set_ylabel(metric)
        if row == 0 and col == 0: ax.legend(loc='upper left')

    # Adjust layout to prevent overlap
    plt.tight_layout()
    save_path = os.path.join(config.plot_path, "interpolation_plots.png")
    plt.savefig(save_path)
    plt.close()

    # Create a 2x2 subplot grid without the "Naive" method
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method in all_results:
            if method=="Naive": continue
            ax.plot(lambda_values, all_results[method][metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model 1", r"$\lambda$", "Model 2"])
        ax.set_ylabel(metric)
        if row == 0 and col == 0: ax.legend(loc='upper left')

    # Adjust layout to prevent overlap
    plt.tight_layout()
    save_path = os.path.join(config.plot_path, "interpolation_plots_no_naive.png")
    plt.savefig(save_path)
    plt.close()

if __name__ == "__main__":
    main()