import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List

import json
import jax.numpy as jnp
import torch
import itertools 
import numpy as np
from tqdm import tqdm
from jax import random
from scipy.optimize import linear_sum_assignment
from datasets import Dataset
import matplotlib.pyplot as plt
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from flax.core.frozen_dict import freeze, unfreeze
from weight_matching import apply_permutation, weight_matching, permute_moe_block
from weight_matching import flax_gpt2_permutation_spec_moe, plot_interp_loss, plot_interp_ppl
from weight_matching_moe import make_stuff, recursive_to_serializable
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params
import multiprocessing as mp
import os, pprint, pathlib
from data_utils import get_lm_corpus

def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    """
    Convert and shard a language modeling batch from PyTorch to JAX.
    Args:
        data (torch.Tensor): Input data of shape (seq_len, batch)
        target (torch.Tensor): Target data of shape (seq_len, batch)
    Returns:
        Dict[str, jnp.ndarray]: Dict with 'data' and 'target', both sharded
            with shape (n_devices, batch_per_device, seq_len)
    """
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': input_ids,'target': target}

def compute_weights_cost_matrices(model, params_a, params_b, config):
    """
    Compute cost matrices for GPT2-MoE models using routing gate weights and expert FFN parameters.

    Args:
        model: FlaxGPT2MoELMHeadModel instance.
        params_a, params_b: FrozenDict parameter trees.
        config: Model config with `moe_layer_indices` (int), `num_routed_experts`.

    Returns:
        D: [num_experts, num_experts] Frobenius norm distance between expert weights.
        S: [num_experts, num_experts] Euclidean distance between gating vectors.
    """
    moe_idx = str(config.moe_layer_indices)
    num_experts = config.num_routed_experts

    moe_block_a = params_a['transformer']['h'][moe_idx]['mlp']
    moe_block_b = params_b['transformer']['h'][moe_idx]['mlp']

    # ----- Compute Gating Distance Matrix (S) -----
    gate_kernel_a = np.array(moe_block_a['gate']['kernel'])  # [hidden_dim, num_experts]
    gate_bias_a = np.array(moe_block_a['gate']['bias'])      # [num_experts]
    gate_kernel_b = np.array(moe_block_b['gate']['kernel'])
    gate_bias_b = np.array(moe_block_b['gate']['bias'])

    # Center the gating weights
    centered_kernel_a = gate_kernel_a - np.mean(gate_kernel_a, axis=0, keepdims=True)
    centered_bias_a = gate_bias_a - np.mean(gate_bias_a)
    centered_kernel_b = gate_kernel_b - np.mean(gate_kernel_b, axis=0, keepdims=True)
    centered_bias_b = gate_bias_b - np.mean(gate_bias_b)

    # Gating vectors: [num_experts, hidden_dim + 1]
    gating_vectors_a = np.hstack([centered_kernel_a.T, centered_bias_a[:, None]])
    gating_vectors_b = np.hstack([centered_kernel_b.T, centered_bias_b[:, None]])

    diff = gating_vectors_a[:, None, :] - gating_vectors_b[None, :, :]
    S = np.sqrt(np.sum(diff ** 2, axis=-1))  # Euclidean distance

    # ----- Compute Expert Weight Distances (D) -----
    D = np.zeros((num_experts, num_experts))
    for i in range(num_experts):
        expert_a = moe_block_a[f'routed_experts_{i}']
        fc_a = np.array(expert_a['c_fc']['kernel'])
        b_fc_a = np.array(expert_a['c_fc']['bias'])
        proj_a = np.array(expert_a['c_proj']['kernel'])
        b_proj_a = np.array(expert_a['c_proj']['bias'])

        W1_a_tilde = np.vstack([fc_a.T, b_fc_a[None, :]])         # [768+1, 3072]
        W2_a_tilde = np.vstack([proj_a.T, b_proj_a[None, :]])  # shape: [3072+1, 768]


        gram1_a = W1_a_tilde.T @ W1_a_tilde                       # [3072, 3072]
        gram2_a = W2_a_tilde @ W2_a_tilde.T                       # [768, 768]

        for j in range(num_experts):
            expert_b = moe_block_b[f'routed_experts_{j}']
            fc_b = np.array(expert_b['c_fc']['kernel'])
            b_fc_b = np.array(expert_b['c_fc']['bias'])
            proj_b = np.array(expert_b['c_proj']['kernel'])
            b_proj_b = np.array(expert_b['c_proj']['bias'])

            W1_b_tilde = np.vstack([fc_b.T, b_fc_b[None, :]])
            W2_b_tilde = np.vstack([proj_b.T, b_proj_b[None, :]]) 

            gram1_b = W1_b_tilde.T @ W1_b_tilde
            gram2_b = W2_b_tilde @ W2_b_tilde.T

            diff1 = gram1_a - gram1_b
            diff2 = gram2_a - gram2_b

            D[i, j] = np.sqrt(np.linalg.norm(diff1, 'fro')**2 + np.linalg.norm(diff2, 'fro')**2)

    return D, S
