import torch_pruning as tp
from pathlib import Path
import sys
from types import SimpleNamespace
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from models import TinyResnet56, vgg19_bn, resnet50
from pruned_model_loader import prune_loader_tp
from torch_dataset import load_data_torch
from torchvision.models import ResNet50_Weights

VGG19_PRETRAINED=SimpleNamespace(**{'acc':7351,'flops':512862308,'params':20086692,'count':10000})
R56_PRETRAINED=SimpleNamespace(**{'acc':9353,'flops':127373962 ,'params':855770,'count':10000})

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        if target.ndim == 2:
            target = target.max(dim=1)[1]

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k.cpu().item())
        return res
def infer(model,test_loader,device):
    model.eval()
    total_metrics=SimpleNamespace(**{'count':0,'loss':0,'acc':0,'acc5':0,'ADW':0})

    loss_criterion = torch.nn.CrossEntropyLoss()

    for datum in test_loader:
        image, label = datum   
        image=image.to(device=device)   
        label=label.to(device=device)

        with torch.no_grad():
            outputs= model(image)
            loss=loss_criterion(outputs,label)
            total_metrics.count+=outputs.shape[0]
            # pred=outputs.softmax(-1).argmax(-1)
            total_metrics.loss+=loss.item()
            acc1, acc5 = accuracy(outputs, label, topk=(1, 5))
            total_metrics.acc+=acc1
            total_metrics.acc5+=acc5
    return total_metrics
def infer_cifar10(configs,device='cuda'):
    test_loader = load_data_torch(configs)
    model = TinyResnet56(configs.num_classes).to(device=device)
    dummy_input = torch.rand([1,*configs.input_shape]).to(device=device)
    # baseline_flops , baseline_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)
    # pretrained_state_dict = torch.load(configs.pretrained_ckpt)
    # model.load_state_dict(pretrained_state_dict)
    # base_metrics = infer(model,test_loader,device)
    # print(base_metrics, baseline_flops , baseline_num_params)

    model_state_dict = torch.load(configs.ckpt_path)
    model = prune_loader_tp(model,model_state_dict).to(device=device)
    model.eval()
    total_metrics = infer(model,test_loader,device)
            # total_metrics.acc+=torch.count_nonzero(pred==label).item()
    pruned_flops , pruned_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)

    flops_reduced =  (1-pruned_flops/R56_PRETRAINED.flops)*100
    params_reduce = (1-pruned_num_params/R56_PRETRAINED.params)*100
    speedup = R56_PRETRAINED.flops/pruned_flops
    base_acc1 = (R56_PRETRAINED.acc/R56_PRETRAINED.count)*100
    acc1 = (total_metrics.acc/total_metrics.count)*100
    acc5 = (total_metrics.acc5/total_metrics.count)*100
    print('MACs reduced:%.2f %% (speedup %.2f x)'%(flops_reduced,speedup))
    print('#Params reduced:%.3f %%'%params_reduce)
    print('ACC1:%.3f %% (baseline : %.3f %%, Drop:%.3f %%)'%(acc1,base_acc1,base_acc1-acc1))
    return total_metrics
def infer_cifar100(configs,device = 'cuda'):
    test_loader = load_data_torch(configs)
    model=vgg19_bn(num_classes =configs.num_classes).to(device=device)
    dummy_input = torch.rand([1,*configs.input_shape]).to(device=device)
    # baseline_flops , baseline_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)
    # pretrained_state_dict = torch.load(configs.pretrained_ckpt)
    # model.load_state_dict(pretrained_state_dict)
    # base_metrics = infer(model,test_loader,device)
    # print(base_metrics, baseline_flops , baseline_num_params)

    model_state_dict = torch.load(configs.ckpt_path)
    model = prune_loader_tp(model,model_state_dict).to(device=device)
    model.eval()
    total_metrics = infer(model,test_loader,device)
            # total_metrics.acc+=torch.count_nonzero(pred==label).item()
    pruned_flops , pruned_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)

    flops_reduced =  (1-pruned_flops/VGG19_PRETRAINED.flops)*100
    params_reduce = (1-pruned_num_params/VGG19_PRETRAINED.params)*100
    speedup = VGG19_PRETRAINED.flops/pruned_flops
    base_acc1 = (VGG19_PRETRAINED.acc/VGG19_PRETRAINED.count)*100
    acc1 = (total_metrics.acc/total_metrics.count)*100
    acc5 = (total_metrics.acc5/total_metrics.count)*100
    print('MACs reduced:%.2f %% (speedup %.2f x)'%(flops_reduced,speedup))
    print('#Params reduced:%.3f %%'%params_reduce)
    print('ACC1:%.3f %% (baseline : %.3f %%, Drop:%.3f %%)'%(acc1,base_acc1,base_acc1-acc1))
    return total_metrics


