import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import copy
import jax
import json
import jax.numpy as jnp
import optax
import torch
import wandb
import itertools 
import numpy as np
from tqdm import tqdm
from jax import random
from datasets import Dataset
import matplotlib.pyplot as plt
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from flax.core.frozen_dict import freeze, unfreeze
from weight_matching import apply_permutation, weight_matching, permute_moe_block
from weight_matching import flax_gpt2_permutation_spec_moe, plot_interp_loss, plot_interp_ppl
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params
from data_utils import get_lm_corpus
import multiprocessing as mp
import os, pprint, pathlib

mp.set_start_method("spawn", force=True)
def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)

def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    """
    Convert and shard a language modeling batch from PyTorch to JAX.
    Args:
        data (torch.Tensor): Input data of shape (seq_len, batch)
        target (torch.Tensor): Target data of shape (seq_len, batch)
    Returns:
        Dict[str, jnp.ndarray]: Dict with 'data' and 'target', both sharded
            with shape (n_devices, batch_per_device, seq_len)
    """
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': input_ids,'target': target}
def make_stuff(model):
    apply_fn = model.__call__
    @jax.jit
    def batch_eval(params, batch):
        labels = batch.pop("target")  # destructive, just like your train_step
        logits = apply_fn(**batch, params=params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        return logits, labels, loss
    @jax.jit
    def step(state: train_state.TrainState, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 2)
        def loss_fn(params):
            labels = batch.pop("target")
            logits = apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        new_state = state.apply_gradients(grads=grads)
        metrics = {"batch_loss": loss,"logits": logits,}
        return new_state, metrics, new_dropout_rng
    def dataset_loss_and_ppl(params, dataloader):
        """
        Iterate once over `dataloader`, pplumulate token-level CE, and
        return (mean_loss, perplexity).  Works on **one device**.
        """
        total_loss, total_tok = 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for eval_data, eval_target, _ in pbar:
            eval_batch = prepare_lm_batch(eval_data, eval_target)
            _, labels, loss = batch_eval(params,eval_batch)
            ntok = jnp.sum(labels != -100)
            total_loss += loss * ntok
            total_tok += ntok
            pbar.set_postfix(loss=f"{loss:.4f}", ppl=f"{np.exp(loss):.2f}")
        mean_loss = (total_loss / total_tok).item()
        ppl = jnp.exp(mean_loss).item()
        return mean_loss, ppl
    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_ppl": dataset_loss_and_ppl,}

def main():
    parser = argparse.ArgumentParser(description="Weight matching for GPT2-MoE models on lm1b")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first GPT2-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second GPT2-MoE model checkpoint")
    parser.add_argument("--seed", type=int, default =0)
    parser.add_argument("--data-path", type=str, default="./data/lm1b", help="train datset paths (multiple paths)")
    parser.add_argument('--dataset', type=str, default='lm1b',choices=['wt103', 'lm1b', 'enwik8', 'text8'],help='dataset name')
    parser.add_argument('--max_step', type=int, default=80000,help='upper epoch limit')
    parser.add_argument("--batch-size", type=int, default=24, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument('--tgt_len', type=int, default=256,help='number of tokens to predict')
    parser.add_argument('--eval_tgt_len', type=int, default=256,help='number of tokens to predict for evaluation')
    parser.add_argument('--ext_len', type=int, default=0,help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=0,help='length of the retained previous heads')
    args = parser.parse_args()
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    args.n_token = ntokens
    eval_batch_size = 10
    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    train_loader = va_iter
    test_loader = te_iter
    config = GPT2Config.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    model = FlaxGPT2MoELMHeadModel(config)
    params_a = checkpoints.restore_checkpoint(ckpt_dir=args.model_a, target={"params": model.params})["params"]
    params_b = checkpoints.restore_checkpoint(ckpt_dir=args.model_b, target={"params": model.params})["params"]
    stuff = make_stuff(model = model)
    train_loss_a, train_ppl_a = stuff["dataset_loss_and_ppl"](params_a, train_loader)
    train_loss_b, train_ppl_b = stuff["dataset_loss_and_ppl"](params_b, train_loader)
    test_loss_a, test_ppl_a = train_loss_a, train_ppl_a #stuff["dataset_loss_and_ppl"](params_a, test_loader)
    test_loss_b, test_ppl_b = train_loss_b, train_ppl_b #stuff["dataset_loss_and_ppl"](params_b, test_loader)
    print({
        "train_loss_a": float(train_loss_a),"train_ppl_a": float(train_ppl_a),
        "train_loss_b": float(train_loss_b),"train_ppl_b": float(train_ppl_b),
        "test_loss_a": float(test_loss_a),"test_ppl_a": float(test_ppl_a),
        "test_loss_b": float(test_loss_b),"test_ppl_b": float(test_ppl_b),
    })
    baseline_train_loss = 0.5 * (train_loss_a + train_loss_b) #2.7338707785748593
    print(baseline_train_loss)
    permutation_spec = flax_gpt2_permutation_spec_moe(config =config)
    # Generate all possible expert permutations
    expert_perms = [list(p) for p in itertools.permutations(range(config.num_routed_experts))]
    perm_labels = [str(p) for p in expert_perms]  # e.g., "[0, 1]", "[1, 0]"
    # Define interpolation points
    lambdas = jnp.linspace(0, 1, num=25)
    # Naive interpolation (no permutation)
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_ppl_interp_naive, test_ppl_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
        train_loss, train_ppl = stuff["dataset_loss_and_ppl"](naive_p, train_loader)
        test_loss, test_ppl = train_loss, train_ppl#stuff["dataset_loss_and_ppl"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_ppl_interp_naive.append(train_ppl)
        test_ppl_interp_naive.append(test_ppl)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_ppl_interp_clever_list, test_ppl_interp_clever_list = [], []
    for pi, label in zip(expert_perms, perm_labels):
        model_b_pi = permute_moe_block(params_b, pi, config)
        # Perform weight matching
        final_permutation = weight_matching(
            random.PRNGKey(args.seed),permutation_spec,flatten_params(params_a),flatten_params(model_b_pi)
        )
        # Apply permutation to Model B_pi
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        # Interpolate between Model A and aligned Model B_pi
        train_loss_interp, test_loss_interp = [], []
        train_ppl_interp, test_ppl_interp = [], []
        for lam in tqdm(lambdas, desc=f"Permuted Interpolation {label}"):
            clever_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(model_b_pi_aligned)))
            train_loss, train_ppl = stuff["dataset_loss_and_ppl"](clever_p, train_loader)
            test_loss, test_ppl =  train_loss, train_ppl  #stuff["dataset_loss_and_ppl"](clever_p, test_loader)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_ppl_interp.append(train_ppl)
            test_ppl_interp.append(test_ppl)
        train_loss_interp_clever_list.append(train_loss_interp)
        test_loss_interp_clever_list.append(test_loss_interp)
        train_ppl_interp_clever_list.append(train_ppl_interp)
        test_ppl_interp_clever_list.append(test_ppl_interp)
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_ppl_interp_naive": train_ppl_interp_naive,
        "test_ppl_interp_naive": test_ppl_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_ppl_interp_clever_list": train_ppl_interp_clever_list,
        "test_ppl_interp_clever_list": test_ppl_interp_clever_list,
        "baseline_train_loss": baseline_train_loss
    }
    results = recursive_to_serializable(results)
    # Validate lengths
    assert len(lambdas) == len(train_loss_interp_naive) == len(test_loss_interp_naive)
    assert all(len(lambdas) == len(tl) for tl in train_loss_interp_clever_list)
    assert all(len(lambdas) == len(tl) for tl in test_loss_interp_clever_list)
    os.makedirs("./plots/lm1b", exist_ok=True)
    os.makedirs("./results/lm1b", exist_ok=True)
    # Plot results
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    with open(f'results/lm1b/[{name_a}+{name_b}].json', 'w') as f:
        json.dump(results, f, indent=4)
    print("Generating plots...")
    loss_fig = plot_interp_loss(
        lambdas,
        train_loss_interp_naive, test_loss_interp_naive,
        train_loss_interp_clever_list, test_loss_interp_clever_list,
        perm_labels
    )
    loss_fig_path = f"./plots/lm1b/[{name_a}+{name_b}]_weight_matching_interp_loss.png"
    plt.savefig(loss_fig_path, dpi=300)
    plt.close(loss_fig)
    ppl_fig = plot_interp_ppl(
        lambdas,
        train_ppl_interp_naive, test_ppl_interp_naive,
        train_ppl_interp_clever_list, test_ppl_interp_clever_list,
        perm_labels
    )
    ppl_fig_path = f"./plots/lm1b/[{name_a}+{name_b}]_weight_matching_interp_ppl.png"
    plt.savefig(ppl_fig_path, dpi=300)
    plt.close(ppl_fig)    
    # Log metrics to wandb
    metrics = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_ppl_interp_naive": train_ppl_interp_naive,
        "test_ppl_interp_naive": test_ppl_interp_naive,
    }
    for label, train_loss, test_loss, train_ppl, test_ppl in zip(
        perm_labels, train_loss_interp_clever_list, test_loss_interp_clever_list,
        train_ppl_interp_clever_list, test_ppl_interp_clever_list
    ):
        metrics[f"train_loss_interp_clever_{label}"] = train_loss
        metrics[f"test_loss_interp_clever_{label}"] = test_loss
        metrics[f"train_ppl_interp_clever_{label}"] = train_ppl
        metrics[f"test_ppl_interp_clever_{label}"] = test_ppl
if __name__ == "__main__":
    # Parse command-line arguments
    main()