def select_permutations(D_score, S_score, config, alphas):
    """
    Select optimal expert permutations for GPT2-MoE models using:
    - Output feature distance (D)
    - Gating weight similarity (S)
    - Hybrid of both

    Args:
        D_score: [num_experts, num_experts] L2 distance matrix between expert outputs.
        S_score: [num_experts, num_experts] cosine similarity matrix between gating weights.
        config: Model config with `num_routed_experts`.
        alphas: List of alpha values ∈ [0, 1] to weight D and S in hybrid matching.

    Returns:
        perm_D: np.ndarray, permutation minimizing output distance.
        perm_S: np.ndarray, permutation maximizing gating similarity.
        perm_hybrids: list of np.ndarrays, permutations for each hybrid alpha.
    """
    num_experts = config.num_routed_experts
    # Match experts by minimizing L2 distance
    _, perm_D = linear_sum_assignment(D_score)
    # Match experts by maximizing cosine similarity (negate to minimize)
    _, perm_S = linear_sum_assignment(-S_score)
    # Normalize for hybrid matching
    D_std = (D_score - D_score.mean()) / (D_score.std() + 1e-8)
    S_std = (S_score - S_score.mean()) / (S_score.std() + 1e-8)
    perm_hybrids = []
    for alpha in alphas:
        hybrid_cost = alpha * D_std - (1 - alpha) * S_std
        _, perm_hybrid = linear_sum_assignment(hybrid_cost)
        perm_hybrids.append(np.array(perm_hybrid))

    return perm_D, perm_S, perm_hybrids