def infer_imagenet(configs,device='cuda'):
    test_loader = load_data_torch(configs)
    model=resnet50().to(device=device)
    dummy_input = torch.rand([1,*configs.input_shape]).to(device=device)
    baseline_flops , baseline_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)
    pretrained_state_dict = ResNet50_Weights.IMAGENET1K_V1.get_state_dict(progress=True)
    model.load_state_dict(pretrained_state_dict)
    base_metrics = infer(model,test_loader,device)

    model_state_dict = torch.load(configs.ckpt_path)
    model = prune_loader_tp(model,model_state_dict).to(device=device)
    model.eval()
    total_metrics = infer(model,test_loader,device)
            # total_metrics.acc+=torch.count_nonzero(pred==label).item()
    pruned_flops , pruned_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)

    flops_reduced =  (1-pruned_flops/baseline_flops)*100
    params_reduce = (1-pruned_num_params/baseline_num_params)*100
    speedup = baseline_flops/pruned_flops
    base_acc1 = (base_metrics.acc/base_metrics.count)*100
    acc1 = (total_metrics.acc/total_metrics.count)*100
    acc5 = (total_metrics.acc5/total_metrics.count)*100
    print(f'Pruning Result for Imagenet_Resnet50')
    print('MACs reduced:%.2f %% (speedup %.2f x)'%(flops_reduced,speedup))
    print('#Params reduced:%.3f %%'%params_reduce)
    print('ACC1:%.3f %% (baseline : %.3f %%, Drop:%.3f %%)'%(acc1,base_acc1,base_acc1-acc1))
    return total_metrics
if __name__ == '__main__':
    model_ckpt_root= Path(__file__).parent/'ckpt'

    exp_name = sys.argv[1]
    ckpt_id = sys.argv[2]

    # exp_name ='cifar100'
    print(f'========START EVALUATION:{exp_name}_{ckpt_id}=======')
    if exp_name == 'cifar10':
        configs=SimpleNamespace(
        dataset_name = 'benchmark_cifar10',
        pretrained_ckpt = 'cifar10_resnet56.pth',
        eval_batch_size=128,
        torch_seed = 328,
        workers=1,
        input_shape = [3,32,32],
        num_classes = 10,
        ckpt_path = model_ckpt_root / f'pruned_cifar10_resnet56_{ckpt_id}.pt'
        )
        infer_cifar10(configs)
    if exp_name == 'cifar100':
        configs=SimpleNamespace(
        dataset_name = 'benchmark_cifar100',
        pretrained_ckpt = 'cifar100_vgg19.pth',
        eval_batch_size=128,
        torch_seed = 328,
        workers=1,
        input_shape = [3,32,32],
        num_classes = 100,
        ckpt_path = model_ckpt_root / f'pruned_cifar100_vgg19_{ckpt_id}.pt'
        )
        infer_cifar100(configs)
    if exp_name == 'imagenet':
        configs=SimpleNamespace(
        dataset_name = 'imagenet',
        pretrained_ckpt = '',
        eval_batch_size=64,
        torch_seed = 807,
        workers=1,
        input_shape = [3,224,224],
        num_classes = 1000,
        ckpt_path = model_ckpt_root / f'pruned_imagenet_resnet50_{ckpt_id}.pt'
        )
        infer_imagenet(configs)

    print(f'========END EVALUATION: {exp_name}_{ckpt_id}=======')
