import argparse
import os
import jax
import copy
import json
import torch
import optax
import numpy as np 
from tqdm import tqdm
import jax.numpy as jnp
from jax import random, vmap
from typing import Any, Dict, List
from flax.core import freeze, unfreeze
from flax.serialization import from_bytes
from utils import flatten_params, lerp, unflatten_params
from flax.traverse_util import flatten_dict, unflatten_dict
from lmc_model import LMCFlaxGPT2LMHeadModel, print_model
from transformers.models.gpt2.modeling_flax_gpt2 import GPT2Config, FlaxGPT2LMHeadModel
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from matching_utils import weight_matching
from data_utils import get_lm_corpus
import matplotlib.pyplot as plt
import numpy as np

def load_flax_params(checkpoint_dir, target):
    """
    Loads Flax parameters from a msgpack file.

    Args:
        checkpoint_dir (str): path to folder containing flax_model.msgpack
        target (PyTree): a target object matching the structure of your model's parameters

    Returns:
        PyTree: deserialized parameters
    """
    msgpack_path = os.path.join(checkpoint_dir, "flax_model.msgpack")

    with open(msgpack_path, "rb") as f:
        packed_bytes = f.read()

    params = from_bytes(target, packed_bytes)
    return params
def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    """
    Convert and shard a language modeling batch from PyTorch to JAX.
    Args:
        data (torch.Tensor): Input data of shape (seq_len, batch)
        target (torch.Tensor): Target data of shape (seq_len, batch)
    Returns:
        Dict[str, jnp.ndarray]: Dict with 'data' and 'target', both sharded
            with shape (n_devices, batch_per_device, seq_len)
    """
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': input_ids,'target': target}
def make_stuff(model):
    apply_fn = model.__call__
    @jax.jit
    def batch_eval(params, batch, attention_input=None):
        labels = batch.pop("target")  # destructive, just like your train_step
        logits = apply_fn(**batch, params=params, train=False,attention_input=attention_input)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        return logits, labels, loss
    @jax.jit
    def step(state: train_state.TrainState, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 2)
        def loss_fn(params):
            labels = batch.pop("target")
            logits = apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        new_state = state.apply_gradients(grads=grads)
        metrics = {"batch_loss": loss,"logits": logits,}
        return new_state, metrics, new_dropout_rng
    def dataset_loss_and_ppl(params, dataloader):
        """
        Iterate once over `dataloader`, pplumulate token-level CE, and
        return (mean_loss, perplexity).  Works on **one device**.
        """
        total_loss, total_tok = 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for eval_data, eval_target, _ in pbar:
            eval_batch = prepare_lm_batch(eval_data, eval_target)
            _, labels, loss = batch_eval(params,eval_batch)
            ntok = jnp.sum(labels != -100)
            total_loss += loss * ntok
            total_tok += ntok
            pbar.set_postfix(loss=f"{loss:.4f}", ppl=f"{np.exp(loss):.2f}")
        mean_loss = (total_loss / total_tok).item()
        ppl = jnp.exp(mean_loss).item()
        return mean_loss, ppl
    def get_attention_inputs(params, dataloader):
        activation = {}
        for data, target, seq_len in dataloader:
            batch = prepare_lm_batch(data, target)
            labels = batch.pop("target")  # destructive, just like your train_step
            logits = apply_fn(**batch, params=params, train=False,attention_input=activation)
            break
        return activation
    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_ppl": dataset_loss_and_ppl,"get_attention_inputs":get_attention_inputs}
