from Main_Compression import *
def CompressionMain(args):
    BATCH_SIZE = 100
    criterion = nn.CrossEntropyLoss().to(device)
    workers = 4 if device == 'cuda' else 0
    use_gpu = True if device == 'cuda' else False
    input_dim = 3
    if args.dataset == 'cifar100':
        output_dim = 100
    else :
        output_dim = 10

    if args.architecture == 'alexnet':
        net = AlextNet(input_dim,output_dim)

    elif args.architecture == 'vgg16':
        lin1_inp_size = 512
        lin1_out_size = 512
        lin2_inp_size = 512
        lin2_out_size = 512
        lin3_inp_size = 512
        lin3_out_size = output_dim 
        net = vgg16(pretrained=False)
        del net.avgpool
        net.avgpool = lambda x: x
        net.classifier._modules['0'] = nn.Linear(lin1_inp_size, lin1_out_size)
        net.classifier._modules['3'] = nn.Linear(lin2_inp_size, lin2_out_size)
        net.classifier._modules['6'] = nn.Linear(lin3_inp_size, lin3_out_size)

    elif args.architecture == 'lenet':
        net = LeNet()
    model_to_load = args.architecture+ args.dataset
    print('Loading ' + args.architecture + ' ' + args.dataset + '...')
    state_dict = torch.load('./Models/'+model_to_load+'.pth')
    state_dict = OrderedDict(zip(net.state_dict().keys(), state_dict.values()))
    net.load_state_dict(state_dict)
    

    if args.dataset == 'mnist':
        transforms = Compose([Resize(96),ToTensor(),Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
        train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))])),  batch_size=BATCH_SIZE, shuffle=True)

        test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))])),  batch_size=BATCH_SIZE, shuffle=True)
    elif args.dataset == 'mnistfashion':

        train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('data', train=True, download=True,transform=transforms.Compose([
                                                    transforms.ToTensor(),transforms.Normalize((0.2868,), (0.35244,))])),
                                                    batch_size=BATCH_SIZE, shuffle=True)

        test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('data', train=False, transform=transforms.Compose([
                                                    transforms.ToTensor(),transforms.Normalize((0.2868,), (0.35244,))])),
                                                     batch_size=BATCH_SIZE, shuffle=True)
    elif args.dataset == 'svhn':
        if args.architecture == 'alexnet':
            transforms = Compose([Resize(96),ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        else: 
            transforms = Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])       
        trainset = datasets.SVHN(root='./data_SVHN', split = 'train', download=True, transform=transforms)
        train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=workers, pin_memory=use_gpu,drop_last=True)
        testset = datasets.SVHN(root='./data_SVHN', split = 'test', download=True, transform=transforms)
        test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=workers, pin_memory=use_gpu,drop_last=False)
        net.to(device)
    
    elif args.dataset == 'cifar10':
        if args.architecture == 'alexnet':
            transforms = Compose([Resize(96),ToTensor(),Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
        else:
            transforms = Compose([ToTensor(),Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
        trainset = datasets.CIFAR10(root='./data_cifar10', train=True, transform=transforms, download=True)
        train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=workers, pin_memory=use_gpu,drop_last=True)
        testset = datasets.CIFAR10(root='./data_cifar10', train=False, transform=transforms, download=True)
        test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=workers, pin_memory=use_gpu,drop_last=False)
        net.to(device)
    
    elif args.dataset == 'cifar100':
        if args.architecture == 'alexnet':
            transforms = Compose([Resize(96),ToTensor(),Normalize(mean=[0.5071, 0.4867, 0.4408],std=[0.2675, 0.2565, 0.2761])])
        else:
            transforms = Compose([ToTensor(),Normalize(mean=[0.5071, 0.4867, 0.4408],std=[0.2675, 0.2565, 0.2761])])
        trainset = torchvision.datasets.CIFAR100(root='./data_cifar100/', train=True, transform=transforms, download=True)
        train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=workers, pin_memory=use_gpu,drop_last=True)
        testset = torchvision.datasets.CIFAR100(root='./data_cifar100/', train=False, transform=transforms, download=True)
        test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=workers, pin_memory=use_gpu,drop_last=False)
        net.to(device)
        
    print('Dataset :'+ args.dataset)
    print('Architecture :'+ args.architecture)
    print('Retrain :'+ str(args.retrain))
    print('Method :'+ args.method)
    print('Category :'+ args.category)   

    if args.method == 'classblind':
        _,acc,percentage = compressclassblind(net,train_loader ,test_loader,retrain=args.retrain,category= args.category,architecture = args.architecture)
    if args.method == 'classuniform':
        _,acc,percentage = compressclassunif(net,train_loader ,test_loader,retrain=args.retrain,category= args.category,architecture = args.architecture)
    if args.method == 'classdistribution':
        _,acc,percentage = compressclassdist(net,train_loader ,test_loader,retrain=args.retrain,category= args.category,architecture = args.architecture)
    if args.method == 'tgcomp':
        _,acc,percentage = compress_tg(net,train_loader ,test_loader,criterion,retrain=args.retrain,category= args.category,architecture = args.architecture)

    np.savetxt('./mylist/accuracy.csv',acc,delimiter= ',')
    np.savetxt('./mylist/percentage.csv',percentage,delimiter= ',')
