import os
import argparse
import numpy as np
from flax import traverse_util
from flax.core import freeze, unfreeze
import jax
import jax.numpy as jnp
from jax import random
import ml_collections
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
import flax.jax_utils as jax_utils
from vit_jax import checkpoint, input_pipeline, models
from vit_jax.configs import common as common_config
from vit_jax.weight_matching import weight_matching, apply_permutation, vit_permutation_spec_moe, permute_moe_block
from cifar10.finetune_cifar10_moe import VisionTransformerMoE

# Utility functions
def flatten_params(params):
  return {"/".join(k): v for k, v in traverse_util.flatten_dict(unfreeze(params)).items()}

def unflatten_params(flat_params):
  return freeze(
      traverse_util.unflatten_dict({tuple(k.split("/")): v
                                    for k, v in flat_params.items()}))

def lerp(lam, a, b):
    return jax.tree.map(lambda x, y: (1 - lam) * x + lam * y, a, b)

def flatten_params(params):
    return traverse_util.flatten_dict(params, sep='/')

def unflatten_params(flat_params):
    return traverse_util.unflatten_dict(flat_params, sep='/')

def load_model(filepath):
    """Load model parameters from an .npz checkpoint file."""
    data = np.load(filepath)
    flat_params = {k: data[k] for k in data}
    params = traverse_util.unflatten_dict(flat_params, sep='/')
    return freeze(params)

def compute_weights_cost_matrices(model, params_a, params_b, num_experts):
    """Compute cost matrices based on expert weights and gating similarity."""
    moe_params_a = model.get_moe_params(params_a)
    moe_params_b = model.get_moe_params(params_b)

    # Gating-based similarity matrix S
    gating_kernel_a = moe_params_a['gating_kernel']
    gating_bias_a = moe_params_a['gating_bias']
    gating_kernel_b = moe_params_b['gating_kernel']
    gating_bias_b = moe_params_b['gating_bias']
    
    gating_vectors_a = np.hstack([gating_kernel_a.T, gating_bias_a[:, np.newaxis]])
    gating_vectors_b = np.hstack([gating_kernel_b.T, gating_bias_b[:, np.newaxis]])
    
    norms_a = np.linalg.norm(gating_vectors_a, axis=1)
    norms_b = np.linalg.norm(gating_vectors_b, axis=1)
    dot_product = gating_vectors_a @ gating_vectors_b.T
    S = dot_product / (norms_a[:, np.newaxis] * norms_b[np.newaxis, :] + 1e-8)

    # Expert parameters-based distance matrix D
    D = np.zeros((num_experts, num_experts))
    for i in range(num_experts):
        W1_a = moe_params_a[f'expert_{i}_layer1_kernel']
        b1_a = moe_params_a[f'expert_{i}_layer1_bias']
        W2_a = moe_params_a[f'expert_{i}_layer2_kernel']
        b2_a = moe_params_a[f'expert_{i}_layer2_bias']
        
        W_tilde1_a = np.vstack([W1_a, b1_a[np.newaxis, :]])
        W_tilde2_a = np.vstack([W2_a, b2_a[np.newaxis, :]])
        
        gram1_a = W_tilde1_a.T @ W_tilde1_a
        gram2_a = W_tilde2_a.T @ W_tilde2_a
        
        for j in range(num_experts):
            W1_b = moe_params_b[f'expert_{j}_layer1_kernel']
            b1_b = moe_params_b[f'expert_{j}_layer1_bias']
            W2_b = moe_params_b[f'expert_{j}_layer2_kernel']
            b2_b = moe_params_b[f'expert_{j}_layer2_bias']
            
            W_tilde1_b = np.vstack([W1_b, b1_b[np.newaxis, :]])
            W_tilde2_b = np.vstack([W2_b, b2_b[np.newaxis, :]])
            
            gram1_b = W_tilde1_b.T @ W_tilde1_b
            gram2_b = W_tilde2_b.T @ W_tilde2_b
            
            diff1 = gram1_a - gram1_b
            diff2 = gram2_a - gram2_b
            D[i, j] = np.linalg.norm(diff1, 'fro') + np.linalg.norm(diff2, 'fro')

    return D, S

