import torch.distributed.launch
import argparse
import torch
import torch.backends.cudnn as cudnn
from pathlib import Path
from datasets import build_dataset
from engine_for_finetuning import evaluate, evaluate_snn
import model_resnet
import random
import utils
from tqdm import tqdm
import copy

def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata
    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')
    load(model, prefix=prefix)
    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)
    missing_keys = warn_missing_keys
    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))

def save_model(args, model, model_without_ddp):
    output_dir = Path(args.output_dir)
    checkpoint_paths = [output_dir / (args.savename+'.pth')]
    for checkpoint_path in checkpoint_paths:
        to_save = {'model': model_without_ddp.state_dict(),'args': args,}
        save_on_master(to_save, checkpoint_path)
        
def rename_pretrained_keys(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        #print(k)
        if 'downsample' in k:
            k = k.replace('downsample', 'shortcut')
        if 'fc.' in k:
            k = k.replace('fc.', 'linear.')
        new_state_dict[k] = v
    return new_state_dict

def get_args():
    parser = argparse.ArgumentParser()
    # Model parameters
    parser.add_argument('--model', default='eva_g_patch14', type=str, metavar='MODEL',help='Name of model to train')
    parser.add_argument('--input_size', default=32, type=int,help='images input size')
    parser.add_argument('--nb_classes', default=10, type=int,help='number of the classification types')
    parser.add_argument('--model_path', default='')
    parser.add_argument('--percent', default=0.99, type=float)
    parser.add_argument('--monitor', default=False, type=bool)

    # Dataset parameters
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--eval_data_path', default='../datasets/val', type=str,help='dataset path for evaluation')
    parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
    parser.add_argument('--data_set', default='image_folder', choices=['CIFAR10','CIFAR100', 'IMNET', 'image_folder'],type=str, help='ImageNet dataset path')
    parser.add_argument('--output_dir', default='../models/threshold',help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--savename', default='test', type=str)
    
    # Mode parameter
    parser.add_argument('--test_mode', default='ann',choices=['ann', 'for_v', 'snn'], help="test mode")
    parser.add_argument('--test_T', default=8,type=int)
    
    # Multi-threshold neuron parameter
    parser.add_argument('--linear_num', default = 1, type=int)
    parser.add_argument('--lambda', default = 1.0, type=float)
    
    known_args, _ = parser.parse_known_args()

    return parser.parse_args()
        
def main(args):
    args.distributed = False
    device = torch.device(args.device)
    cudnn.benchmark = True

    dataset_val, args.nb_classes = build_dataset(is_train=False, args=args)
    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    sampler_val = torch.utils.data.DistributedSampler(dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=int(args.batch_size),
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False
    )

    if args.test_mode == 'for_v':
        model = model_resnet.ResNet18()
        model.to(device)
        if args.model_path:
            checkpoint = torch.load(args.model_path, map_location='cpu')
            print("Load ckpt from %s" % args.model_path)
            checkpoint_model = None
            for model_key in ['model','module']:
                if model_key in checkpoint:
                    checkpoint_model = checkpoint[model_key]
                    print("Load state_dict by model_key = %s" % model_key)
                    break
            if checkpoint_model is None:
                checkpoint_model = checkpoint
            checkpoint_model = rename_pretrained_keys(checkpoint_model)
            load_state_dict(model, checkpoint_model, prefix='')
        model_resnet.replace_test_by_testneuron(model, percent=1.0)
        
        logfile = 'logs/mtn.txt'
        evaluate(data_loader_val, model, device, args = args, model_without_ddp = model, logfile = logfile)
    
    if args.test_mode == 'snn':
        model = model_resnet.ResNet18()
        model.to(device)
        model_resnet.replace_test_by_testneuron(model, percent=0.99)
        if args.model_path:
            checkpoint = torch.load(args.model_path, map_location='cpu')
            print("Load ckpt from %s" % args.model_path)
            checkpoint_model = None
            for model_key in ['model','module']:
                if model_key in checkpoint:
                    checkpoint_model = checkpoint[model_key]
                    print("Load state_dict by model_key = %s" % model_key)
                    break
            if checkpoint_model is None:
                checkpoint_model = checkpoint
            checkpoint_model = rename_pretrained_keys(checkpoint_model)
            load_state_dict(model, checkpoint_model, prefix='')
        model_resnet.replace_nonlinear_by_neuron(model)
        
        for T in range(1, 257):
            args.linear_num = T
            args.lambda = 1.0/T
            print(args.linear_num, args.lambda)
            model_T = copy.deepcopy(model)
            model_resnet.replace_testneuron_by_sfneuron(model_T, args)
            logfile = 'logs/mtn.txt'
            evaluate_snn(data_loader_val, model_T, device, args.test_T, args, logfile = logfile)

if __name__ == '__main__':
    opts = get_args()
    if opts.output_dir:
        Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
    main(opts)