import copy
import time

import numpy as np
import torch
import tqdm

from typing import Literal
from src.models.heads import get_classification_head
from src.models.modeling import ImageClassifier
from src.datasets.registry import get_dataset
from src.models.task_vectors import _Checkpoint, _TaskVector
from src.utils import utils
from torch.cuda.amp import autocast
from torchinfo import summary
from torchinfo.model_statistics import ModelStatistics
from onnxsim import simplify
import onnx

def get_summary(image_encoder, dataset_name, args, depth: int= 3, verbose: Literal[0, 1, 2]=1) -> ModelStatistics:
    model = ImageClassifier(image_encoder, get_classification_head(args, dataset_name))
    model.to(args.device)
    model.eval()

    dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
    input_tensor, label = dataset.train_dataset[0]
    input_size = (1, *input_tensor.shape)

    print("Model parameters (including classification_head):", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")

    # Explicitly summarize the classification head as well
    summary(model.classification_head, input_size=(1, 512), verbose=2)

    # Return the SummaryStatistics object without the classification head
    return summary(model, input_size=input_size, verbose=verbose, depth=depth)

def export_to_onnx(image_encoder, dataset_name, args, onnx_path):
    model = ImageClassifier(image_encoder, get_classification_head(args, dataset_name))
    model.to(args.device)
    model.eval()

    dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
    input_tensor, label = dataset.train_dataset[0]
    input_size = (1, *input_tensor.shape)

    input = input_tensor.unsqueeze(0).to(torch.float32).to(args.device)

    torch.onnx.export(
    model,
    input,
    onnx_path,
    export_params=True,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}  # Allow dynamic batch size
    )
    
    # Simplify the ONNX model
    print("Simplifying the ONNX model...")
    onnx_model = onnx.load(onnx_path)
    simplified_model, check = simplify(onnx_model)

    if check:
        print("Simplified ONNX model successfully!")
        # Save the simplified model
        simplified_onnx_path = onnx_path.replace(".onnx", "_simplified.onnx")
        onnx.save(simplified_model, simplified_onnx_path)
        print(f"Simplified ONNX model saved to: {simplified_onnx_path}")
    else:
        print("Failed to simplify the ONNX model.")





    
