import argparse
from re import M
from tkinter.font import names
from typing import Iterable
from matplotlib.font_manager import weight_dict
import scipy
import yaml
import torch
import torch.nn as nn
from spikingjelly.clock_driven import functional
from timm.utils import *
from timm.models import create_model,resume_checkpoint
import model
from timm.data import create_dataset,create_loader,resolve_data_config
import os
import logging
from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
import json
import matplotlib.pyplot as plt
from timm.optim import create_optimizer_v2, optimizer_kwargs

import numpy as np





_logger = logging.getLogger('train')
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='cifar10.yml', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments') # imagenet.yml  cifar10.yml




parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')


parser.add_argument('--resume', default='', type=str, metavar='PATH')
parser.add_argument('--model', default='vitsnn', type=str, metavar='MODEL',
                    help='Name of model to train (default: "countception"')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
                    help='how many training processes to use (default: 1)')

parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')


#lr
parser.add_argument(
        "--opt",
        default="adamw",
        type=str,
        metavar="OPTIMIZER",
        help='Optimizer (default: "adamw"',
    )

parser.add_argument("--momentum",type=float,default=0.9,metavar="M",
                    help="Optimizer momentum (default: 0.9)",
)
parser.add_argument(
        "--lr",
        type=float,
        default=None,
        metavar="LR",
        help="learning rate (absolute lr)",
    )

parser.add_argument(
    "--clip_grad",
    type=float,
    default=None,
    metavar="NORM",
    help="Clip gradient norm (default: None, no clipping)",
)
parser.add_argument(
    "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
)



#data
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
parser.add_argument('--crop-pct', default=None, type=float,
                    metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--dataset', '-d', metavar='NAME', default='torch/cifar10',
                    help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('-data-dir', metavar='DIR',default="/media/data/spike-transformer-network/torch/cifar10/",
                    help='path to dataset') #./torch/imagenet/
parser.add_argument('--pin-mem', action='store_true', default=False,
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-vb', '--val-batch-size', type=int, default=16, metavar='N',
                    help='input val batch size for training (default: 32)')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
                    help='Image patch size (default: None => model default)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
                    help='dataset validation split (default: validation)')


#model
parser.add_argument('-T', '--time-step', type=int, default=4, metavar='time',
                    help='simulation time step of spiking neuron (default: 4)')
parser.add_argument('--td', action='store_true', default=False,
                    help='load top-down mechanism')
parser.add_argument('-L', '--layer', type=int, default=4, metavar='layer',
                    help='model layer (default: 4)')
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
                    help='number of label classes (Model default if None)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
                    metavar='N N N',
                    help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--dim', type=int, default=None, metavar='N',
                    help='embedding dimsension of feature')
parser.add_argument('--num_heads', type=int, default=None, metavar='N',
                    help='attention head number')
parser.add_argument('--patch-size', type=int, default=None, metavar='N',
                    help='Image patch size')
parser.add_argument('--mlp-ratio', type=int, default=None, metavar='N',
                    help='expand ration of embedding dimension in MLP block')
parser.add_argument('--layer_td', type=str, default='batch')


parser.add_argument("--save", default=None, type=str)
parser.add_argument("--beta", default=0.5, type=float)




def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    print(args_config)
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text






firing_rate = {}
synpase_num = {}


def layer_num(module):
    return len(list(module.children())) 


def find_layer(module, prefix=''):


    for name,child in module.named_children():

        tmp = prefix
        prefix = prefix + '.' + name if prefix != '' else name

        # print(prefix)s

        # if isinstance(child, nn.Conv1d) or isinstance(child, nn.Conv2d) or isinstance(child, nn.Linear):
            # print(prefix,"have lif")
            # hook = make_firing_hook(prefix
        if isinstance(child, (MultiStepLIFNode,nn.Linear,nn.Conv1d,nn.Conv2d,nn.BatchNorm1d,nn.BatchNorm2d)):
        
            # flops_hook = make_flops_hook(prefix)

            firing_hook = make_firing_hook(prefix)

            # flops = child.register_forward_hook(flops_hook)

            firing = child.register_forward_pre_hook(firing_hook)

        elif layer_num(child) > 0:
            
                find_layer(child,prefix = prefix)
            
                
        
        prefix = tmp




def make_firing_hook(prefix):      #closure

    def firing_hook_pre(module,input,output):

        if isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            
            input_tensor = input[0][0]
            # print(input_tensor.shape)
            
            count_of_ones = (input_tensor == 1).sum().item()

            total_elements = input_tensor.numel()

            # total_elements = output.numel()

            # # print(f'firing rate of {prefix} is: {ratio_ones:.4f}')
    
          
            firing_rate[prefix] = round(count_of_ones/total_elements, 3)
            # else:
            #     firing_rate[prefix] = [round(count_of_ones/total_elements, 3)]
    


            # print(f'shape of {prefsix} is: {v}')

    def var_hook(module,input):  #all 
            
        # if isinstance(module, (MultiStepLIFNode)):

        input_tensor = input[0]
        # output_tensor = input[0][0]
        # eps = 1e-8
        input_tensor = input_tensor.flatten()
        # output_tensor = output_tensor.flatten()
        input_variance = input_tensor.var(dim=0).mean().item()
        # output_variance = output_tensor.var(dim=0).mean().item()

        print(f"Layer {prefix} intput variance: {input_variance}")
        # print(f"Layer {prefix} intput variance: {output_variance}")


        # total_elements = output.numel()

        # # print(f'firing rate of {prefix} is: {ratio_ones:.4f}')

    
        # firing_rate[prefix] = round(variance, 3)




    return var_hook     




