import logging
from collections import defaultdict
import torch
import torch.nn as nn
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
    AMPTrainer,
    SimpleTrainer,
    default_argument_parser,
    default_setup,
    default_writers,
    hooks,
    launch,
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, inference_on_dataset_for_v, inference_on_dataset_snn, print_csv_format
from detectron2.utils import comm
from detectron2.modeling.backbone import trans_utils

logger = logging.getLogger("detectron2")

def count_params_by_module(model):
    param_dict = defaultdict(int)
    for name, param in model.named_parameters():
        if param.requires_grad:
            module_key = ".".join(name.split(".")[:3])
            param_dict[module_key] += param.numel()
    return param_dict

def model_info(module, file_path, indent=0):
    with open(file_path, "w") as f:
        def _write(module, prefix=""):
            for name, submodule in module.named_children():
                line = prefix + name + f": {submodule.__class__.__name__}\n"
                f.write(line)
                _write(submodule, prefix + "  ")
        _write(module)

    logger.info(f"Save to: {file_path}")
    
def load_single_threshold(name, module, state_dict):
    scale_p_key = f"{name}.scale_p"
    scale_n_key = f"{name}.scale_n"

    if scale_p_key in state_dict and scale_n_key in state_dict:
        shape_p = module.scale_p.shape
        shape_n = module.scale_n.shape                    

        scale_p_val = state_dict[scale_p_key]
        scale_n_val = state_dict[scale_n_key]

        try:
            module.scale_p.data.copy_(scale_p_val.expand(shape_p))
            module.scale_n.data.copy_(scale_n_val.expand(shape_n))
            logger.info(f"[Load OK] {scale_p_key}: {scale_p_val.shape} -> {shape_p}")
            logger.info(f"[Load OK] {scale_n_key}: {scale_n_val.shape} -> {shape_n}")
        except Exception as e:
            logger.info(f"[Load Failed] {name} scale parameters shape mismatch: {e}")
    else:
        logger.info(f"[Skip] Missing keys in checkpoint for {name}: {scale_p_key}, {scale_n_key}")   

def load_testneuron_thresholds(model: nn.Module, state_dict: dict):
    for name, module in model.named_modules():
        if isinstance(module, trans_utils.TestNeuron):
            load_single_threshold(name, module, state_dict)

def do_test_ann(cfg, model):
    model_info(model, "model.txt")
    param_count = count_params_by_module(model)
    with open("model.txt", "a") as f:
        for module, count in sorted(param_count.items(), key=lambda x: -x[1]):
            f.write(f"{module:<40} {count:>10}  ({count / 1e6:.3f}M)")
    
    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
        )
        #trans_utils.final_flush(model)
        print_csv_format(ret)
        return ret

def do_test_for_v(cfg, model):
    save_path = "/root/autodl-tmp/models/threshold/eva_coco_det_threshold.pth"
    trans_utils.replace_test_by_testneuron(model, percent = 0.99)
    model_info(model, "model.txt")
    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset_for_v(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator), save_path = save_path
        )
        print_csv_format(ret)
        return ret

def do_test_snn(cfg, model, args):
    args.model = 'eva'
    args.linear_num = 8
    args.qkv_num = 8
    args.softmax_num = 8
    args.softmax_p = 0.95 / 263
    args.lambda = 0.3498856
    args.T = 1
    args.monitor = True
    args.eval_interval = 20

    checkpoint_path = "/root/autodl-tmp/models/threshold/eva_coco_det_threshold.pth"
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    trans_utils.replace_test_by_testneuron(model, percent = 0.99)
    load_testneuron_thresholds(model, checkpoint)
    trans_utils.replace_nonlinear_by_neuron(model)
    trans_utils.replace_at_by_neuron(model)
    trans_utils.replace_testneuron_by_hmtneuron(model, args)

    model_info(model, "model.txt")
    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset_snn(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator), args
        )
        print_csv_format(ret)
        return ret

def main(args):
    cfg = LazyConfig.load(args.config_file)
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args)

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
        #print(do_test_ann(cfg, model))
        #print(do_test_for_v(cfg, model))
        print(do_test_snn(cfg, model, args))

if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )