import argparse
import os
import jax
import flax
import copy
import torch
import wandb
import optax
import itertools
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
from jax import random, vmap
from flax import linen as nn
from typing import Any, Dict, List
from jax import random, vmap
from flax.jax_utils import replicate, unreplicate
from flax.core.frozen_dict import freeze, unfreeze
from flax.training import checkpoints, train_state
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.training.common_utils import get_metrics, onehot, shard
from utils import flatten_params, lerp, unflatten_params
from lmc_model import  LMCFlaxViTForImageClassification, print_model, print_model_with_prefix
from transformers.models.vit.modeling_flax_vit import ViTConfig
from matching_utils import weight_matching_attn
from datasets import build_dataset
from flax.serialization import from_bytes
import multiprocessing as mp
import matplotlib.pyplot as plt
import numpy as np
import json
mp.set_start_method("spawn", force=True)
def load_flax_params(checkpoint_dir, target):
    msgpack_path = os.path.join(checkpoint_dir, "flax_model.msgpack")
    with open(msgpack_path, "rb") as f:
        packed_bytes = f.read()
    params = from_bytes(target, packed_bytes)
    return params
def imagenet_data_loader(args,is_generalization):
    dataset_val, args.nb_classes = build_dataset(is_train=False, args=args,is_generalization=is_generalization)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_val
def sample_random_batch(dataloader, seed=0):
    """Return (images, labels) from a random batch without materializing the loader."""
    n = len(dataloader)  # DataLoader should define __len__
    idx = int(np.random.default_rng(seed).integers(n))
    return next(itertools.islice(iter(dataloader), idx, None))
def accuracy(logits, labels, topk=(1,)):
    maxk = max(topk)
    batch_size = labels.shape[0]
    topk_preds = jnp.argsort(logits, axis=-1)[:, -maxk:][:, ::-1]  # Top-k predictions
    res = []
    for k in topk:
        correct = (topk_preds[:, :k] == labels[:, None])
        correct = jnp.any(correct, axis=1)
        correct = jnp.sum(correct)
        res.append(100.0 * correct / batch_size)
    return res  # list of [acc@1, acc@5]
def prepare_image_batch(images:torch.Tensor,labels:torch.Tensor) -> Dict[str, Any]:
    images, labels =  jnp.array(images),jnp.array(labels)
    return {'images': shard(images),'labels': shard(labels)}
def make_stuff(model):
    apply_fn = model.__call__
    def batch_eval(params, batch):
        outputs = apply_fn(params=params,pixel_values=batch["images"],train=False,)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        acc1, acc5 = accuracy(logits, batch["labels"], topk=(1, 5))
        metrics = {"loss": loss,"acc1": acc1,"acc5": acc5,}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_batch_eval = jax.pmap(batch_eval, axis_name="batch")
    def dataset_loss_and_acc(params, dataloader):
        """Evaluate over the dataloader on all devices"""
        eval_results = []
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        rep_params = replicate(params)
        for step_idx, (images, labels) in enumerate(pbar):
            batch = prepare_image_batch(images, labels)
            eval_metric = parallel_batch_eval(rep_params, batch)
            eval_results.append(copy.deepcopy(eval_metric))
        eval_metrics = get_metrics(eval_results)
        eval_metrics = unreplicate(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
        loss, acc1 = float(eval_metrics["loss"]), float(eval_metrics["acc1"])
        return loss, acc1
    return {"batch_eval": batch_eval, "dataset_loss_and_acc": dataset_loss_and_acc}
def compute_interpolation(params_a, params_b_target, lambdas, stuff, test_loader, desc="Interpolation"):
    train_loss_interp, test_loss_interp = [], []
    train_acc_interp, test_acc_interp = [], []
    for lam in tqdm(lambdas, desc=desc):
        p_interp = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b_target)))
        test_loss, test_acc = stuff["dataset_loss_and_acc"](p_interp, test_loader)
        test_loss_interp.append(test_loss)
        test_acc_interp.append(test_acc)
    return {"Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],"Test Acc": [float(f"{x:.4f}") for x in test_acc_interp]}
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned GPT2 model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned GPT@ model checkpoint")
    parser.add_argument("--dist",type=str,default="L2",choices=["L2","L1","cosine","corr","spectral","fro_cos",],)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--data-path", type=str, default=None)
    parser.add_argument("--data-original", type=str, required=True)
    parser.add_argument("--data-generalization", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    parser.add_argument("--save-path", type=str, default="/", help="Path to plot directory")
    args = parser.parse_args()
    config = ViTConfig.from_json_file(os.path.join(os.path.dirname(args.model_a).rstrip("/"),'config.json'))
    lmc_config = ViTConfig(**config.lmc_config)
    config.lmc_config = lmc_config
    model = LMCFlaxViTForImageClassification(config,input_shape=(1,config.image_size, config.image_size, config.num_channels),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    print_model(model.params)
    if os.path.exists(args.model_a) and os.path.exists(args.model_b):
        params_a = load_flax_params(args.model_a,copy.deepcopy(model.params))
        params_b = load_flax_params(args.model_b,copy.deepcopy(model.params))
    else:
        raise FileNotFoundError(f"Checkpoint path does not exist")
    stuff = make_stuff(model = model)   
    args.data_path = args.data_original
    original_val_loader = imagenet_data_loader(args,False)
    args.data_path = args.data_generalization
    generalization_val_loader = imagenet_data_loader(args,True)
    lambdas = jnp.linspace(0, 1, num=3)
    rng = random.PRNGKey(args.seed)
    # Compute naive interpolation
    # all_results = {"Naive": naive_results}
    naive_original_results = compute_interpolation(params_a, params_b, lambdas, stuff, original_val_loader, desc="Naive Original Interpolation")
    print(json.dumps({"Naive-Original": naive_original_results}, indent=2))
    naive_original_results = compute_interpolation(params_a, params_b, lambdas, stuff, generalization_val_loader, desc="Naive Generalization Interpolation")
    print(json.dumps({"Naive-Generalization": naive_original_results}, indent=2))
    # Compute weight matching interpolations for each method
    aligned_models = weight_matching_attn(rng, params_a, params_b, config, args)
    for method, params_b_aligned in aligned_models.items():
        original_method_results = compute_interpolation(params_a, params_b_aligned, lambdas, stuff, original_val_loader, desc=f"{method} Interpolation")
        print(json.dumps({method+"original": original_method_results}, indent=2))
        generalization_method_results = compute_interpolation(params_a, params_b_aligned, lambdas, stuff, generalization_val_loader, desc=f"{method} Interpolation")
        print(json.dumps({method+"generalization": generalization_method_results}, indent=2))


if __name__ == "__main__":
    main()