import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import copy
import jax
import json
import jax.numpy as jnp
import optax
import torch
import wandb
import itertools 
import numpy as np
from tqdm import tqdm
from jax import random
from datasets import Dataset
import matplotlib.pyplot as plt
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from flax.core.frozen_dict import freeze, unfreeze
from weight_matching import apply_permutation, weight_matching, permute_moe_block
from weight_matching import flax_gpt2_permutation_spec_moe, plot_interp_loss, plot_interp_pbc
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params
import multiprocessing as mp
import os, pprint, pathlib
from datamodule import  JAXTextDataset
from torch.utils.data import  DataLoader
mp.set_start_method("spawn", force=True)
def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)
    
LOG2E = jnp.log2(jnp.e) 
def load_token_file(path):
    with open(path, 'r') as f:
        tokens = [int(t) for t in f.read().split()]
    return np.array(tokens, dtype=np.uint8)
# Use PyTorch DataLoader, returns batches of jax.numpy arrays
def jax_collate(batch):
    return {key: jnp.stack([item[key] for item in batch]) for key in batch[0]}

def make_stuff(model):
    apply_fn = model.__call__
    @jax.jit
    def batch_eval(params, batch):
        labels = batch.pop("labels")
        logits = apply_fn(**batch, params=params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        bpc = loss * LOG2E
        return loss, labels, bpc
    @jax.jit
    def step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 2)
        def loss_fn(params):
            labels = batch["labels"]
            logits = apply_fn(params=params, dropout_rng=dropout_rng, train=True,
                              **{k: v for k, v in batch.items() if k != "labels"})[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        new_state = state.apply_gradients(grads=grads)
        bpc = loss * LOG2E
        metrics = {"loss": loss, "bpc": bpc, "logits": logits, "labels": batch["labels"]}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return new_state, metrics, new_dropout_rng

    def dataset_loss_and_pbc(params, dataloader):
        """
        Iterate over `dataloader`, compute mean loss and BPC.
        Assumes evaluation is on a single device.
        """
        total_loss, total_bpc, total_tok = 0.0, 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for batch in pbar:
            loss, labels, bpc = batch_eval(params, dict(batch))
            ntok = jnp.sum(labels != -100)
            total_loss += loss * ntok
            total_bpc += bpc * ntok
            total_tok += ntok
            pbar.set_postfix(bpc=f"{bpc:.4f}")
        mean_bpc = (total_bpc / total_tok).item()
        mean_loss =(total_loss/ total_tok).item()
        return mean_loss, mean_bpc  # return twice for API parity (loss, bpc)

    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_pbc": dataset_loss_and_pbc,}


def main():
    parser = argparse.ArgumentParser(description="Weight matching for GPT2-MoE models on enwik8")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first GPT2-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second GPT2-MoE model checkpoint")
    parser.add_argument("--seed", type=int, default =0)
    parser.add_argument("--data-path", type=str, default="data/enwik8", help="train datset paths (multiple paths)")
    parser.add_argument("--batch-size", type=int, default=48, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument("--max-sequence-length", type=int, default=256, help="sequence lenght of model input")
    args = parser.parse_args()
    train_data = load_token_file(os.path.join(args.data_path,'train.txt'))
    val_data   = load_token_file(os.path.join(args.data_path,'valid.txt'))
    test_data  = load_token_file(os.path.join(args.data_path,'test.txt'))
    # Create loaders
    train_loader = DataLoader(JAXTextDataset(train_data, args.max_sequence_length), batch_size=args.batch_size, shuffle=True, collate_fn=jax_collate)
    train_loader   = DataLoader(JAXTextDataset(val_data, args.max_sequence_length), batch_size=args.batch_size, collate_fn=jax_collate)
    test_loader  = DataLoader(JAXTextDataset(test_data, args.max_sequence_length), batch_size=args.batch_size, collate_fn=jax_collate)
    config = GPT2Config.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    model = FlaxGPT2MoELMHeadModel(config)
    params_a = checkpoints.restore_checkpoint(ckpt_dir=args.model_a, target={"params": model.params})["params"]
    params_b = checkpoints.restore_checkpoint(ckpt_dir=args.model_b, target={"params": model.params})["params"]
    stuff = make_stuff(model = model)
    train_loss_a, train_pbc_a = stuff["dataset_loss_and_pbc"](params_a, train_loader)
    train_loss_b, train_pbc_b = stuff["dataset_loss_and_pbc"](params_b, train_loader)
    test_loss_a, test_pbc_a = stuff["dataset_loss_and_pbc"](params_a, test_loader)
    test_loss_b, test_pbc_b = stuff["dataset_loss_and_pbc"](params_b, test_loader)
    print({
        "train_loss_a": float(train_loss_a),"train_pbc_a": float(train_pbc_a),
        "train_loss_b": float(train_loss_b),"train_pbc_b": float(train_pbc_b),
        "test_loss_a": float(test_loss_a),"test_pbc_a": float(test_pbc_a),
        "test_loss_b": float(test_loss_b),"test_pbc_b": float(test_pbc_b),
    })
    baseline_train_loss = 0.5 * (train_loss_a + train_loss_b) #2.7338707785748593
    print(baseline_train_loss)
    permutation_spec = flax_gpt2_permutation_spec_moe(config =config)
    # Generate all possible expert permutations
    expert_perms = [list(p) for p in itertools.permutations(range(config.num_routed_experts))]
    perm_labels = [str(p) for p in expert_perms]  # e.g., "[0, 1]", "[1, 0]"
    # Define interpolation points
    lambdas = jnp.linspace(0, 1, num=25)
    # Naive interpolation (no permutation)
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_pbc_interp_naive, test_pbc_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
        train_loss, train_pbc = stuff["dataset_loss_and_pbc"](naive_p, train_loader)
        test_loss, test_pbc = stuff["dataset_loss_and_pbc"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_pbc_interp_naive.append(train_pbc)
        test_pbc_interp_naive.append(test_pbc)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_pbc_interp_clever_list, test_pbc_interp_clever_list = [], []
    for pi, label in zip(expert_perms, perm_labels):
        model_b_pi = permute_moe_block(params_b, pi, config)
        # Perform weight matching
        final_permutation = weight_matching(
            random.PRNGKey(args.seed),permutation_spec,flatten_params(params_a),flatten_params(model_b_pi)
        )
        # Apply permutation to Model B_pi
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        # Interpolate between Model A and aligned Model B_pi
        train_loss_interp, test_loss_interp = [], []
        train_pbc_interp, test_pbc_interp = [], []
        for lam in tqdm(lambdas, desc=f"Permuted Interpolation {label}"):
            clever_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(model_b_pi_aligned)))
            train_loss, train_pbc = stuff["dataset_loss_and_pbc"](clever_p, train_loader)
            test_loss, test_pbc =   stuff["dataset_loss_and_pbc"](clever_p, test_loader)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_pbc_interp.append(train_pbc)
            test_pbc_interp.append(test_pbc)
        train_loss_interp_clever_list.append(train_loss_interp)
        test_loss_interp_clever_list.append(test_loss_interp)
        train_pbc_interp_clever_list.append(train_pbc_interp)
        test_pbc_interp_clever_list.append(test_pbc_interp)
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_pbc_interp_naive": train_pbc_interp_naive,
        "test_pbc_interp_naive": test_pbc_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_pbc_interp_clever_list": train_pbc_interp_clever_list,
        "test_pbc_interp_clever_list": test_pbc_interp_clever_list,
        "baseline_train_loss": baseline_train_loss
    }
    results = recursive_to_serializable(results)
    # Validate lengths
    assert len(lambdas) == len(train_loss_interp_naive) == len(test_loss_interp_naive)
    assert all(len(lambdas) == len(tl) for tl in train_loss_interp_clever_list)
    assert all(len(lambdas) == len(tl) for tl in test_loss_interp_clever_list)
    os.makedirs("./plots/enwik8", exist_ok=True)
    os.makedirs("./results/enwik8", exist_ok=True)
    # Plot results
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    with open(f'results/enwik8/[{name_a}+{name_b}].json', 'w') as f:
        json.dump(results, f, indent=4)
    print("Generating plots...")
    loss_fig = plot_interp_loss(
        lambdas,
        train_loss_interp_naive, test_loss_interp_naive,
        train_loss_interp_clever_list, test_loss_interp_clever_list,
        perm_labels
    )
    loss_fig_path = f"./plots/enwik8/[{name_a}+{name_b}]_weight_matching_interp_loss.png"
    plt.savefig(loss_fig_path, dpi=300)
    plt.close(loss_fig)
    ppl_fig = plot_interp_pbc(
        lambdas,
        train_pbc_interp_naive, test_pbc_interp_naive,
        train_pbc_interp_clever_list, test_pbc_interp_clever_list,
        perm_labels
    )
    ppl_fig_path = f"./plots/enwik8/[{name_a}+{name_b}]_weight_matching_interp_pbc.png"
    plt.savefig(ppl_fig_path, dpi=300)
    plt.close(ppl_fig)    

if __name__ == "__main__":
    # Parse command-line arguments
    main()