def compute_interpolation(params_a, params_b_target, lambdas, stuff, val_ds, test_ds, desc="Interpolation"):
    train_loss_interp, test_loss_interp = [], []
    train_ppl_interp, test_ppl_interp = [], []
    for lam in tqdm(lambdas, desc=desc):
        p_interp = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b_target)))
        train_loss, train_ppl = stuff["dataset_loss_and_ppl"](p_interp, val_ds)
        test_loss, test_ppl = stuff["dataset_loss_and_ppl"](p_interp, test_ds)
        test_loss, test_ppl = train_loss, train_ppl
        train_loss_interp.append(train_loss)
        test_loss_interp.append(test_loss)
        train_ppl_interp.append(train_ppl)
        test_ppl_interp.append(test_ppl)
    return {
        "Val Loss": [float(f"{x:.4f}") for x in train_loss_interp],
        "Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],
        "Val PPL": [float(f"{x:.4f}") for x in train_ppl_interp],
        "Test PPL": [float(f"{x:.4f}") for x in test_ppl_interp]
    }
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned GPT2 model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned GPT@ model checkpoint")
    parser.add_argument("--seed", type=int, default =0)
    parser.add_argument("--data-path", type=str, default="./data/lm1b", help="train datset paths (multiple paths)")
    parser.add_argument('--dataset', type=str, default='lm1b',choices=['wt103', 'lm1b', 'enwik8', 'text8'],help='dataset name')
    parser.add_argument("--batch-size", type=int, default=24, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument('--tgt_len', type=int, default=256,help='number of tokens to predict')
    parser.add_argument('--eval_tgt_len', type=int, default=256,help='number of tokens to predict for evaluation')
    parser.add_argument('--ext_len', type=int, default=0,help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=0,help='length of the retained previous heads')
    parser.add_argument("--save-path", type=str, default="/", help="Path to plot directory")
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    args = parser.parse_args()
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    eval_batch_size = 12
    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    val_ds = va_iter
    test_ds = te_iter
    config = GPT2Config.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    lmc_config = GPT2Config(**config.lmc_config)
    config.lmc_config = lmc_config
    model = LMCFlaxGPT2LMHeadModel(config,input_shape=(1, args.tgt_len),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    print_model(model.params)
    if os.path.exists(args.model_a) and os.path.exists(args.model_b):
        params_a = load_flax_params(args.model_a,copy.deepcopy(model.params))
        params_b = load_flax_params(args.model_b,copy.deepcopy(model.params))
    else:
        raise FileNotFoundError(f"Checkpoint path does not exist")
    stuff = make_stuff(model = model)   
    lambdas = jnp.linspace(0, 1, num=25)
    rng = random.PRNGKey(args.seed)
    # Compute naive interpolation
    activation = stuff["get_attention_inputs"](params_b, va_iter) # None #
    naive_results = compute_interpolation(params_a, params_b, lambdas, stuff, val_ds, test_ds, desc="Naive Interpolation")
    print(json.dumps({"Naive": naive_results}, indent=2))
    all_results = {"Naive": naive_results}
    # Compute weight matching interpolations for each method
    aligned_models = weight_matching(rng, params_a, params_b, activation, config)
    for method, params_b_aligned in aligned_models.items():
        method_results = compute_interpolation(params_a, params_b_aligned, lambdas, stuff, val_ds, test_ds, desc=f"{method} Interpolation")
        all_results[method] = method_results
        print(json.dumps({method: method_results}, indent=2))
    # Save directories
    os.makedirs(f"./plots/{args.dataset}", exist_ok=True)
    os.makedirs(f"./results/{args.dataset}", exist_ok=True)
    # Save results JSON
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    result_path = f'results/{args.dataset}/[{name_a}+{name_b}].json'
    with open(result_path, 'w') as f:
        json.dump(all_results, f, indent=2) 
    # Plot
    print("Generating plots...")
    plot_path = f"./plots/{args.dataset}/[{name_a}+{name_b}].pdf"
    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
        'font.size': 13,
        'axes.labelsize': 16,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
    })
    plt.style.use('tableau-colorblind10')  
    num_points = len(all_results["Naive"]["Val Loss"])  
    lambda_values = np.linspace(0, 1, num_points)
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))  # 1 row, 2 columns
    metrics = ["Val Loss", "Test Loss"]
    custom_colors = [
        "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
        "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
    ]
    for idx, metric in enumerate(metrics):
        ax = axs[idx]
        for j, method in enumerate(all_results):   # Naive + aligned methods
            ax.plot(lambda_values,
                    all_results[method][metric],
                    label="Naive" if method == "Naive" else "Match",
                    color=custom_colors[j % len(custom_colors)])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model 1", r"$\lambda$", "Model 2"])
        ax.set_ylabel(metric)
        ax.legend(loc='best')

    plt.tight_layout()
    plt.savefig(plot_path.replace(".pdf", "_row0.pdf"))
    plt.close()

if __name__ == "__main__":
    main()