import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import copy
import jax
import json
import jax.numpy as jnp
import optax
import torch
import wandb
import itertools 
import numpy as np
from tqdm import tqdm
from jax import random
from datasets import Dataset
import matplotlib.pyplot as plt
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from flax.core.frozen_dict import freeze, unfreeze
from weight_matching import apply_permutation, weight_matching, permute_moe_block
from weight_matching import flax_gpt2_permutation_spec_moe, plot_interp_loss, plot_interp_ppl
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params
import multiprocessing as mp
import os, pprint, pathlib

mp.set_start_method("spawn", force=True)
def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)
    
def batch_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    keys = batch[0].keys()
    stacked: Dict[str, np.ndarray] = {}
    for k in keys:
        items = [np.asarray(example[k]) for example in batch]  # torch.Tensor → np
        arr   = np.stack(items, axis=0)                       # (B, …)
        # squeeze the common “(B, L, 1)” case so GPT-2 sees (B, L)
        if arr.ndim == 3 and arr.shape[-1] == 1:
            arr = np.squeeze(arr, -1)
        stacked[k] = jnp.asarray(arr)                         # host → JAX array
    return stacked
def make_stuff(model):
    apply_fn = model.__call__
    @jax.jit
    def batch_eval(params, batch):
        labels = batch.pop("labels")                     # (B, L)
        pred_logits = apply_fn(**batch, params=params, train=False)[0]  # (B, L, V)
        loss = optax.softmax_cross_entropy(pred_logits,onehot(labels, pred_logits.shape[-1])).mean()
        return pred_logits, labels, loss
    @jax.jit
    def step(state: train_state.TrainState, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 2)
        def loss_fn(params):
            labels = batch["labels"]
            pred_logits = apply_fn(params=params,dropout_rng=dropout_rng,train=True,**{k: v for k, v in batch.items() if k != "labels"})[0]
            loss = optax.softmax_cross_entropy(pred_logits,onehot(labels, pred_logits.shape[-1])).mean()
            return loss, pred_logits
        (loss, pred_logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        new_state = state.apply_gradients(grads=grads)
        metrics = {"batch_loss": loss,"pred_logits": pred_logits, "labels": batch["labels"]}
        return new_state, metrics, new_dropout_rng
    def dataset_loss_and_ppl(params, dataloader):
        """
        Iterate once over `dataloader`, pplumulate token-level CE, and
        return (mean_loss, perplexity).  Works on **one device**.
        """
        total_loss, total_tok = 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for batch in pbar:
            _, labels, loss = batch_eval(params, dict(batch))   # copy → labels pop safe
            ntok = jnp.sum(labels != -100)
            total_loss += loss * ntok
            total_tok += ntok
            pbar.set_postfix(loss=f"{loss:.4f}", ppl=f"{np.exp(loss):.2f}")
        mean_loss = (total_loss / total_tok).item()
        ppl = jnp.exp(mean_loss).item()
        return mean_loss, ppl
    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_ppl": dataset_loss_and_ppl,}


def main():
    parser = argparse.ArgumentParser(description="Weight matching for GPT2-MoE models on Wikitext103")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first GPT2-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second GPT2-MoE model checkpoint")
    parser.add_argument("--seed", type=int, default =0)
    parser.add_argument("--train-dataset-paths", type=str, default="dataset/wikitext.train**", help="train datset paths (multiple paths)")
    parser.add_argument("--eval-dataset-paths", type=str, default="dataset/wikitext.test**", help="eval dataset paths (multiple paths)")
    parser.add_argument("--batch-size", type=int, default=16, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument("--max-sequence-length", type=int, default=256, help="sequence lenght of model input")
    args = parser.parse_args()
    train_dataset = Dataset.from_parquet(args.train_dataset_paths)
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size,drop_last=True,collate_fn=batch_collate_fn,)
    test_dataset = Dataset.from_parquet(args.eval_dataset_paths)
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batch_size,drop_last=True,collate_fn=batch_collate_fn,)
    # print("args.model_a =", args.model_a)
    # model_dir = os.path.dirname(args.model_a)
    # print("os.path.dirname(args.model_a) =", model_dir)
    # print("\nListing contents of that folder:")
    # pprint.pprint(os.listdir(model_dir))
    config = GPT2Config.from_json_file(os.path.join(os.path.dirname(args.model_a),'config.json'))
    model = FlaxGPT2MoELMHeadModel(config)
    params_a = checkpoints.restore_checkpoint(ckpt_dir=args.model_a, target={"params": model.params})["params"]
    params_b = checkpoints.restore_checkpoint(ckpt_dir=args.model_b, target={"params": model.params})["params"]
    stuff = make_stuff(model = model)
    train_loss_a, train_ppl_a = stuff["dataset_loss_and_ppl"](params_a, train_loader)
    train_loss_b, train_ppl_b = stuff["dataset_loss_and_ppl"](params_b, train_loader)
    test_loss_a, test_ppl_a = stuff["dataset_loss_and_ppl"](params_a, test_loader)
    test_loss_b, test_ppl_b = stuff["dataset_loss_and_ppl"](params_b, test_loader)
    print({
        "train_loss_a": float(train_loss_a),"train_ppl_a": float(train_ppl_a),
        "train_loss_b": float(train_loss_b),"train_ppl_b": float(train_ppl_b),
        "test_loss_a": float(test_loss_a),"test_ppl_a": float(test_ppl_a),
        "test_loss_b": float(test_loss_b),"test_ppl_b": float(test_ppl_b),
    })
    baseline_train_loss = 0.5 * (train_loss_a + train_loss_b) #2.7338707785748593
    print(baseline_train_loss)
    permutation_spec = flax_gpt2_permutation_spec_moe(config =config)
    # Generate all possible expert permutations
    expert_perms = [list(p) for p in itertools.permutations(range(config.num_routed_experts))]
    perm_labels = [str(p) for p in expert_perms]  # e.g., "[0, 1]", "[1, 0]"
    # Define interpolation points
    lambdas = jnp.linspace(0, 1, num=25)
    # Naive interpolation (no permutation)
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_ppl_interp_naive, test_ppl_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
        train_loss, train_ppl = stuff["dataset_loss_and_ppl"](naive_p, train_loader)
        test_loss, test_ppl = stuff["dataset_loss_and_ppl"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_ppl_interp_naive.append(train_ppl)
        test_ppl_interp_naive.append(test_ppl)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_ppl_interp_clever_list, test_ppl_interp_clever_list = [], []
    for pi, label in zip(expert_perms, perm_labels):
        model_b_pi = permute_moe_block(params_b, pi, config)
        # Perform weight matching
        final_permutation = weight_matching(
            random.PRNGKey(args.seed),permutation_spec,flatten_params(params_a),flatten_params(model_b_pi)
        )
        # Apply permutation to Model B_pi
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        # Interpolate between Model A and aligned Model B_pi
        train_loss_interp, test_loss_interp = [], []
        train_ppl_interp, test_ppl_interp = [], []
        for lam in tqdm(lambdas, desc=f"Permuted Interpolation {label}"):
            clever_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(model_b_pi_aligned)))
            train_loss, train_ppl = stuff["dataset_loss_and_ppl"](clever_p, train_loader)
            test_loss, test_ppl =   stuff["dataset_loss_and_ppl"](clever_p, test_loader)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_ppl_interp.append(train_ppl)
            test_ppl_interp.append(test_ppl)
        train_loss_interp_clever_list.append(train_loss_interp)
        test_loss_interp_clever_list.append(test_loss_interp)
        train_ppl_interp_clever_list.append(train_ppl_interp)
        test_ppl_interp_clever_list.append(test_ppl_interp)
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_ppl_interp_naive": train_ppl_interp_naive,
        "test_ppl_interp_naive": test_ppl_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_ppl_interp_clever_list": train_ppl_interp_clever_list,
        "test_ppl_interp_clever_list": test_ppl_interp_clever_list,
        "baseline_train_loss": baseline_train_loss
    }
    results = recursive_to_serializable(results)
    # Validate lengths
    assert len(lambdas) == len(train_loss_interp_naive) == len(test_loss_interp_naive)
    assert all(len(lambdas) == len(tl) for tl in train_loss_interp_clever_list)
    assert all(len(lambdas) == len(tl) for tl in test_loss_interp_clever_list)
    os.makedirs("./plots/wikitext103", exist_ok=True)
    os.makedirs("./results/wikitext103", exist_ok=True)
    # Plot results
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    with open(f'results/wikitext103/[{name_a}+{name_b}].json', 'w') as f:
        json.dump(results, f, indent=4)
    print("Generating plots...")
    loss_fig = plot_interp_loss(
        lambdas,
        train_loss_interp_naive, test_loss_interp_naive,
        train_loss_interp_clever_list, test_loss_interp_clever_list,
        perm_labels
    )
    loss_fig_path = f"./plots/wikitext103/[{name_a}+{name_b}]_weight_matching_interp_loss.png"
    plt.savefig(loss_fig_path, dpi=300)
    plt.close(loss_fig)
    ppl_fig = plot_interp_ppl(
        lambdas,
        train_ppl_interp_naive, test_ppl_interp_naive,
        train_ppl_interp_clever_list, test_ppl_interp_clever_list,
        perm_labels
    )
    ppl_fig_path = f"./plots/wikitext103/[{name_a}+{name_b}]_weight_matching_interp_ppl.png"
    plt.savefig(ppl_fig_path, dpi=300)
    plt.close(ppl_fig)    
    # Log metrics to wandb
    metrics = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_ppl_interp_naive": train_ppl_interp_naive,
        "test_ppl_interp_naive": test_ppl_interp_naive,
    }
    for label, train_loss, test_loss, train_ppl, test_ppl in zip(
        perm_labels, train_loss_interp_clever_list, test_loss_interp_clever_list,
        train_ppl_interp_clever_list, test_ppl_interp_clever_list
    ):
        metrics[f"train_loss_interp_clever_{label}"] = train_loss
        metrics[f"test_loss_interp_clever_{label}"] = test_loss
        metrics[f"train_ppl_interp_clever_{label}"] = train_ppl
        metrics[f"test_ppl_interp_clever_{label}"] = test_ppl
if __name__ == "__main__":
    # Parse command-line arguments
    main()
