import numpy as np
import os
import torch
from scipy.io import savemat,loadmat
import logging
import sys

from funcs import train_ann,  eval_snn, eval_ann
import utils
import sys
sys.path.append("..")
from dst_scheduler import DSTScheduler, create_step_wrapper
import Models 
from arg_parser import get_args
from exp_comparison import find_best_sparsity
from CHT import fuse_bn_recursively
from exp_comparison import grid_best_input

def path(args):
    basic_input=os.path.join("../input",args.architecture,args.dataset)
    if args.architecture == 'MLP':
        sparse=f"s_{args.linear_sparsity}" 
        input_path = os.path.join(basic_input, sparse)
    else:   
        sparse=f"conv_{args.conv_sparsity}/s_0.0"
        method_str = f'd_0.0/onefc_True'
        input_lr, input_bs = grid_best_input(args.architecture, args.dataset)
        config_str=f'lr_{input_lr}/bs_{input_bs}' 
        input_path = os.path.join(basic_input, sparse, method_str, config_str)
        
    tr_save_name=os.path.join(f'{args.architecture}/{args.dataset}', sparse, f'lr_{args.lr}/bs_{args.bs}/L_{args.l}')
    
    ft_save_name=tr_save_name+'/ft'

    ft_model=os.path.join("../QCFS/bestmodel",tr_save_name,'best_model.pth')

    final_model_save_name=tr_save_name if args.architecture=='MLP' else ft_save_name
    val_model=os.path.join("../QCFS/bestmodel",final_model_save_name,'best_model.pth')

    if not args.save:
        tr_save_name=None
        ft_save_name=None
    return input_path,tr_save_name,ft_save_name,ft_model,val_model


