import optuna
from torch.utils.data import Subset
import random

import logging
from collections import defaultdict
import torch
import torch.nn as nn
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
    AMPTrainer,
    SimpleTrainer,
    default_argument_parser,
    default_setup,
    default_writers,
    hooks,
    launch,
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, inference_on_dataset_for_v, inference_on_dataset_snn, print_csv_format
from detectron2.utils import comm
from detectron2.modeling.backbone import trans_utils

logger = logging.getLogger("detectron2")

def get_partial_dataset(dataset, ratio=0.02):
    total_len = len(dataset)
    sample_size = int(total_len * ratio)
    indices = random.sample(range(total_len), sample_size)
    return Subset(dataset, indices)
    
def load_single_threshold(name, module, state_dict):
    scale_p_key = f"{name}.scale_p"
    scale_n_key = f"{name}.scale_n"

    if scale_p_key in state_dict and scale_n_key in state_dict:
        shape_p = module.scale_p.shape
        shape_n = module.scale_n.shape                    

        scale_p_val = state_dict[scale_p_key]
        scale_n_val = state_dict[scale_n_key]

        try:
            module.scale_p.data.copy_(scale_p_val.expand(shape_p))
            module.scale_n.data.copy_(scale_n_val.expand(shape_n))
            logger.info(f"[Load OK] {scale_p_key}: {scale_p_val.shape} -> {shape_p}")
            logger.info(f"[Load OK] {scale_n_key}: {scale_n_val.shape} -> {shape_n}")
        except Exception as e:
            logger.info(f"[Load Failed] {name} scale parameters shape mismatch: {e}")
    else:
        logger.info(f"[Skip] Missing keys in checkpoint for {name}: {scale_p_key}, {scale_n_key}")   

def load_testneuron_thresholds(model: nn.Module, state_dict: dict):
    for name, module in model.named_modules():
        if isinstance(module, trans_utils.TestNeuron):
            load_single_threshold(name, module, state_dict)

def do_test_snn(cfg, model, args):
    args.model = 'eva'
    args.linear_num = 8
    args.qkv_num = 8
    args.softmax_num = 8
    args.softmax_p = 0.92/263
    assert hasattr(args, 'lambda')
    args.T = 1
    args.monitor = False
    args.eval_interval = 20

    checkpoint_path = "/root/autodl-tmp/models/threshold/eva_coco_det_threshold.pth"
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    trans_utils.replace_test_by_testneuron(model, percent = 0.99)
    load_testneuron_thresholds(model, checkpoint)
    trans_utils.replace_nonlinear_by_neuron(model)
    trans_utils.replace_at_by_neuron(model)
    trans_utils.replace_testneuron_by_hmtneuron(model, args)

    test_loader_cfg = cfg.dataloader.test
    test_dataset = instantiate(test_loader_cfg.dataset)
    test_dataset = get_partial_dataset(test_dataset, ratio=0.2)
    test_loader_cfg.dataset = test_dataset
    test_loader = instantiate(test_loader_cfg)
    cfg.dataloader.evaluator.dataset_name = "coco_2017_val"

    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset_snn(
            model, test_loader, instantiate(cfg.dataloader.evaluator), args
        )
        print_csv_format(ret)
        return ret

def main(args):
    cfg = LazyConfig.load(args.config_file)
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args)

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
        # print(do_test_ann(cfg, model))
        # print(do_test_for_v(cfg, model))
        print(do_test_snn(cfg, model, args))

def objective(trial):
    args = default_argument_parser().parse_args()
    cfg = LazyConfig.load(args.config_file)
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args)

    args.lambda = trial.suggest_float("lambda", 1e-4, 1.0)

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
        results = do_test_snn(cfg, model, args)
        
        score_ap = results["bbox"]["AP"]
        score_ap50 = results["bbox"]["AP50"]

        trial.set_user_attr("AP", float(score_ap))
        trial.set_user_attr("AP50", float(score_ap50))

        return score_ap

if __name__ == "__main__":
    study = optuna.create_study(direction="maximize")  # maximize AP
    study.optimize(objective, n_trials=50)

    for trial in study.trials:
        print(f"Trial {trial.number}: lambda={trial.params['lambda']:.4f}, AP={trial.user_attrs['AP']:.4f}, AP50={trial.user_attrs['AP50']:.4f}")

    print("Best lambda:", study.best_params["lambda"])
    print("Best AP:", study.best_value)

    best_trial = study.best_trial
    print("AP50 at best:", best_trial.user_attrs["AP50"])