def main():
    parser = argparse.ArgumentParser(description="Weight matching for GPT2-MoE models on One Billion Words")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first GPT2-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second GPT2-MoE model checkpoint")
    parser.add_argument("--seed", type=int, default =0)
    parser.add_argument("--data-path", type=str, default="./data/lm1b", help="train datset paths (multiple paths)")
    parser.add_argument('--dataset', type=str, default='lm1b',choices=['wt103', 'lm1b', 'enwik8', 'text8'],help='dataset name')
    parser.add_argument("--batch-size", type=int, default=24, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument('--tgt_len', type=int, default=256,help='number of tokens to predict')
    parser.add_argument('--eval_tgt_len', type=int, default=256,help='number of tokens to predict for evaluation')
    parser.add_argument('--ext_len', type=int, default=0,help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=0,help='length of the retained previous heads')
    parser.add_argument("--save-path", type=str, default="/", help="Path to plot directory")
    args = parser.parse_args()
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    args.n_token = ntokens
    eval_batch_size = 10
    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    train_loader = va_iter
    test_loader = te_iter
    config = GPT2Config.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    model = FlaxGPT2MoELMHeadModel(config)
    params_a = checkpoints.restore_checkpoint(ckpt_dir=args.model_a, target={"params": model.params})["params"]
    params_b = checkpoints.restore_checkpoint(ckpt_dir=args.model_b, target={"params": model.params})["params"]
    stuff = make_stuff(model = model)
    permutation_spec = flax_gpt2_permutation_spec_moe(config=config)
    lambdas = jnp.linspace(0, 1, num=25)
    # --- Naive Interpolation ---
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_ppl_interp_naive, test_ppl_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))
        train_loss, train_ppl = stuff["dataset_loss_and_ppl"](naive_p, train_loader)
        test_loss, test_ppl = stuff["dataset_loss_and_ppl"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_ppl_interp_naive.append(train_ppl)
        test_ppl_interp_naive.append(test_ppl)
    # Weight cost matrices
    D_weight, S_weight = compute_weights_cost_matrices(model, params_a, params_b, config)
    _, perm_expert_weight = linear_sum_assignment(D_weight)
    _, perm_gating_weight = linear_sum_assignment(S_weight)
    # Collect unique permutations
    selected_perms, seen_perms = {}, set()
    def add_perm(name, perm):
        perm_tuple = tuple(perm)
        if perm_tuple not in seen_perms:
            selected_perms[name] = perm_tuple
            seen_perms.add(perm_tuple)

    add_perm("expert_weight_matching", perm_expert_weight)
    add_perm("gating_weight_matching", perm_gating_weight)

    # Interpolation with clever alignment
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_ppl_interp_clever_list, test_ppl_interp_clever_list = [], []

    for label, pi in selected_perms.items():
        pi = list(pi)
        parans_b_pi = permute_moe_block(params_b, pi, config)

        final_permutation = weight_matching(
            random.PRNGKey(args.seed), permutation_spec,
            flatten_params(params_a), flatten_params(parans_b_pi)
        )

        parans_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(parans_b_pi))
        )

        train_loss_interp, test_loss_interp = [], []
        train_ppl_interp, test_ppl_interp = [], []

        for lam in tqdm(lambdas, desc=f"Permuted Interpolation {label}"):
            clever_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(parans_b_pi_aligned)))
            train_loss, train_ppl = stuff["dataset_loss_and_ppl"](clever_p, train_loader)
            test_loss, test_ppl = stuff["dataset_loss_and_ppl"](clever_p, test_loader)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_ppl_interp.append(train_ppl)
            test_ppl_interp.append(test_ppl)

        train_loss_interp_clever_list.append(train_loss_interp)
        test_loss_interp_clever_list.append(test_loss_interp)
        train_ppl_interp_clever_list.append(train_ppl_interp)
        test_ppl_interp_clever_list.append(test_ppl_interp)

    # Save results
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_ppl_interp_naive": train_ppl_interp_naive,
        "test_ppl_interp_naive": test_ppl_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_ppl_interp_clever_list": train_ppl_interp_clever_list,
        "test_ppl_interp_clever_list": test_ppl_interp_clever_list,
        "selected_perms": {k: list(v) for k, v in selected_perms.items()}
    }
    results = recursive_to_serializable(results)

    # Validation
    assert len(lambdas) == len(train_loss_interp_naive)
    assert all(len(lambdas) == len(x) for x in train_loss_interp_clever_list)

    # Save directories
    os.makedirs("./plots/lm1b", exist_ok=True)
    os.makedirs("./results/lm1b", exist_ok=True)

    # Save results JSON
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    result_path = f'results/lm1b/[{name_a}+{name_b}].json'
    with open(result_path, 'w') as f:
        json.dump(results, f, indent=4)

    # Plot
    print("Generating plots...")
    perm_labels_selected = list(selected_perms.keys())

    loss_fig = plot_interp_loss(
        lambdas,
        train_loss_interp_naive, test_loss_interp_naive,
        train_loss_interp_clever_list, test_loss_interp_clever_list,
        perm_labels_selected
    )
    loss_fig_path = f"./plots/lm1b/[{name_a}+{name_b}]_weight_matching_interp_loss.png"
    plt.savefig(loss_fig_path, dpi=300)
    plt.close(loss_fig)

    ppl_fig = plot_interp_ppl(
        lambdas,
        train_ppl_interp_naive, test_ppl_interp_naive,
        train_ppl_interp_clever_list, test_ppl_interp_clever_list,
        perm_labels_selected
    )
    ppl_fig_path = f"./plots/lm1b/[{name_a}+{name_b}]_weight_matching_interp_ppl.png"
    plt.savefig(ppl_fig_path, dpi=300)
    plt.close(ppl_fig)

if __name__ == "__main__":
    # Parse command-line arguments
    main()