import os
import jax 
import optax
import copy
import pickle 
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import json 
from jax import random
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.tree_util import tree_map
from jax.lax import stop_gradient
from dataclasses import dataclass, asdict
from jax import jit, random, value_and_grad
from flax.training.train_state import TrainState
from flax.core import freeze, unfreeze
from flax.serialization import from_bytes
from datasets import Dataset
import torch
from flax.training.common_utils import get_metrics, onehot, shard
from transformers.models.bert.modeling_flax_bert import FlaxBertForSequenceClassification, BertConfig
from flax.training import checkpoints, train_state

vocab_size = 15000
max_seq_len = 100
PAD = 0
UNK = 1

def lerp(lam, t1, t2):
  return tree_map(lambda a, b: (1 - lam) * a + lam * b, t1, t2)

def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)
def batch_collate_fn(batch):
    batch_dict = {key: [example[key] for example in batch] for key in batch[0]}
    result = {}
    for key, value in batch_dict.items():
        try:
            result[key] = jnp.array(value)
        except TypeError:
            result[key] = value
    return result

def split_batch(batch):
    labels = batch["label"]
    inputs = {k: v for k, v in batch.items() if k != "label"}
    return inputs, labels

def make_stuff(model, num_classes):
    apply_fn = model.__call__  # BERT-style

    @jit
    def batch_eval(params, batch):
        inputs, labels = split_batch(batch)
        logits = apply_fn(**inputs, params=params, train=False)[0]  # Get first element if tuple (logits, ...)
        y_onehot = onehot(labels, logits.shape[-1])
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y_onehot))
        predictions = jnp.argmax(logits, axis=-1)
        accuracy = jnp.mean(predictions == labels)
        metrics = {"loss": loss, "accuracy": accuracy}
        return metrics

    @jit
    def step(train_state: TrainState, batch):
        def loss_fn(params):
            inputs, labels = split_batch(batch)
            logits = apply_fn(**inputs, params=params, train=True)[0]
            y_onehot = onehot(labels, logits.shape[-1])
            loss = jnp.mean(optax.softmax_cross_entropy(logits, y_onehot))
            accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
            return loss, {"accuracy": accuracy}
        (loss, metrics), grads = value_and_grad(loss_fn, has_aux=True)(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state, {"loss": loss, **metrics}
    def dataset_loss_and_accuracy(params, dataloader):
        total_loss = 0.0
        total_accuracy = 0.0
        total_count = 0
        pbar = tqdm(dataloader, desc="Evaluating")

        for batch in pbar:
            batch = {k: jax.device_put(v) for k, v in batch.items()}
            metrics = batch_eval(params, batch)
            batch_size = batch["input_ids"].shape[0]
            total_loss += float(metrics["loss"]) * batch_size
            total_accuracy += float(metrics["accuracy"]) * batch_size
            total_count += batch_size
            pbar.set_postfix({"loss": f"{metrics['loss']:.4f}","acc": f"{metrics['accuracy']:.4f}"})
        avg_loss = total_loss / total_count
        avg_accuracy = total_accuracy / total_count
        return avg_loss, avg_accuracy
    return { "batch_eval": batch_eval,"step": step,"dataset_loss_and_accuracy": dataset_loss_and_accuracy,}
def main():
    parser = argparse.ArgumentParser(description="Weight matching for Transformer-MoE models on dbpedia")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first Transformer-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second Transformer-MoE model checkpoint")
    parser.add_argument("--train-dataset-paths", type=str, required=True)
    parser.add_argument("--eval-dataset-paths", type=str, required=True)    
    parser.add_argument("--max-sequence-length", type=int, default=256)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
    parser.add_argument("--batch-size", type=int, default = 16)
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    args = parser.parse_args()
    train_dataset = Dataset.from_parquet(args.train_dataset_paths)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, drop_last=True, collate_fn=batch_collate_fn
    )

    eval_dataset = Dataset.from_parquet(args.eval_dataset_paths)
    test_loader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=args.batch_size, drop_last=False, collate_fn=batch_collate_fn
    )
    model_config = BertConfig.from_pretrained("bert-base-uncased",num_labels=219)
    model = FlaxBertForSequenceClassification(model_config,input_shape=(args.batch_size,args.max_sequence_length),seed=0, dtype=jnp.dtype(args.dtype),)
    params_a = checkpoints.restore_checkpoint(ckpt_dir=args.model_a, target={"params": model.params})["params"]
    params_b = checkpoints.restore_checkpoint(ckpt_dir=args.model_b, target={"params": model.params})["params"]
    # print_model(model.params)
    stuff = make_stuff(model = model,num_classes = 219)
    train_loss_a, train_accuracy_a = stuff["dataset_loss_and_accuracy"](params_a, train_loader)
    train_loss_b, train_accuracy_b = stuff["dataset_loss_and_accuracy"](params_b, train_loader)
    test_loss_a, test_accuracy_a = stuff["dataset_loss_and_accuracy"](params_a, test_loader)
    test_loss_b, test_accuracy_b = stuff["dataset_loss_and_accuracy"](params_b, test_loader)
    # Print results
    print({
        "train_loss_a": float(train_loss_a),
        "train_accuracy_a": float(train_accuracy_a),
        "train_loss_b": float(train_loss_b),
        "train_accuracy_b": float(train_accuracy_b),
        "test_loss_a": float(test_loss_a),
        "test_accuracy_a": float(test_accuracy_a),
        "test_loss_b": float(test_loss_b),
        "test_accuracy_b": float(test_accuracy_b),
    })

    # Evaluate base models
    baseline_train_loss = 0.5 * (train_loss_a + train_loss_b)
    print(baseline_train_loss)
    # Generate all possible expert permutations
    # expert_perms = [list(p) for p in itertools.permutations(range(config.num_experts))]
    # perm_labels = [str(p) for p in expert_perms]  # e.g., "[0, 1]", "[1, 0]"
    # Define interpolation points
    lambdas = jnp.linspace(0, 1, num=25)
    # Naive interpolation (no permutation)
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_acc_interp_naive, test_acc_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
        train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_loader)
        test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_acc_interp_naive.append(train_acc)
        test_acc_interp_naive.append(test_acc)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_acc_interp_clever_list, test_acc_interp_clever_list = [], []
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_acc_interp_naive": train_acc_interp_naive,
        "test_acc_interp_naive": test_acc_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_acc_interp_clever_list": train_acc_interp_clever_list,
        "test_acc_interp_clever_list": test_acc_interp_clever_list,
        "baseline_train_loss": baseline_train_loss
    }
    results = recursive_to_serializable(results)
    # Validate lengths
    assert len(lambdas) == len(train_loss_interp_naive) == len(test_loss_interp_naive)
    assert all(len(lambdas) == len(tl) for tl in train_loss_interp_clever_list)
    assert all(len(lambdas) == len(tl) for tl in test_loss_interp_clever_list)
    os.makedirs("./results/dbpedia/", exist_ok=True)
    name_a = os.path.basename(args.model_a.rstrip("/"))
    name_b = os.path.basename(args.model_b.rstrip("/"))
    print("Save List of Values...")
    with open(f'results/dbpedia/[{name_a}+{name_b}].json', 'w') as f:
        json.dump(results, f, indent=4)
    # Plot results

if __name__ == "__main__":
    # Parse command-line arguments
    main()