# def firing_hook(module,input,output):


#     count_ones = torch.sum(output == 1).item()
#     total_elements = output.numel()
#     ratio_ones = count_ones / total_elements
#     print(f'firing rate of {module} is: {ratio_ones:.4f}')





# def make_hook(prefix,hook_type = 'fire'):      #closure


#     def firing_hook(module,input,output):

#         count_ones = torch.sum(output == 1).item()
#         total_elements = output.numel()
#         ratio_ones = count_ones / total_elements
#         # print(f'firing rate of {prefix} is: {ratio_ones:.4f}')
#         firing_rate[prefix] = ratio_ones

#         v  = module.v.shape
#         # synpase_num[prefix] = synpase_num
#         print(f'shape of {prefix} is: {v}')



#     def entropy_hook(module, input, output):
    

#         # input = input[0]
#         activations = output.detach().cpu().numpy()

#         flattened_activations = activations.flatten()

#         # normalized_activations = (flattened_activations - np.min(flattened_activations)) / (np.max(flattened_activations) - np.min(flattened_activations))
        
#         hist, bin_edges = np.histogram(flattened_activations, bins=100, density=True)

#         h = scipy.stats.entropy(hist)

#         print("The entropy of ",type(module).__name__,"is",h)


    
#     hook_dict = 
#     {
#         "fire" : firing_hook,
#         "entropy" : entropy_hook
#     }


#     return hook_dict[hook_type]







def main():
     
    setup_default_logging()
    args, args_text = _parse_args()



    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    else:
        args.distributed = False

    # args.prefetcher = not args.no_prefetcher
    args.device = 'cuda:1'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        # _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
        #              % (args.rank, args.world_size))
    # else:
    #     _logger.info('Training with a single process on 1 GPUs.')

    assert args.rank >= 0




    model = create_model(
        "QKFormer",
        drop_rate=0.,
        drop_path_rate=0.2,
        drop_block_rate=None,
        img_size_h=args.img_size, img_size_w=args.img_size,
        patch_size=args.patch_size, embed_dims=args.dim, num_heads=args.num_heads, mlp_ratios=args.mlp_ratio,
        in_channels=3, num_classes=args.num_classes, qkv_bias=False,
        depths=args.layer, sr_ratios=1,
        T=args.time_step,
        layer_td = args.layer_td,
        
        td=args.td,pretrained = False)
    
    model.cuda()
    

    data_config = resolve_data_config(vars(args), model=model, verbose=True)

    dataset_eval = create_dataset(
        args.dataset, root=args.data_dir, 
        split=args.val_split, is_training=False, 
        batch_size=args.batch_size)
    

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.val_batch_size,
        is_training=False,
        use_prefetcher=False,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    random_seed(args.seed, args.rank)
    if args.resume:
        resume_epoch = resume_checkpoint(
            model, args.resume,
            optimizer=None,
            loss_scaler=None,
            log_info=args.local_rank == 0)
    

  
    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
    loss_scaler = NativeScaler()
 



    find_layer(model)


    for batch_idx, (input, target) in enumerate(loader_eval):
        

        input = input.cuda()
        target = target.cuda()
        
        pass

        criterion = torch.nn.CrossEntropyLoss()
            
        # 打印图片的形状和标签s
        # print(f"First image shape: {input.shape}")
        if args.td:
            x1, td, tmp = model(input)
            functional.reset_net(model)
            x2 = model(tmp, td=td)

            loss = (1 - args.beta) * criterion(x1, target) + args.beta * criterion(
                x2, target)
            
        else:
            
            output = model(input)

            loss = criterion(output, target)



        exit(0)
        # optimizer.zero_grad()
        # loss.backward()
        # functional.reset_net(model)

        # gradients = model.patch_embed1.proj_conv.weight.grad

        # gradients = gradients.flatten().cpu().numpy()

        # # kur = kurtosis(gradients)

        # print("model is", args.model, "kur is", kur)


        # print(json.dumps(synpase_num, indent=4))

        # plt.hist(gradients, bins=70, color="red", alpha=0.7)
        # # plt.xlim(-0.05, 0.05)
        # # plt.ylim(0, 300)
        # plt.xlabel("Gradient Value")
        # plt.ylabel("Frequency")

    

        # plt.savefig(args.save)

        exit(0)

















if __name__ == '__main__':
    main()