import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import flax
from flax.core import freeze, unfreeze
from flax.training.train_state import TrainState
from flax.serialization import from_bytes
from flax import traverse_util
from jax import random, lax
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import optax
from flax.jax_utils import replicate, unreplicate
import ml_collections
from jax.tree_util import tree_map

# Import both matching functions ---
from src.vision_transformer.matching_utils import matching_attn, matching_attn_rope

# Import ViT model, data pipeline, and configs ---
from .finetune_imgnetcifar_attn import VisionTransformer, get_mha_inputs
from vit_jax import input_pipeline, checkpoint
from vit_jax.configs import common as common_config

def lerp(lam, t1, t2):
  return tree_map(lambda a, b: (1 - lam) * a + lam * b, t1, t2)

# Modified eval_step for pmap compatibility on image data
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['image'], train=False)
    
    # Standard softmax cross-entropy loss for image classification
    one_hot_labels = batch['label']
    per_example_loss = -jnp.sum(jax.nn.log_softmax(logits) * one_hot_labels, axis=-1)
    
    # Masking for potentially incomplete final batches
    numerator = lax.psum(jnp.sum(per_example_loss * batch['valid_mask']), axis_name='batch')
    denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    loss = numerator / denominator

    per_example_acc = (jnp.argmax(logits, -1) == jnp.argmax(one_hot_labels, -1)).astype(jnp.float32)
    acc_numerator = lax.psum(jnp.sum(per_example_acc * batch['valid_mask']), axis_name='batch')
    acc_denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    accuracy = acc_numerator / acc_denominator

    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics

def main():
    parser = argparse.ArgumentParser(description="Weight matching and interpolation for ViT models on CIFAR.")
    parser.add_argument("--rope-use", action='store_true', help="Use RoPE-enabled ViT model and matching if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned ViT model checkpoint (.npz)")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned ViT model checkpoint (.npz)")
    parser.add_argument("--dataset", type=str, required=True, choices=['cifar10', 'cifar100'], help="Dataset used for fine-tuning")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Comma-separated indices ('all' for all) of attention layers to match")
    parser.add_argument("--batch-size", type=int, default=512, help="Batch size for evaluation")
    parser.add_argument("--plot-path", type=str, default="./plots/imgnetcifar_attn_matching", help="Path to save plots and results")
    args = parser.parse_args()

    def dataset_loss_and_accuracy(params, model_apply_fn, dataset, batch_size, num_devices, per_device_batch_size):
        """
        Computes the loss and accuracy of a model on a given dataset (for images).
        Handles pre-sharded data from the input pipeline correctly.
        """
        state = TrainState.create(apply_fn=model_apply_fn, params=params, tx=optax.sgd(1e-3))
        state = replicate(state)
        
        p_eval_step = jax.pmap(eval_step, axis_name='batch')

        num_examples = input_pipeline.get_dataset_info(args.dataset, 'test')['num_examples']
        batch_metrics = []
        
        dataset_iter = dataset.as_numpy_iterator()
        pbar = tqdm(range(0, num_examples, batch_size), desc="Evaluating", leave=False)
        
        for i in pbar:
            end = min(i + batch_size, num_examples)
            current_batch_size = end - i

            try:
                # The data from the iterator is already sharded into a dictionary of 5D tensors
                batch_data = next(dataset_iter)
            except StopIteration:
                break # End of dataset

            # The data loader pads the final batch. We create a mask to ignore the padding.
            pad_size = batch_size - current_batch_size
            valid_mask_1d = jnp.array([1.0] * current_batch_size + [0.0] * pad_size)
            
            # Shard the 1D mask to match the sharded data dimensions: (batch_size,) -> (num_devices, per_device_batch_size)
            sharded_valid_mask = jnp.reshape(valid_mask_1d, (num_devices, per_device_batch_size))

            # Construct the batch for pmap. 'image' and 'label' are already correctly sharded.
            batch = {
                'image': batch_data['image'],
                'label': batch_data['label'],
                'valid_mask': sharded_valid_mask
            }
            
            metrics = p_eval_step(state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
        
        if not batch_metrics:
            return 0.0, 0.0

        mean_metrics = {k: np.mean([m[k] for m in batch_metrics]) for k in batch_metrics[0]}
        return mean_metrics['loss'], mean_metrics['accuracy']


    def compute_interpolation(model_a_params, model_b_params, model_apply_fn, lambdas, ds_test, batch_size, num_devices, per_device_batch_size, desc="Interpolation"):
        """
        Computes metrics for linear interpolation between two models.
        """
        test_loss_interp, test_acc_interp = [], []

        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a_params), unfreeze(model_b_params)))
            
            test_loss, test_acc = dataset_loss_and_accuracy(p_interp, model_apply_fn, ds_test, batch_size, num_devices, per_device_batch_size)
            
            test_loss_interp.append(test_loss)
            test_acc_interp.append(test_acc)
            
        return {
            "Train Loss": [0.0] * len(lambdas), # Skipping train for speed
            "Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],
            "Train Acc": [0.0] * len(lambdas), # Skipping train for speed
            "Test Acc": [float(f"{x:.4f}") for x in test_acc_interp]
        }

    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices")
    assert args.batch_size % num_devices == 0, "Batch size must be divisible by the number of devices"
    per_device_batch_size = args.batch_size // num_devices

    os.makedirs(args.plot_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)

    # --- Data Loading ---
    print(f"Loading {args.dataset} data...")
    data_config = common_config.with_dataset(common_config.get_config(), args.dataset)
    data_config.batch = args.batch_size
    data_config.pp.crop = 224 # As used in finetuning
    ds_train = input_pipeline.get_data_from_tfds(config=data_config, mode='train')
    ds_test = input_pipeline.get_data_from_tfds(config=data_config, mode='test')
    num_classes = input_pipeline.get_dataset_info(args.dataset, 'train')['num_classes']
    
    # --- Model Definition ---
    model_config = ml_collections.ConfigDict()
    model_config.model_name = 'ViT-S_16'
    model_config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    model_config.hidden_size = 384
    model_config.transformer = ml_collections.ConfigDict()
    model_config.transformer.mlp_dim = 1536
    model_config.transformer.num_heads = 6
    model_config.transformer.num_layers = 12
    model_config.transformer.attention_dropout_rate = 0.0
    model_config.transformer.dropout_rate = 0.0
    model_config.classifier = 'token'
    model_config.representation_size = None

    model = VisionTransformer(num_classes=num_classes, **model_config)

    print(f"Loading model A from: {args.model_a}")
    model_a_params = checkpoint.load(args.model_a)
    
    print(f"Loading model B from: {args.model_b}")
    model_b_params = checkpoint.load(args.model_b)

    # Get activations for data-dependent matching
    activation_batch_size = 1028
    print(f"Computing activations from models on a batch of {activation_batch_size} training examples...")
    activation_batch = next(ds_train.as_numpy_iterator())
    images = activation_batch['image']

    # The input pipeline pre-shards the data. Reshape from 5D to 4D if necessary.
    if images.ndim == 5:
        num_devices, per_device_batch, h, w, c = images.shape
        print(f"Detected sharded data with shape {images.shape}. Reshaping to 4D.")
        images = images.reshape((num_devices * per_device_batch, h, w, c))

    # Ensure the batch has the right size for get_mha_inputs by slicing if needed.
    if images.shape[0] > activation_batch_size:
        images = images[:activation_batch_size]

    # Now, `images` is a 4D tensor, which is the correct format for the model.
    activations_a = get_mha_inputs(model, model_a_params, images, train=False)
    activations_b = get_mha_inputs(model, model_b_params, images, train=False)
    print(f"Got {len(activations_a)} sets of activations, for each of the {model_config.transformer.num_layers} layers.")

    # --- Interpolation and Evaluation ---
    lambdas = jnp.linspace(0, 1, num=11)
    all_results = {}
    
    # 1. Naive Interpolation
    naive_results = compute_interpolation(
        model_a_params, model_b_params, model.apply, lambdas, ds_test, args.batch_size, num_devices, per_device_batch_size, desc="Naive Interpolation"
    )
    all_results["Naive"] = naive_results
    print("\n--- Naive Interpolation Results ---")
    print(json.dumps({"Naive": naive_results}, indent=2))

    # 2. Matching on Attention Layers
    finetune_layer_indices = list(range(args.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]
    print(f"\nPerforming attention weight matching on layers: {finetune_layer_indices}")
    
    rng, matching_rng = random.split(rng)
    # Conditional call to matching function ---
    if not args.rope_use:
        aligned_models = matching_attn(
            matching_rng, model_a_params, model_b_params, activations_a, activations_b, finetune_layer_indices, model_config.transformer.num_heads, args.plot_path
        )
    else:
        aligned_models = matching_attn_rope(
            model_a_params, model_b_params, activations_a, activations_b, finetune_layer_indices, model_config.transformer.num_heads
        )

    for method, model_b_aligned in aligned_models.items():
        method_results = compute_interpolation(
            model_a_params, model_b_aligned, model.apply, lambdas, ds_test, args.batch_size, num_devices, per_device_batch_size, desc=f"{method} Interpolation"
        )
        all_results[method] = method_results
        print(f"\n--- {method} Interpolation Results ---")
        print(json.dumps({method: method_results}, indent=2))

    # Save all results to a JSON file
    results_path = os.path.join(args.plot_path, "interpolation_results.json")
    with open(results_path, "w") as f:
        json.dump(all_results, f, indent=2)
    print(f"\nAll results saved to {results_path}")

    # --- Plotting ---
    plt.rcParams.update({
        "font.family": "serif", 'legend.frameon': False, 'lines.linewidth': 2, 'font.size': 13,
        'axes.labelsize': 16, 'xtick.labelsize': 11, 'ytick.labelsize': 11, 'legend.fontsize': 11,
    })
    plt.style.use('tableau-colorblind10')

    lambda_values = lambdas
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
    
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method, results in all_results.items():
            ax.plot(lambda_values, results[metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model A", r"$\lambda$", "Model B"])
        ax.set_ylabel(metric)
        ax.grid(True, linestyle='--', alpha=0.6)
        if row == 0 and col == 0: ax.legend(loc='best')

    plt.tight_layout()
    save_path = os.path.join(args.plot_path, "interpolation_plots.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Combined plot saved to {save_path}")

    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method, results in all_results.items():
            if method == "Naive":
                continue
            ax.plot(lambda_values, results[metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model A", r"$\lambda$", "Model B"])
        ax.set_ylabel(metric)
        ax.grid(True, linestyle='--', alpha=0.6)
        if row == 0 and col == 0: ax.legend(loc='best')

    plt.tight_layout()
    save_path = os.path.join(args.plot_path, "interpolation_plots_no_naive.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Plot without naive baseline saved to {save_path}")
    
if __name__ == "__main__":
    main()