def main(args):
    #prepare things
    device=torch.device(args.device)
    input_path,tr_save_name,ft_save_name,ft_model,val_model = path(args)
    res_path='../QCFS/results'
    work_dir='../QCFS/'

    model, num_activations, train_loader,test_loader=Models.prepare_model_and_loader(args)
    #根据conv sparsity选择dense还是hanming的模型
    model.to(device)
    #lazy intilization: trigger register_buffer
    for img,lab in train_loader:
        img=img.to(device)
        model(img)
        break
    
    #load model
    statedict=torch.load(os.path.join(input_path,'best_model.pth'),weights_only=True,map_location='cpu')
    torch.cuda.empty_cache()
    model.load_state_dict(statedict,strict=True)
    del statedict
    #CHT.set_initilized_true_convcht(model) not usefull anymore because of: model(img)

    #第一步：替换relu,pool,bn
    model = utils.replace_maxpool2d_by_avgpool2d(model)
    model = utils.replace_activation_by_floor(model,args.l,args.activation_mode)
    if args.fold_BN:
        model = fuse_bn_recursively(model) #TODO: do a small experiment
    model.to(device)
    
    print('========================================train========================================')
    args.epochs=300 #DEBUG
    #优化器(未覆写step方法)->调度器 pruner->覆写optimizer.step
    optimizer= torch.optim.SGD(model.parameters(),args.lr,momentum=0.9,weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    if args.linear_sparsity>0.0 and not args.one_fc:
        pruner_dict=torch.load(os.path.join(input_path,'pruner.pth'),weights_only=True,map_location='cpu')
        torch.cuda.empty_cache() #重写了load_state_dict方法，自动适应device
        T_end = int(args.epochs * 0.75) if (args.adaptive_zeta or args.EM_S) else args.epochs
        pruner = DSTScheduler(model, optimizer, alpha=args.zeta, delta=args.update_interval * len(train_loader), sparsity_distribution=args.sparsity_distribution, static_topo=False, T_end=T_end* len(train_loader), ignore_linear_layers=False, grad_accumulation_n=1, args=args,state_dict=pruner_dict)
        del pruner_dict
    else:
        pruner = None

    try:
        tr_acc_val, tr_loss_val, tr_acc_train, tr_loss_train = train_ann(train_loader, test_loader, model, optimizer, scheduler, args.epochs, device, activation_mode=args.activation_mode,save=tr_save_name, work_directory=work_dir)
    except ValueError as e:
        msg=f'During Traning {e}'
        raise ValueError(msg) from e
            
    if args.save:
        os.makedirs(os.path.join(res_path,tr_save_name),exist_ok=True)
        savemat(os.path.join(res_path,tr_save_name,'res.mat'),{'tr_acc_val': tr_acc_val ,'tr_loss_val':tr_loss_val, 'tr_acc_train': tr_acc_train, 'tr_loss_train': tr_loss_train})
    
    if args.architecture!='MLP':
        print('======================================finetune======================================')
        args.epochs=100 #DEBUG
        statedict=torch.load(ft_model,weights_only=True,map_location='cpu')
        torch.cuda.empty_cache()
        model.load_state_dict(statedict,strict=False)
        del statedict

        model = utils.replace_layer_activation_by_channel(model,args.l,args.activation_mode)
        model.to(device)
        #重新包装优化器
        optimizer= torch.optim.SGD(model.parameters(),0.1*args.lr,momentum=0.0,weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
        if args.linear_sparsity>0.0:
            #create_step_wrapper(pruner,optimizer)
            pass
        '''  
        for img,label in train_loader:
            img=img.to(device)
            model(img)
            break
        '''  
        
        try:
            ft_acc_val, ft_loss_val, ft_acc_train, ft_loss_train =train_ann(train_loader, test_loader, model, optimizer, scheduler, args.epochs, device,  activation_mode=args.activation_mode,save=ft_save_name,work_directory=work_dir)
        except ValueError as e:
            msg=f'During Finetuning {e}'
            raise ValueError(msg) from e
                
        if args.save:
            os.makedirs(os.path.join(res_path,ft_save_name),exist_ok=True)
            savemat(os.path.join(res_path,ft_save_name,'res.mat'),{'ft_acc_val': ft_acc_val ,'ft_loss_val':ft_loss_val, 'ft_acc_train': ft_acc_train, 'ft_loss_train': ft_loss_train})
        
    print('========================================test========================================')
    #args.t=1 #DEBUG
    statedict=torch.load(val_model,weights_only=True,map_location='cpu')
    torch.cuda.empty_cache()
    model.load_state_dict(statedict,strict=False)
    del statedict

    model.to(device)
    model.eval()
    
    model = utils.replace_activation_by_neuron(model)
    acc, layer_asfr= eval_snn(test_loader, model, device, args.t, num_activations)

    if args.save:
        final_save_name=ft_save_name if args.architecture!='MLP' else tr_save_name
        res_dict=loadmat(os.path.join(res_path,final_save_name,'res.mat'))
        res_dict['snn_acc']=acc #savemat会悄悄丢弃非标准的类型，包括torch.tensor! 保存最安全的就是np.array
        res_dict['LASFR']=layer_asfr
        savemat(os.path.join(res_path,ft_save_name,'res.mat'),res_dict)
    print('Done')


if __name__ == "__main__":
    args=get_args()
    utils.seed_all(args.seed)

    if args.architecture == 'VGG-16':
        if args.conv_sparsity==0.0:
            assert args.linear_sparsity==0.0
        elif args.one_fc : 
            args.linear_sparsity==0.0
        else:
            method_str = f'd_{args.dropout}/onefc_{args.one_fc}'
            input_lr, input_bs = grid_best_input(args.architecture, args.dataset)
            config_str=f'lr_{input_lr}/bs_{input_bs}' 
            args.linear_sparsity=find_best_sparsity(args.dataset, method_str+'/'+config_str)
            args.linear_sparsity=find_best_sparsity(args.dataset, method_str)

    print(args)

    logging.basicConfig(
        filename="../QCFS/error.log",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )

    try:
        main(args)
    except Exception as e:
        logging.exception(f"exception in main\n{args}")
        sys.exit(1)

