import argparse
import os
import jax
import flax
import copy
import torch
import wandb
import optax
import itertools
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
from jax import random, vmap
from flax import linen as nn
from typing import Any, Dict, List
from jax import random, vmap
from flax.jax_utils import replicate, unreplicate
from flax.core.frozen_dict import freeze, unfreeze
from flax.training import checkpoints, train_state
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.training.common_utils import get_metrics, onehot, shard
from utils import flatten_params, lerp, unflatten_params
from lmc_model import  LMCFlaxViTForImageClassification, print_model, print_model_with_prefix
from transformers.models.vit.modeling_flax_vit import ViTConfig
from matching_utils import weight_matching_attn
from datasets import build_dataset
from flax.serialization import from_bytes
import multiprocessing as mp
import matplotlib.pyplot as plt
import numpy as np
import json
mp.set_start_method("spawn", force=True)
def load_flax_params(checkpoint_dir, target):
    msgpack_path = os.path.join(checkpoint_dir, "flax_model.msgpack")
    with open(msgpack_path, "rb") as f:
        packed_bytes = f.read()
    params = from_bytes(target, packed_bytes)
    return params
def imagenet_data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val
def sample_random_batch(dataloader, seed=0):
    """Return (images, labels) from a random batch without materializing the loader."""
    n = len(dataloader)  # DataLoader should define __len__
    idx = int(np.random.default_rng(seed).integers(n))
    return next(itertools.islice(iter(dataloader), idx, None))
def accuracy(logits, labels, topk=(1,)):
    maxk = max(topk)
    batch_size = labels.shape[0]
    topk_preds = jnp.argsort(logits, axis=-1)[:, -maxk:][:, ::-1]  # Top-k predictions
    res = []
    for k in topk:
        correct = (topk_preds[:, :k] == labels[:, None])
        correct = jnp.any(correct, axis=1)
        correct = jnp.sum(correct)
        res.append(100.0 * correct / batch_size)
    return res  # list of [acc@1, acc@5]
def prepare_image_batch(images:torch.Tensor,labels:torch.Tensor) -> Dict[str, Any]:
    images, labels =  jnp.array(images),jnp.array(labels)
    return {'images': shard(images),'labels': shard(labels)}
def make_stuff(model):
    apply_fn = model.__call__
    def batch_eval(params, batch):
        outputs = apply_fn(params=params,pixel_values=batch["images"],train=False,)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        acc1, acc5 = accuracy(logits, batch["labels"], topk=(1, 5))
        metrics = {"loss": loss,"acc1": acc1,"acc5": acc5,}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_batch_eval = jax.pmap(batch_eval, axis_name="batch")
    def dataset_loss_and_acc(params, dataloader):
        """Evaluate over the dataloader on all devices"""
        eval_results = []
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        rep_params = replicate(params)
        for step_idx, (images, labels) in enumerate(pbar):
            batch = prepare_image_batch(images, labels)
            eval_metric = parallel_batch_eval(rep_params, batch)
            eval_results.append(copy.deepcopy(eval_metric))
        eval_metrics = get_metrics(eval_results)
        eval_metrics = unreplicate(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
        loss, acc1 = float(eval_metrics["loss"]), float(eval_metrics["acc1"])
        return loss, acc1
    def get_attention_inputs(params, dataloader):
        activation = {}
        for images, labels in dataloader:
            images_host = jnp.array(images)
            _ = apply_fn(params=params,pixel_values=images_host,train=False,attention_input=activation,)
            break
        return activation
    return {"batch_eval": batch_eval, "dataset_loss_and_acc": dataset_loss_and_acc, "get_attention_inputs": get_attention_inputs}
def compute_interpolation(params_a, params_b_target, lambdas, stuff, test_loader, desc="Interpolation"):
    train_loss_interp, test_loss_interp = [], []
    train_acc_interp, test_acc_interp = [], []
    for lam in tqdm(lambdas, desc=desc):
        p_interp = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b_target)))
        test_loss, test_acc = stuff["dataset_loss_and_acc"](p_interp, test_loader)
        test_loss_interp.append(test_loss)
        test_acc_interp.append(test_acc)
    return {"Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],"Test Acc": [float(f"{x:.4f}") for x in test_acc_interp]}
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned GPT2 model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned GPT@ model checkpoint")
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    parser.add_argument("--save-path", type=str, default="/", help="Path to plot directory")
    args = parser.parse_args()
    config = ViTConfig.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    lmc_config = ViTConfig(**config.lmc_config)
    config.lmc_config = lmc_config
    model = LMCFlaxViTForImageClassification(config,input_shape=(1,config.image_size, config.image_size, config.num_channels),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    print_model(model.params)
    if os.path.exists(args.model_a) and os.path.exists(args.model_b):
        params_a = load_flax_params(args.model_a,copy.deepcopy(model.params))
        params_b = load_flax_params(args.model_b,copy.deepcopy(model.params))
        # params_a = checkpoints.restore_checkpoint(ckpt_dir=args.model_a, target={"params": model.params})["params"]
        # params_b = checkpoints.restore_checkpoint(ckpt_dir=args.model_b, target={"params": model.params})["params"]
    else:
        raise FileNotFoundError(f"Checkpoint path does not exist")
    stuff = make_stuff(model = model)   
    train_loader, val_loader = imagenet_data_loader(args)
    lambdas = jnp.linspace(0, 1, num=25)
    rng = random.PRNGKey(args.seed)
    # Compute naive interpolation
    activation = stuff["get_attention_inputs"](params_b, val_loader) # None #
    naive_results = compute_interpolation(params_a, params_b, lambdas, stuff, val_loader, desc="Naive Interpolation")
    print(json.dumps({"Naive": naive_results}, indent=2))
    all_results = {"Naive": naive_results}
    # Compute weight matching interpolations for each method
    aligned_models = weight_matching_attn(rng, params_a, params_b, activation, config)
    for method, params_b_aligned in aligned_models.items():
        method_results = compute_interpolation(params_a, params_b_aligned, lambdas, stuff, val_loader, desc=f"{method} Interpolation")
        all_results[method] = method_results
        print(json.dumps({method: method_results}, indent=2))
    # Save directories
    os.makedirs(f"./plots/imagenet", exist_ok=True)
    os.makedirs(f"./results/imagenet", exist_ok=True)
    # Save results JSON
    print("Save List of Values...")
    name_a = os.path.basename(os.path.dirname(args.model_a).rstrip("/"))
    name_b = os.path.basename(os.path.dirname(args.model_b).rstrip("/"))
    result_path = f'results/imagenet/{name_a}+{name_b}.json'
    with open(result_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    # Plot
    print("Generating plots...")
    plot_path = f"./plots/imagenet/{name_a}+{name_b}.pdf"
    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
        'font.size': 13,
        'axes.labelsize': 16,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
    })
    plt.style.use('tableau-colorblind10')  
    num_points = len(all_results["Naive"]["Test Loss"])  
    lambda_values = np.linspace(0, 1, num_points)
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))
    metrics = ["Test Loss", "Test Acc", "Test Loss", "Test Acc",]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method in all_results:
            ax.plot(lambda_values, all_results[method][metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model 1", r"$\lambda$", "Model 2"])
        ax.set_ylabel(metric)
        ax.legend(loc='best')  
    plt.tight_layout()
    plt.savefig(plot_path)
    plt.close() 

if __name__ == "__main__":
    main()