import torch.distributed.launch
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from pathlib import Path
from timm.models import create_model
from datasets import build_dataset
from engine_for_finetuning import evaluate, evaluate_snn
import utils
import trans_utils
import model_eva
import model_vit

def load_single_threshold(name, module, state_dict):
    scale_p_key = f"{name}.scale_p"
    scale_n_key = f"{name}.scale_n"

    if scale_p_key in state_dict and scale_n_key in state_dict:
        shape_p = module.scale_p.shape
        shape_n = module.scale_n.shape                    

        scale_p_val = state_dict[scale_p_key]
        scale_n_val = state_dict[scale_n_key]

        try:
            module.scale_p.data.copy_(scale_p_val.expand(shape_p))
            module.scale_n.data.copy_(scale_n_val.expand(shape_n))
            logger.info(f"[Load OK] {scale_p_key}: {scale_p_val.shape} -> {shape_p}")
            logger.info(f"[Load OK] {scale_n_key}: {scale_n_val.shape} -> {shape_n}")
        except Exception as e:
            logger.info(f"[Load Failed] {name} scale parameters shape mismatch: {e}")
    else:
        logger.info(f"[Skip] Missing keys in checkpoint for {name}: {scale_p_key}, {scale_n_key}")   

def load_testneuron_thresholds(model: nn.Module, state_dict: dict):
    for name, module in model.named_modules():
        if isinstance(module, trans_utils.TestNeuron):
            load_single_threshold(name, module, state_dict)
            
def get_args():
    parser = argparse.ArgumentParser()
    # Model parameters
    parser.add_argument('--model', default='eva_g_patch14', type=str, metavar='MODEL',help='Name of model to train')
    parser.add_argument('--input_size', default=224, type=int,help='images input size')
    parser.add_argument('--nb_classes', default=1000, type=int,help='number of the classification types')
    parser.add_argument('--model_path', default='')
    parser.add_argument('--percent', default=0.97, type=float)
    parser.add_argument('--monitor', default=True, type=bool)

    # Dataset parameters
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--eval_data_path', default='../datasets/val', type=str,help='dataset path for evaluation')
    parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
    parser.add_argument('--data_set', default='image_folder', choices=['CIFAR10','CIFAR100', 'IMNET', 'image_folder'],type=str, help='ImageNet dataset path')
    parser.add_argument('--output_dir', default='../models/threshold',help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--savename', default='test', type=str)
    
    # Mode parameter
    parser.add_argument('--test_mode', default='ann',choices=['ann', 'for_v', 'snn'], help="test mode")
    parser.add_argument('--test_T', default = 8, type=int)
    
    # SFN parameter
    parser.add_argument('--linear_num', default = 8, type=int)
    parser.add_argument('--qkv_num', default = 8, type=int)
    parser.add_argument('--softmax_num', default = 8, type=int)
    parser.add_argument('--softmax_p', default = 0.93 / 263, type=float)
    parser.add_argument('--lambda', default = 0.282360, type=float)
    
    known_args, _ = parser.parse_known_args()

    return parser.parse_args()

def main(args):
    args.distributed = False
    print("{}".format(args).replace(', ', ',\n'))
    device = torch.device(args.device)
    cudnn.benchmark = True
    
    dataset_val, args.nb_classes = build_dataset(is_train=False, args=args)
    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    sampler_val = torch.utils.data.DistributedSampler(dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=int(args.batch_size),
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    model = create_model(args.model, pretrained=False, img_size=args.input_size, num_classes=args.nb_classes)
    if args.test_mode == 'ann':
        pass
    elif args.test_mode == 'for_v':
        pass
    elif args.test_mode == 'snn':
        trans_utils.replace_test_by_testneuron(model)
    if args.model_path:
        checkpoint = torch.load(args.model_path, map_location='cpu')
        print("Load ckpt from %s" % args.model_path)
        checkpoint_model = None
        for model_key in ['model','module']:
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        utils.load_state_dict(model, checkpoint_model, prefix='')
    model.to(device)
    
    # model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)
    print("Batch size = %d" % args.batch_size)
    # args.gpu_for_use = [int(i) for i in args.gpu_for_use]
    # model = torch.nn.DataParallel(model,device_ids=args.gpu_for_use,output_device=0)

    logfile = 'logs/log_snn.txt'
    if args.test_mode=='ann':
        test_stats = evaluate(data_loader_val, model, device, args = args, model_without_ddp = model, logfile = logfile)
        trans_utils.final_flush(model)
    elif args.test_mode=='for_v':
        trans_utils.replace_test_by_testneuron(model, args.percent)
        test_stats = evaluate(data_loader_val, model, device, args = args, model_without_ddp = model, logfile = logfile)
    elif args.test_mode=='snn':
        trans_utils.replace_nonlinear_by_neuron(model)
        trans_utils.replace_at_by_neuron(model)
        trans_utils.replace_testneuron_by_sfneuron(model, args)
        test_stats = evaluate_snn(data_loader_val, model, device, args.test_T, args, logfile = logfile)
        
if __name__ == '__main__':
    opts = get_args()
    if opts.output_dir:
        Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
    main(opts)