import optuna
import optuna.visualization as vis

import random
from torch.utils.data import Subset, DataLoader

import torch.distributed.launch
import argparse
import torch
import torch.backends.cudnn as cudnn
from pathlib import Path
from timm.models import create_model
from datasets import build_dataset
from engine_for_finetuning import evaluate, evaluate_snn
import utils
import trans_utils
import model_eva
import model_vit

def get_args():
    parser = argparse.ArgumentParser()
    # Model parameters
    parser.add_argument('--model', default='eva_g_patch14', type=str, metavar='MODEL',help='Name of model to train')
    parser.add_argument('--input_size', default=224, type=int,help='images input size')
    parser.add_argument('--nb_classes', default=1000, type=int,help='number of the classification types')
    parser.add_argument('--model_path', default='')
    parser.add_argument('--percent', default=0.99, type=float)
    parser.add_argument('--monitor', default=True, type=bool)

    # Dataset parameters
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--eval_data_path', default='../datasets/val', type=str,help='dataset path for evaluation')
    parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
    parser.add_argument('--data_set', default='image_folder', choices=['CIFAR10','CIFAR100', 'IMNET', 'image_folder'],type=str, help='ImageNet dataset path')
    parser.add_argument('--output_dir', default='../models/threshold',help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--savename', default='test', type=str)
    
    # Mode parameter
    parser.add_argument('--test_mode', default='ann',choices=['ann', 'for_v', 'snn'], help="test mode")
    parser.add_argument('--test_T', default = 1, type=int)
    
    # SFN parameter
    parser.add_argument('--linear_num', default = 8, type=int)
    parser.add_argument('--qkv_num', default = 8, type=int)
    parser.add_argument('--softmax_num', default = 8, type=int)
    parser.add_argument('--softmax_p', default = 0.93 / 263, type=float)
    
    known_args, _ = parser.parse_known_args()

    return parser.parse_args()

def get_partial_loader(dataset, batch_size=64, ratio=0.02, shuffle=True, seed=0):
    """
    从 dataset 中随机选取 ratio 比例的样本，返回一个小的 DataLoader
    """
    dataset_size = len(dataset)
    selected_size = int(dataset_size * ratio)

    g = torch.Generator()
    g.manual_seed(seed)

    perm = torch.randperm(dataset_size, generator=g)[:selected_size]
    selected_indices = perm.tolist()

    subset = Subset(dataset, selected_indices)

    loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, generator=g)
    return loader

def objective(trial):
    args = get_args()
    args.lambda = trial.suggest_float("lambda", 1e-4, 1.00)
    args.distributed = False

    print(f"Trying lambda = {args.lambda:.4f}")

    device = torch.device(args.device)
    cudnn.benchmark = True

    dataset_val, args.nb_classes = build_dataset(is_train=False, args=args)
    data_loader_val = get_partial_loader(
        dataset_val, 
        batch_size=args.batch_size, 
        ratio=0.02, 
        shuffle=True,
        seed = 42,
    )

    model = create_model(args.model, pretrained=False, img_size=args.input_size, num_classes=args.nb_classes)
    trans_utils.replace_test_by_testneuron(model)
    
    if args.model_path:
        checkpoint = torch.load(args.model_path, map_location='cpu')
        print("Load ckpt from %s" % args.model_path)
        checkpoint_model = None
        for model_key in ['model','module']:
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        utils.load_state_dict(model, checkpoint_model, prefix='')
    model.to(device)

    trans_utils.replace_nonlinear_by_neuron(model)
    trans_utils.replace_at_by_neuron(model)
    trans_utils.replace_testneuron_by_sfneuron(model, args)

    logfile = "optlogs/eva.txt"
    results = evaluate_snn(data_loader_val, model, device, args.test_T, args, logfile = logfile)

    trial.set_user_attr("val_acc", results["acc"].global_avg)
    trial.set_user_attr("val_energy", results["energy"])
    
    return results["acc"].global_avg

if __name__ == "__main__":
    study = optuna.create_study(
        direction="maximize",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=10),
    )
    study.optimize(objective, n_trials = 50)

    print("\n================ Final Trial Results ================\n")
    for trial in study.trials:
        print(f"Trial {trial.number}:")
        print(f"  lambda       = {trial.params['lambda']:.6f}")
        print(f"  val_acc  = {trial.user_attrs.get('val_acc', 'N/A')}")
        print(f"  val_energy   = {trial.user_attrs.get('val_energy', 'N/A')}")
        print("--------------------------------------------------")

    print("\n================ Best Trial =========================\n")
    print(f"Best Trial Number: {study.best_trial.number}")
    print(f"Best gap: {study.best_params['lambda']:.6f}")
    print(f"Best val_acc: {study.best_trial.user_attrs.get('val_acc', 'N/A')}")
    print(f"Best val_energy: {study.best_trial.user_attrs.get('val_energy', 'N/A')}")