import os
import argparse
import numpy as np
from flax import traverse_util
from flax.core import freeze, unfreeze
import jax
import jax.numpy as jnp
from jax import random
import ml_collections
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
import flax.jax_utils as jax_utils
from vit_jax import checkpoint, input_pipeline, models
from vit_jax.configs import common as common_config
from vit_jax.weight_matching import weight_matching, apply_permutation, vit_permutation_spec_moe, permute_moe_block
from cifar10.finetune_cifar10_moe import VisionTransformerMoE

# Utility functions
def flatten_params(params):
  return {"/".join(k): v for k, v in traverse_util.flatten_dict(unfreeze(params)).items()}

def unflatten_params(flat_params):
  return freeze(
      traverse_util.unflatten_dict({tuple(k.split("/")): v
                                    for k, v in flat_params.items()}))

def lerp(lam, a, b):
    return jax.tree.map(lambda x, y: (1 - lam) * x + lam * y, a, b)

def flatten_params(params):
    return traverse_util.flatten_dict(params, sep='/')

def unflatten_params(flat_params):
    return traverse_util.unflatten_dict(flat_params, sep='/')

def load_model(filepath):
    """Load model parameters from an .npz checkpoint file."""
    data = np.load(filepath)
    flat_params = {k: data[k] for k in data}
    params = traverse_util.unflatten_dict(flat_params, sep='/')
    return freeze(params)

def compute_weights_cost_matrices(model, params_a, params_b, num_experts):
    """Compute cost matrices based on expert weights and gating similarity."""
    moe_params_a = model.get_moe_params(params_a)
    moe_params_b = model.get_moe_params(params_b)

    # Gating-based similarity matrix S
    gating_kernel_a = moe_params_a['gating_kernel']
    gating_bias_a = moe_params_a['gating_bias']
    gating_kernel_b = moe_params_b['gating_kernel']
    gating_bias_b = moe_params_b['gating_bias']
    
    # Center the gating kernels and biases
    mean_gating_kernel_a = np.mean(gating_kernel_a, axis=0)
    mean_gating_bias_a = np.mean(gating_bias_a)
    centered_gating_kernel_a = gating_kernel_a - mean_gating_kernel_a
    centered_gating_bias_a = gating_bias_a - mean_gating_bias_a

    mean_gating_kernel_b = np.mean(gating_kernel_b, axis=0)
    mean_gating_bias_b = np.mean(gating_bias_b)
    centered_gating_kernel_b = gating_kernel_b - mean_gating_kernel_b
    centered_gating_bias_b = gating_bias_b - mean_gating_bias_b

    # Construct gating vectors with centered kernels and biases
    gating_vectors_a = np.hstack([centered_gating_kernel_a.T, centered_gating_bias_a[:, np.newaxis]])
    gating_vectors_b = np.hstack([centered_gating_kernel_b.T, centered_gating_bias_b[:, np.newaxis]])

    # Compute the difference between gating vectors
    diff_vectors = gating_vectors_a[:, np.newaxis, :] - gating_vectors_b[np.newaxis, :, :]

    # Compute the Euclidean distance
    S = np.sqrt(np.sum(diff_vectors ** 2, axis=2))

    # Expert parameters-based distance matrix D
    D = np.zeros((num_experts, num_experts))
    for i in range(num_experts):
        W1_a = moe_params_a[f'expert_{i}_layer1_kernel']
        b1_a = moe_params_a[f'expert_{i}_layer1_bias']
        W1p_a = moe_params_a[f'expert_{i}_layer2_kernel']
        b1p_a = moe_params_a[f'expert_{i}_layer2_bias']
        
        W_tilde1_a = np.vstack([W1_a, b1_a[np.newaxis, :]])
        W_tilde1p_a = np.vstack([W1p_a, b1p_a[np.newaxis, :]])
        
        gram1_a = W_tilde1_a.T @ W_tilde1_a
        gram1p_a = W_tilde1p_a @ W_tilde1p_a.T
        
        for j in range(num_experts):
            W1_b = moe_params_b[f'expert_{j}_layer1_kernel']
            b1_b = moe_params_b[f'expert_{j}_layer1_bias']
            W1p_b = moe_params_b[f'expert_{j}_layer2_kernel']
            b1p_b = moe_params_b[f'expert_{j}_layer2_bias']
            
            W_tilde1_b = np.vstack([W1_b, b1_b[np.newaxis, :]])
            W_tilde1p_b = np.vstack([W1p_b, b1p_b[np.newaxis, :]])
            
            gram1_b = W_tilde1_b.T @ W_tilde1_b
            gram1p_b = W_tilde1p_b @ W_tilde1p_b.T
            
            diff1 = gram1_a - gram1_b
            diff1p = gram1p_a - gram1p_b
            norm_diff1 = np.linalg.norm(diff1, 'fro')
            norm_diff1p = np.linalg.norm(diff1p, 'fro')
            D[i, j] = np.sqrt(norm_diff1**2 + norm_diff1p**2)

    return D, S

def select_permutations(D, S, num_experts, alphas=[0.5]):
    """Select permutations based on D, S, and hybrid costs."""
    _, perm_D = linear_sum_assignment(D)
    _, perm_S = linear_sum_assignment(S)

    return perm_D, perm_S

