import os
import sys
import torch
import torch.nn as nn
import logging
import numpy as np
from scipy.io import savemat, loadmat
import random

sys.path.append("../")
import Models
from dst_scheduler import DSTScheduler
from models.utils import En_Decoding2
from models.new_convert_code_2 import SpikeModel,SpikeModule
from exp_comparison import grid_best_input
from tqdm import tqdm
from TrainTest import  Train, Test
sys.path.append("../QCFS")
from QCFS.arg_parser import get_args


def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark=False

# 路径生成函数，保持你的习惯
def path(args):
    basic_input=os.path.join("./snn_conversion/input",args.architecture,args.dataset)
    if args.architecture == 'MLP':
        sparse=f"s_{args.linear_sparsity}" 
        input_path = os.path.join(basic_input, sparse)
    else:   
        sparse=f"conv_{args.conv_sparsity}/s_0.0"
        method_str = f'd_0.0/onefc_True'
        input_lr, input_bs = grid_best_input(args.architecture, args.dataset)
        config_str=f'lr_{input_lr}/bs_{input_bs}' 
        input_path = os.path.join(basic_input, sparse, method_str, config_str)

    tr_save_name = os.path.join(f"{args.architecture}/{args.dataset}", sparse , f"lr_{args.lr}/bs_{args.bs}/t_{args.t}/")
    if not args.save:
        tr_save_name = None

    return input_path, tr_save_name

def main(args):
    device = torch.device(args.device)
    input_path, save_name = path(args)
    res_path = "./snn_conversion/AEC/results"
    work_dir = "./snn_conversion/AEC/bestmodel"

    # 数据 + 模型
    source_ann, num_activations, train_loader, test_loader = Models.prepare_model_and_loader(args)

    # 加载ANN模型
    statedict = torch.load(os.path.join(input_path, "best_model.pth"), map_location="cpu", weights_only=True)
    torch.cuda.empty_cache()
    source_ann.load_state_dict(statedict, strict=True)
    del statedict
    source_ann.to(device)

    if args.linear_sparsity>0.0 and not args.one_fc:
        pruner_dict=torch.load(os.path.join(input_path,'pruner.pth'),weights_only=True,map_location='cpu')
        torch.cuda.empty_cache() #重写了load_state_dict方法，自动适应device
        T_end = int(args.epochs * 0.75) if (args.adaptive_zeta or args.EM_S) else args.epochs
        pruner = DSTScheduler(source_ann, optimizer, alpha=args.zeta, delta=args.update_interval * len(train_loader), sparsity_distribution=args.sparsity_distribution, static_topo=False, T_end=T_end* len(train_loader), ignore_linear_layers=False, grad_accumulation_n=1, args=args,state_dict=pruner_dict)
        del pruner_dict

    for name,p in source_ann.named_parameters():
        p.requires_grad = False

    snn = SpikeModel(model=source_ann, sim_length=args.t, dataset=args.dataset)

    for name,p in snn.named_modules():
        if isinstance(p,nn.BatchNorm2d):
            for n,m in p.named_parameters():
                m.requires_grad = True
    snn.to(device)

    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(snn.parameters(), lr=args.lr, momentum=0.9,weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)
    ''''''
    tr_acc_val, tr_loss_val, tr_acc_train, tr_loss_train= Train(train_loader, test_loader, snn, optimizer, scheduler,criterion ,args.epochs, device,  save_name,  work_directory=work_dir)
    
    if args.save:
        os.makedirs(os.path.join(res_path, save_name), exist_ok=True)
        savemat(
            os.path.join(res_path, save_name, "res.mat"),
            {
                "acc_val": tr_acc_val,
                "loss_val": tr_loss_val,
                "acc_train": tr_acc_train,
                "loss_train": tr_loss_train,
            }
        )
    
    #测试逻辑，应该重新加载最好的snn
    statedict = torch.load(os.path.join(work_dir, save_name ,"best_model.pth"), map_location="cpu", weights_only=True)
    torch.cuda.empty_cache()
    snn.load_state_dict(statedict, strict=True)
    del statedict
    snn.to(device)  #require_grads在测试时不重要。
    acc, loss, LASFR = Test(test_loader, snn, criterion, device, True)
    print('tested')
    # 保存结果
    if args.save:
        res = dict(loadmat(os.path.join(res_path, save_name, "res.mat")))
        res['LASFR'] = LASFR
        res['snn_acc'] = acc
        savemat(os.path.join(res_path, save_name, "res.mat"), res)

    print("Done")


if __name__ == "__main__":
    args = get_args()
    seed_all(args.seed)

    # 根据 sparsity 设置
    if args.architecture == 'VGG-16':
        if args.conv_sparsity==0.0:
            assert args.linear_sparsity==0.0
        elif args.one_fc : 
            args.linear_sparsity==0.0
        else:
            pass

    print(args)

    logging.basicConfig(
        filename="./error.log",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )

    try:
        main(args)
    except Exception as e:
        logging.exception(f"exception in main\n{args}")
        sys.exit(1)