def select_permutations(D, S, num_experts, alphas=[0.5]):
    """Select permutations based on D, S, and hybrid costs."""
    _, perm_D = linear_sum_assignment(D)
    _, perm_S = linear_sum_assignment(-S)

    D_std = (D - D.mean()) / (D.std() + 1e-8)
    S_std = (S - S.mean()) / (S.std() + 1e-8)

    perm_hybrids = []
    for alpha in alphas:
        cost_hybrid = alpha * D_std - (1 - alpha) * S_std
        _, perm_hybrid = linear_sum_assignment(cost_hybrid)
        perm_hybrids.append(perm_hybrid)

    return perm_D, perm_S, perm_hybrids

def dataset_loss_and_accuracy(compute_batch_metrics, params_repl, dataset, num_samples, batch_size):
    """Compute loss and accuracy over the dataset up to num_samples, similar to reference code."""
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    dataset_iter = iter(dataset.as_numpy_iterator())
    for batch in dataset_iter:
        loss, accuracy = compute_batch_metrics(params_repl, batch)
        total_loss += jnp.mean(loss)
        total_accuracy += jnp.mean(accuracy)
        num_batches += 1
        if num_batches * batch_size >= num_samples:
            break
    return total_loss / num_batches, total_accuracy / num_batches

def main():
    parser = argparse.ArgumentParser(description="Interpolation analysis for ViT-MoE models")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second model checkpoint")
    parser.add_argument("--num-experts", type=int, default=2, help="Number of experts in MoE")
    parser.add_argument("--plot-path", type=str, default="./plots", help="Directory to save plots")
    args = parser.parse_args()

    # Load dataset
    dataset = 'cifar10'
    batch_size = 512
    config = common_config.with_dataset(common_config.get_config(), dataset)
    config.batch = batch_size
    config.pp.crop = 224
    train_ds = input_pipeline.get_data_from_tfds(config=config, mode='train')
    test_ds = input_pipeline.get_data_from_tfds(config=config, mode='test')
    num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
    num_train_examples = input_pipeline.get_dataset_info(dataset, 'train')['num_examples']

    # Model configurations
    config_vit = ml_collections.ConfigDict()
    config_vit.model_name = 'ViT-B_16'
    config_vit.patches = ml_collections.ConfigDict({'size': (32, 32)})
    config_vit.hidden_size = 768
    config_vit.transformer = ml_collections.ConfigDict()
    config_vit.transformer.mlp_dim = 3072
    config_vit.transformer.num_heads = 12
    config_vit.transformer.num_layers = 12
    config_vit.transformer.attention_dropout_rate = 0.0
    config_vit.transformer.dropout_rate = 0.0
    config_vit.classifier = 'token'
    config_vit.representation_size = None
    config_vit.transformer.num_experts = args.num_experts
    config_vit.transformer.expert_hidden_dim = 3072
    model_config = config_vit
    model = VisionTransformerMoE(num_classes=num_classes, **model_config)

    # Load models
    model_a = load_model(args.model_a)
    model_b = load_model(args.model_b)
    model_b = permute_moe_block(model_a, np.array([3,2,1,0]), num_layers = 12, num_experts = args.num_experts)


    # Define pmapped compute_batch_metrics function
    compute_batch_metrics = jax.pmap(lambda params, batch: (
        -jnp.sum(jax.nn.log_softmax(model.apply({'params': params}, batch['image'], train=False)) * batch['label']) / batch['label'].shape[0],
        (model.apply({'params': params}, batch['image'], train=False).argmax(axis=-1) == batch['label'].argmax(axis=-1)).mean()
    ))

    # Interpolation lambdas
    lambdas = jnp.linspace(0, 1, num=15)

    # Naive interpolation
    naive_metrics = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    for lam in lambdas:
        naive_p = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b)))
        params_repl = jax_utils.replicate(naive_p)
        train_loss, train_acc = dataset_loss_and_accuracy(compute_batch_metrics, params_repl, train_ds, 10000, batch_size)
        test_loss, test_acc = dataset_loss_and_accuracy(compute_batch_metrics, params_repl, test_ds, 10000, batch_size)
        naive_metrics['train_loss'].append(float(train_loss))
        naive_metrics['train_acc'].append(float(train_acc))
        naive_metrics['test_loss'].append(float(test_loss))
        naive_metrics['test_acc'].append(float(test_acc))

    # Compute cost matrices
    D_weight, S_weight = compute_weights_cost_matrices(model, model_a, model_b, args.num_experts)

    # Select permutations
    perm_D, perm_S, perm_hybrids = select_permutations(D_weight, S_weight, args.num_experts)
    selected_perms = {
        "expert_matching": perm_D,
        "gating_matching": perm_S,
        "hybrid_alpha_0.5": perm_hybrids[0]
    }
    print(selected_perms)

    # Aligned interpolation
    aligned_metrics = {}
    permutation_spec = vit_permutation_spec_moe(num_layers = 12, num_experts = args.num_experts)
    for method, pi in selected_perms.items():
        model_b_pi = permute_moe_block(model_b, pi, num_layers = 12, num_experts = args.num_experts)
        final_permutation = weight_matching(
            random.PRNGKey(0),
            permutation_spec,
            flatten_params(model_a),
            flatten_params(model_b_pi)
        )
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        metrics = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
        for lam in lambdas:
            clever_p = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_pi_aligned)))
            params_repl = jax_utils.replicate(clever_p)
            train_loss, train_acc = dataset_loss_and_accuracy(compute_batch_metrics, params_repl, train_ds, 10000, batch_size)
            test_loss, test_acc = dataset_loss_and_accuracy(compute_batch_metrics, params_repl, test_ds, 10000, batch_size)
            metrics['train_loss'].append(float(train_loss))
            metrics['train_acc'].append(float(train_acc))
            metrics['test_loss'].append(float(test_loss))
            metrics['test_acc'].append(float(test_acc))
        aligned_metrics[method] = metrics

    # Output results
    os.makedirs(args.plot_path, exist_ok=True)
    for metric in ['train_loss', 'train_acc', 'test_loss', 'test_acc']:
        plt.figure(figsize=(10, 6))
        plt.plot(lambdas, naive_metrics[metric], label='Naive', marker='o')
        for method, metrics in aligned_metrics.items():
            plt.plot(lambdas, metrics[metric], label=method, marker='x')
        plt.xlabel('Lambda')
        plt.ylabel(metric.replace('_', ' ').title())
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(args.plot_path, f'{metric}.png'))
        plt.close()

    # Print metrics
    print("Naive Interpolation Metrics:")
    for i, lam in enumerate(lambdas):
        print(f"Lambda {lam:.2f}: Train Loss: {naive_metrics['train_loss'][i]:.4f}, "
              f"Train Acc: {naive_metrics['train_acc'][i]:.4f}, "
              f"Test Loss: {naive_metrics['test_loss'][i]:.4f}, "
              f"Test Acc: {naive_metrics['test_acc'][i]:.4f}")
    for method, metrics in aligned_metrics.items():
        print(f"\n{method} Interpolation Metrics:")
        for i, lam in enumerate(lambdas):
            print(f"Lambda {lam:.2f}: Train Loss: {metrics['train_loss'][i]:.4f}, "
                  f"Train Acc: {metrics['train_acc'][i]:.4f}, "
                  f"Test Loss: {metrics['test_loss'][i]:.4f}, "
                  f"Test Acc: {metrics['test_acc'][i]:.4f}")

if __name__ == "__main__":
    main()