def main():
    parser = argparse.ArgumentParser(description="Interpolation analysis for ViT-MoE models")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second model checkpoint")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of layers in MoE")
    parser.add_argument("--moe-layer-which", type=int, required=True, help="which mlp layer is replaced")
    parser.add_argument("--num-experts", type=int, required=True, help="Number of experts in MoE")
    args = parser.parse_args()

    # Load dataset
    dataset = 'cifar10'
    batch_size = 2048
    config = common_config.with_dataset(common_config.get_config(), dataset)
    config.batch = batch_size
    config.batch_eval = batch_size
    config.pp.crop = 224
    ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
    ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
    num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
    num_train_examples = input_pipeline.get_dataset_info(dataset, 'train')['num_examples']

    # Model configurations
    config = ml_collections.ConfigDict()
    config.model_name = 'ViT-S_16'
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 384
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 1536
    config.transformer.num_heads = 6
    config.transformer.num_layers = args.num_layers
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.0
    config.classifier = 'token'
    config.representation_size = None

    config_moe = ml_collections.ConfigDict(config.to_dict())
    config_moe.transformer.moe_layer_which = args.moe_layer_which
    config_moe.transformer.num_experts = args.num_experts
    config_moe.transformer.expert_hidden_dim = config.transformer.mlp_dim

    model = VisionTransformerMoE(num_classes=num_classes, **config_moe)

    # Load models
    model_a = load_model(args.model_a)
    model_b = load_model(args.model_b)

    # Test metrics computation
    @jax.pmap
    def compute_batch_metrics(params, batch):
        logits = model.apply({'params': params}, batch['image'], train=False)
        loss = -jnp.sum(jax.nn.log_softmax(logits) * batch['label']) / batch['label'].shape[0]
        accuracy = (logits.argmax(axis=-1) == batch['label'].argmax(axis=-1)).mean()
        return loss, accuracy

    def evaluate_test(params_repl):
        test_loss_sum = 0.0
        test_accuracy_sum = 0.0
        num_batches = 0
        test_iter = iter(ds_test.as_numpy_iterator())
        for batch in test_iter:
            loss, accuracy = compute_batch_metrics(params_repl, batch)
            test_loss_sum += jnp.mean(loss)
            test_accuracy_sum += jnp.mean(accuracy)
            num_batches += 1
        return test_loss_sum / num_batches, test_accuracy_sum / num_batches

    # Interpolation lambdas
    lambdas = jnp.linspace(0, 1, num=15) #25

    # Naive interpolation
    naive_metrics = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    for lam in lambdas:
        naive_p = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b)))
        params_repl = jax_utils.replicate(naive_p)
        #train_loss, train_acc = evaluate_test(compute_batch_metrics, params_repl, train_ds, 10000, batch_size)
        #test_loss, test_acc = evaluate_test(compute_batch_metrics, params_repl, test_ds, 10000, batch_size)
        test_loss, test_acc = evaluate_test(params_repl)        
        naive_metrics['train_loss'].append(0)#float(train_loss))
        naive_metrics['train_acc'].append(0)#float(train_acc))
        naive_metrics['test_loss'].append(float(test_loss))
        naive_metrics['test_acc'].append(float(test_acc))

    print("Naive Interpolation:")
    print("  Train Loss:", [float(f"{x:.4f}") for x in naive_metrics['train_loss']])
    print("  Test Loss:", [float(f"{x:.4f}") for x in naive_metrics['test_loss']])
    print("  Train Acc:", [float(f"{x:.4f}") for x in naive_metrics['train_acc']])
    print("  Test Acc:", [float(f"{x:.4f}") for x in naive_metrics['test_acc']])

    # Compute cost matrices
    D_weight, S_weight = compute_weights_cost_matrices(model, model_a, model_b, args.num_experts)

    # Select permutations
    perm_D, perm_S = select_permutations(D_weight, S_weight, args.num_experts)
    selected_perms = {
        "Expert Weight Matching": perm_D,
        "Gating Weight Matching": perm_S,
    }
    print(selected_perms)

    # Aligned interpolation
    aligned_metrics = {}
    permutation_spec = vit_permutation_spec_moe(moe_layer_which = args.moe_layer_which, num_experts = args.num_experts)
    for method, pi in selected_perms.items():
        model_b_pi = permute_moe_block(model_b, pi, moe_layer_which = args.moe_layer_which, num_experts = args.num_experts)
        final_permutation = weight_matching(
            random.PRNGKey(0),
            permutation_spec,
            flatten_params(model_a),
            flatten_params(model_b_pi)
        )
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        metrics = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
        for lam in lambdas:
            clever_p = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_pi_aligned)))
            params_repl = jax_utils.replicate(clever_p)
            test_loss, test_acc = evaluate_test(params_repl)        
            metrics['train_loss'].append(0)#float(train_loss))
            metrics['train_acc'].append(0)#float(train_acc))
            metrics['test_loss'].append(float(test_loss))
            metrics['test_acc'].append(float(test_acc))

        print(f"{method}:")
        print("  Train Loss:", [float(f"{x:.4f}") for x in metrics['train_loss']])
        print("  Test Loss:", [float(f"{x:.4f}") for x in metrics['test_loss']])
        print("  Train Acc:", [float(f"{x:.4f}") for x in metrics['train_acc']])
        print("  Test Acc:", [float(f"{x:.4f}") for x in metrics['test_acc']])

if __name__ == "__main__":
    main()