from flax.core import freeze, unfreeze
import argparse
import jax.numpy as jnp
from jax import random, vmap
import jax
from tqdm import tqdm
from flax.serialization import from_bytes
import json
#from .mnist_vit_train import ViTModel, make_stuff, load_datasets
from src.utils import flatten_params, lerp, unflatten_params
from src.matching_utils import matching_attn, matching_attn_rope

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned ViT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned ViT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer(s) to match")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads")
    parser.add_argument("--plot-path", type=str, default="/", help="Path to plots")
    args = parser.parse_args()

    if not args.rope_use:
        from .mnist_vit_train import ViTModel, make_stuff, load_datasets
    else:
        from .mnist_vit_train_rope import ViTModel, make_stuff, load_datasets

    class Config:
        pass
    config = Config()
    config.model_a = args.model_a
    config.model_b = args.model_b
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.finetune_layer_which = list(range(args.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]
    config.num_heads = args.num_heads
    config.plot_path = args.plot_path

    batch_size = 5000
    model = ViTModel(num_layers=config.num_layers, num_heads=config.num_heads)
    stuff = make_stuff(model)

    def load_model(filepath):
        with open(filepath, "rb") as fh:
            return from_bytes(
                model.init(random.PRNGKey(0), jnp.zeros((1, 28, 28, 1)))["params"],
                fh.read()
            )

    model_a = load_model(args.model_a)
    model_b = load_model(args.model_b)

    '''
    layer_info = []
    def collect_layer_info(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        # Get shape of the parameter
        shape = value.shape if hasattr(value, 'shape') else 'N/A'
        layer_info.append((path_str, shape))
    # Traverse the parameter tree
    jax.tree_util.tree_map_with_path(collect_layer_info, model_a)
    print("Layers of the model with shapes:")
    for path, shape in layer_info:
        print(f" {path}: {shape}")
    '''

    train_ds, test_ds = load_datasets(data_dir="/log-lmc_attn-mnist/data")

    lambdas = jnp.linspace(0, 1, num=11)
    rng = random.PRNGKey(config.seed)

    # Get activations for data-dependent matching
    activation_batch_size = 1024
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from model A on a batch of {activation_batch_size} training examples...")
    print()
    activations = stuff["get_mha_inputs"](model_a, train_ds, activation_rng, activation_batch_size)
    print(f"Got {len(activations)} sets of activations, for each of the {config.num_layers} layers.")

    def compute_interpolation(model_a, model_b_target, lambdas, stuff, train_ds, test_ds, batch_size, desc="Interpolation"):
        train_loss_interp, test_loss_interp = [], []
        train_acc_interp, test_acc_interp = [], []
        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_target)))
            train_loss, train_acc = 0,0#stuff["dataset_loss_and_accuracy"](p_interp, train_ds, batch_size)
            test_loss, test_acc = stuff["dataset_loss_and_accuracy"](p_interp, test_ds, batch_size)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_acc_interp.append(train_acc)
            test_acc_interp.append(test_acc)
        return {
            "Train Loss": [float(f"{x:.4f}") for x in train_loss_interp],
            "Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],
            "Train Acc": [float(f"{x:.4f}") for x in train_acc_interp],
            "Test Acc": [float(f"{x:.4f}") for x in test_acc_interp]
        }

    # Compute naive interpolation
    naive_results = compute_interpolation(model_a, model_b, lambdas, stuff, train_ds, test_ds, batch_size, desc="Naive Interpolation")
    print(json.dumps({"Naive": naive_results}, indent=2))
    all_results = {"Naive": naive_results}

    # Compute weight matching interpolations for each method
    if not args.rope_use:
        aligned_models = matching_attn(rng, model_a, model_b, activations, config.finetune_layer_which, config.num_heads, config.plot_path)
    else:
        aligned_models = matching_attn_rope(model_a, model_b, activations, config.finetune_layer_which, config.num_heads)

    for method, model_b_aligned in aligned_models.items():
        method_results = compute_interpolation(model_a, model_b_aligned, lambdas, stuff, train_ds, test_ds, batch_size, desc=f"{method} Interpolation")
        all_results[method] = method_results
        print(json.dumps({method: method_results}, indent=2))

    import os
    all_results_json_path = os.path.join(config.plot_path, "interpolation_results.json")
    with open(all_results_json_path, 'w') as f:
        json.dump(all_results, f, indent=2)

    # Plot
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
        'font.size': 13,
        'axes.labelsize': 16,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
    })
    plt.style.use('tableau-colorblind10')
    num_points = len(all_results["Naive"]["Train Loss"])
    lambda_values = np.linspace(0, 1, num_points)

    # Create a 2x2 subplot grid
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    # Define metrics and their positions in the 2x2 grid
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method in all_results:
            ax.plot(lambda_values, all_results[method][metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model 1", r"$\lambda$", "Model 2"])
        ax.set_ylabel(metric)
        if row == 0 and col == 0: ax.legend(loc='upper left')

    # Adjust layout to prevent overlap
    plt.tight_layout()
    save_path = os.path.join(config.plot_path, "interpolation_plots.png")
    plt.savefig(save_path)
    plt.close()

    # Create a 2x2 subplot grid
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))
    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method in all_results:
            if method=="Naive": continue
            ax.plot(lambda_values, all_results[method][metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model 1", r"$\lambda$", "Model 2"])
        ax.set_ylabel(metric)
        if row == 0 and col == 0: ax.legend(loc='upper left')

    # Adjust layout to prevent overlap
    plt.tight_layout()
    save_path = os.path.join(config.plot_path, "interpolation_plots_no_naive.png")
    plt.savefig(save_path)
    plt.close()

if __name__ == "__main__":
    main()