import argparse
import os
import jax
import copy
import json
import torch
import optax
import numpy as np 
from tqdm import tqdm
import jax.numpy as jnp
from jax import random, vmap
from typing import Any, Dict, List
from flax.core import freeze, unfreeze
from flax.serialization import from_bytes
from utils import flatten_params, lerp, unflatten_params
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.llama.configuration_llama import LlamaConfig
from lmc_model import LMCFlaxLlamaForCausalLM, print_model
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from matching_utils import weight_matching
from data_utils import get_lm_corpus
import matplotlib.pyplot as plt
import numpy as np


def load_flax_params(checkpoint_dir, target):
    msgpack_path = os.path.join(checkpoint_dir, "flax_model.msgpack")

    with open(msgpack_path, "rb") as f:
        packed_bytes = f.read()

    params = from_bytes(target, packed_bytes)
    return params
def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': input_ids,'target': target}
def make_stuff(model):
    apply_fn = model.__call__
    @jax.jit
    def batch_eval(params, batch):
        labels = batch.pop("target")  # destructive, just like your train_step
        logits = apply_fn(**batch, params=params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        return logits, labels, loss
    @jax.jit
    def step(state: train_state.TrainState, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 2)
        def loss_fn(params):
            labels = batch.pop("target")
            logits = apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        new_state = state.apply_gradients(grads=grads)
        metrics = {"batch_loss": loss,"logits": logits,}
        return new_state, metrics, new_dropout_rng
    def dataset_loss_and_ppl(params, dataloader):
        """
        Iterate once over `dataloader`, pplumulate token-level CE, and
        return (mean_loss, perplexity).  Works on **one device**.
        """
        total_loss, total_tok = 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for eval_data, eval_target, _ in pbar:
            eval_batch = prepare_lm_batch(eval_data, eval_target)
            _, labels, loss = batch_eval(params,eval_batch)
            ntok = jnp.sum(labels != -100)
            total_loss += loss * ntok
            total_tok += ntok
            pbar.set_postfix(loss=f"{loss:.4f}", ppl=f"{np.exp(loss):.2f}")
        mean_loss = (total_loss / total_tok).item()
        ppl = jnp.exp(mean_loss).item()
        return mean_loss, ppl
    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_ppl": dataset_loss_and_ppl,}
def compute_interpolation(params_a, params_b_target, lambdas, stuff, val_ds, test_ds, desc="Interpolation"):
    train_loss_interp, test_loss_interp = [], []
    train_ppl_interp, test_ppl_interp = [], []
    for lam in tqdm(lambdas, desc=desc):
        p_interp = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b_target)))
        train_loss, train_ppl = stuff["dataset_loss_and_ppl"](p_interp, val_ds)
        test_loss, test_ppl = stuff["dataset_loss_and_ppl"](p_interp, test_ds)
        # test_loss, test_ppl = train_loss, train_ppl
        train_loss_interp.append(train_loss)
        test_loss_interp.append(test_loss)
        train_ppl_interp.append(train_ppl)
        test_ppl_interp.append(test_ppl)
    return {
        "Val Loss": [float(f"{x:.4f}") for x in train_loss_interp],
        "Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],
        "Val PPL": [float(f"{x:.4f}") for x in train_ppl_interp],
        "Test PPL": [float(f"{x:.4f}") for x in test_ppl_interp]
    }
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned GPT2 model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned GPT@ model checkpoint")
    parser.add_argument("--dist",type=str,default="L2",choices=["L2","L1","cosine","corr","spectral","fro_cos",],)
    parser.add_argument("--seed", type=int, default =0)
    parser.add_argument("--data-path", type=str, default="./data/lm1b", help="train datset paths (multiple paths)")
    parser.add_argument('--dataset', type=str, default='lm1b',choices=['wt103', 'lm1b', 'enwik8', 'text8'],help='dataset name')
    parser.add_argument("--batch-size", type=int, default=24, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument('--tgt_len', type=int, default=256,help='number of tokens to predict')
    parser.add_argument('--eval_tgt_len', type=int, default=256,help='number of tokens to predict for evaluation')
    parser.add_argument('--ext_len', type=int, default=0,help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=0,help='length of the retained previous heads')
    parser.add_argument("--save-path", type=str, default="/", help="Path to plot directory")
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    args = parser.parse_args()
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    eval_batch_size = 12
    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    val_ds = va_iter
    test_ds = te_iter
    config = LlamaConfig.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    lmc_config = LlamaConfig(**config.lmc_config)
    config.lmc_config = lmc_config
    model = LMCFlaxLlamaForCausalLM(config,input_shape=(1, args.tgt_len),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    print_model(model.params)
    if os.path.exists(args.model_a) and os.path.exists(args.model_b):
        params_a = load_flax_params(args.model_a,copy.deepcopy(model.params))
        # check_params_nan(params_a, name="params_a")
        params_b = load_flax_params(args.model_b,copy.deepcopy(model.params))
        # check_params_nan(params_b, name="params_b")
    else:
        raise FileNotFoundError(f"Checkpoint path does not exist")
    stuff = make_stuff(model = model)   
    lambdas = jnp.linspace(0, 1, num=25)
    rng = random.PRNGKey(args.seed)
    # Compute naive interpolation
    naive_results = compute_interpolation(params_a, params_b, lambdas, stuff, val_ds, test_ds, desc="Naive Interpolation")
    print(json.dumps({"Naive": naive_results}, indent=2))
    all_results = {"Naive": naive_results}
    # Compute weight matching interpolations for each method
    aligned_models = weight_matching(rng, params_a, params_b, config, args)
    for method, params_b_aligned in aligned_models.items():
        method_results = compute_interpolation(params_a, params_b_aligned, lambdas, stuff, val_ds, test_ds, desc=f"{method} Interpolation")
        all_results[method] = method_results
        print(json.dumps({method: method_results}, indent=2))
    # Save directories
    os.makedirs(f"./plots/llama-{args.dataset}", exist_ok=True)
    os.makedirs(f"./results/llama-{args.dataset}", exist_ok=True)
    # Save results JSON
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    result_path = f'results/llama-{args.dataset}/[{name_a}+{name_b}].json'
    with open(result_path, 'w') as f:
        json.dump(all_results, f, indent=2)


if __name__ == "__main__":
    main()