import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import copy
import jax
import json
import torch
import optax
import itertools
import numpy as np
from jax import jit 
from tqdm import tqdm
from jax import random
import jax.numpy as jnp
from datasets import Dataset
import matplotlib.pyplot as plt
from flax.training.train_state import TrainState
from flax.core.frozen_dict import freeze, unfreeze
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from flax.core.frozen_dict import freeze, unfreeze
from weight_matching_moe import make_stuff, batch_collate_fn
from flax.training import checkpoints, train_state
from finetune_moe import pretrained2finetune_parmas
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--seed", type=int, default = 0)
    parser.add_argument("--train-dataset-paths", type=str, default="dataset/wikitext.train**", help="train datset paths (multiple paths)")
    parser.add_argument("--eval-dataset-paths", type=str, default="dataset/wikitext.test**", help="eval dataset paths (multiple paths)")
    parser.add_argument("--batch-size", type=int, default=16, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument("--max-sequence-length", type=int, default=256, help="sequence lenght of model input")
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32", help="model datatype")
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rng = jax.random.PRNGKey(args.seed)
    train_dataset = Dataset.from_parquet(args.train_dataset_paths)
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size,drop_last=True,collate_fn=batch_collate_fn,)
    test_dataset = Dataset.from_parquet(args.eval_dataset_paths)
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batch_size,drop_last=True,collate_fn=batch_collate_fn,)
    config = GPT2Config.from_pretrained("gpt2")
    pretrained_model = FlaxGPT2LMHeadModel(config,input_shape=(args.batch_size, args.max_sequence_length),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    pretrained_model.params = pretrained_params
    list_val_loss = []
    list_test_loss = []
    list_val_ppl = []
    list_test_ppl = []
    for moe_idx in range(12):
        config = copy.deepcopy(config)
        config.num_routed_experts = 1
        config.num_shared_experts = 0
        config.topk = 1
        config.moe_layer_indices = moe_idx
        finetune_model = FlaxGPT2MoELMHeadModel(config,input_shape=(args.batch_size, args.max_sequence_length),seed=args.seed,dtype=jnp.dtype(args.dtype),)
        finetune_model.params = pretrained2finetune_parmas(pretrained_model.params,finetune_model.params,config)
        stuff = make_stuff(model = finetune_model)
        val_loss , val_ppl  = stuff["dataset_loss_and_ppl"](finetune_model.params, train_loader)
        test_loss, test_ppl = stuff["dataset_loss_and_ppl"](finetune_model.params, test_loader)
        print({"val_loss": float(val_loss),"val_ppl": float(val_ppl),"test_loss": float(test_loss),"test_ppl": float(test_ppl),})
        list_val_loss.append(to_serializable(val_loss))
        list_test_loss.append(to_serializable(test_loss))
        list_val_ppl.append(to_serializable(val_ppl))
        list_test_ppl.append(to_serializable(test_ppl))
    stuff = make_stuff(model = pretrained_model)
    val_loss , val_ppl  = stuff["dataset_loss_and_ppl"](pretrained_model.params, train_loader)
    test_loss, test_ppl = stuff["dataset_loss_and_ppl"](pretrained_model.params, test_loader) 
    print({"val_loss": float(val_loss),"val_ppl": float(val_ppl),"test_loss": float(test_loss),"test_ppl": float(test_ppl),})
    list_val_loss.append(to_serializable(val_loss))
    list_test_loss.append(to_serializable(test_loss))
    list_val_ppl.append(to_serializable(val_ppl))
    list_test_ppl.append(to_serializable(test_ppl))
    np.savez('wikitext103_results.npz', val_loss=list_val_loss, test_loss=list_test_loss, val_ppl = list_val_ppl, test_ppl = list_test_ppl)
    val_loss = list_val_loss
    test_loss = list_test_loss 
    assert len(val_loss) == 13 and len(test_loss) == 13, "Arrays must have 13 values (12 layers + full model)."
    x = np.arange(12)  # Layers 0-11
    plt.rcParams.update({"font.family": "serif",'legend.frameon': False,'lines.linewidth': 2,})
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    # Validation Loss Plot
    axes[0].bar(x, val_loss[:12], color='steelblue')
    axes[0].axhline(y=val_loss[12], color='lightsalmon', linestyle='--',label=f'Pretrained: {val_loss[12]:.4f}')
    axes[0].set_title('Validation Loss', fontsize=14)
    axes[0].set_xlabel('Layer', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_xticks(x)
    axes[0].tick_params(axis='both', labelsize=11)
    axes[0].legend(fontsize=10)
    # Test Loss Plot
    axes[1].bar(x, test_loss[:12], color='steelblue')
    axes[1].axhline(y=test_loss[12], color='lightsalmon', linestyle='--',
                    label=f'Pretrained: {test_loss[12]:.4f}')
    axes[1].set_title('Test Loss', fontsize=14)
    axes[1].set_xlabel('Layer', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].set_xticks(x)
    axes[1].tick_params(axis='both', labelsize=11)
    axes[1].legend(fontsize=10)

    plt.tight_layout()
    output_path = "wikitext103_layer_metric.pdf"
    plt.savefig(output_path)
    print(f"Saved plot to {output_path}")
    plt.show()
if __name__ == "__main__":
    main()