import torch_pruning as tp
from pathlib import Path
import sys
from types import SimpleNamespace
import os
import torch
from timm.models import create_model
#torch.backends.cuda.matmul.allow_tf32 = False
sys.path.append('/workspace/ICLR2026_submission_number16959/reproduce_iclr/IPPRO_DeiTTiny')
from torch_dataset import load_data_torch

if not 'IMAGENET_ROOT' in os.environ:
    os.environ['IMAGENET_ROOT'] = str('/workspace/Imagenet')

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        if target.ndim == 2:
            target = target.max(dim=1)[1]

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k.cpu().item())
        return res
def infer(model,test_loader,device):
    model.eval()
    total_metrics=SimpleNamespace(**{'count':0,'loss':0,'acc':0,'acc5':0,'ADW':0})

    loss_criterion = torch.nn.CrossEntropyLoss()

    for datum in test_loader:
        image, label = datum   
        image=image.to(device=device)   
        label=label.to(device=device)

        with torch.no_grad():
            outputs= model(image)
            loss=loss_criterion(outputs,label)
            total_metrics.count+=outputs.shape[0]
            # pred=outputs.softmax(-1).argmax(-1)
            total_metrics.loss+=loss.item()
            acc1, acc5 = accuracy(outputs, label, topk=(1, 5))
            total_metrics.acc+=acc1
            total_metrics.acc5+=acc5
    return total_metrics
def infer_imagenet(configs,device='cuda'):
    test_loader = load_data_torch(configs)
    
    model = create_model(
                    'deit_tiny_patch16_224',
                    pretrained=True,
                    num_classes=1000,
                    drop_rate=0,
                    drop_path_rate=0.1,
                    drop_block_rate=None  
                ).to(device=device)

    dummy_input = torch.rand([1,*configs.input_shape]).to(device=device)
    
    baseline_flops , baseline_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)
    
    #weights = torch.load(configs.pretrained_ckpt)
    
    #model.load_state_dict(weights,strict=True)
    base_metrics = infer(model,test_loader,device)
    model = torch.load(configs.pruned_weight)
    model.eval()
    total_metrics = infer(model,test_loader,device)
            # total_metrics.acc+=torch.count_nonzero(pred==label).item()
    pruned_flops , pruned_num_params = tp.utils.count_ops_and_params(model,example_inputs=dummy_input)
    flops_reduced =  (1-(pruned_flops//100000000)/(baseline_flops//100000000))*100
    speedup = baseline_flops/pruned_flops
    base_acc1 = (base_metrics.acc/base_metrics.count)*100
    acc1 = (total_metrics.acc/total_metrics.count)*100
    acc5 = (total_metrics.acc5/total_metrics.count)*100
    print(f'Pruning Result for DeiT-Tiny-ImageNet1k')
    print('MACs reduced:%.2f (speedup %.2f x)'%(flops_reduced,speedup))
    print('ACC1:%.2f (baseline : %.2f, Drop:%.2f)'%(acc1,base_acc1,base_acc1-acc1))
    return total_metrics

if __name__ == '__main__':
    #model_ckpt_root= Path(__file__).parent/'ckpt'
    from torchvision.models import VGG,vgg19_bn,vgg19,VGG19_BN_Weights,VGG19_Weights, ResNet50_Weights
    weights = ResNet50_Weights.IMAGENET1K_V1
    
    configs=SimpleNamespace(
    dataset_name = 'imagenet',
    pretrained_ckpt = weights,
    eval_batch_size=128,
    torch_seed = 328,
    workers=1,
    input_shape = [3,224,224],
    num_classes = 1000,
    pruned_weight = '/workspace/ICLR2026_submission_number16959/ICLR_deittiny_param/deit_tiny_best.pt'
    )
    infer_imagenet(configs)
    
    print('========END EVALUATION=======')
