import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from opacus import PrivacyEngine
from tqdm import tqdm
import timm
import numpy as np
import matplotlib.pyplot as plt

def ECEMCE(y_te, y_te_single,targets_te):
    y_te_prob = np.max(y_te,axis=1) # predicted prob
    n_partitions = 50
    idxs = {i:[] for i in range(n_partitions)}
    for idx, prob in enumerate(y_te_prob):
        idxs[min(int(prob * n_partitions), n_partitions-1)].append(idx)

    CEs = [0 for _ in range(n_partitions)]
    for i, idx_lst in enumerate(idxs.values()):
        if idx_lst:
            idx_lst = np.array(idx_lst)
            accuracy = np.mean(y_te_single[idx_lst] == targets_te[idx_lst])
            conf = np.mean(y_te_prob[idx_lst])
            CEs[i] = np.abs(accuracy - conf)

    ECE = np.sum([len(idx_lst) * CE for idx_lst, CE in zip(idxs.values(), CEs)]) / len(y_te)

    MCE = np.max(CEs)
    return ECE, MCE

def main(args):
    
    transform = transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.mini_batch_size,
                                              shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=50,
                                             shuffle=False, num_workers=2)
        
    
    net = timm.create_model(args.model,pretrained=True,num_classes=10).cuda()
    
    for name,param in net.named_parameters():
        if 'cls_token' in name or 'pos_embed' in name:
            param.requires_grad=False
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr)
    

    privacy_engine=PrivacyEngine(net,
        sample_rate=args.batch_size/50000,
        epochs=args.epochs,
        max_grad_norm=args.R,
        target_epsilon=args.epsilon,
        target_delta=1e-5,
        )
    privacy_engine.attach(optimizer)
    
    
    train_loss=[]
    test_loss=[]
    test_acc=[]
    CalibrationError=[]
    for epoch in range(args.epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(tqdm(trainloader)):
            net.train()
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs=inputs.cuda(); labels=labels.cuda()
    
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            grad_acc_steps=args.batch_size//args.mini_batch_size
            if ((i + 1) % grad_acc_steps == 0) or ((i + 1) == len(trainloader)):
                optimizer.step()
                # zero the parameter gradients
                optimizer.zero_grad()
                train_loss.append(loss.mean().item())
            else:
                optimizer.virtual_step()

            if ((i + 1)/2 % grad_acc_steps == 0) or ((i + 1) == len(trainloader)):
                net.eval()
                correct = 0
                total = 0
                running_loss = 0.0
        
                y_local_te=np.empty((0,10))
                y_local_te_single=np.array([])
                targets_te_local=np.array([])
                prob_local=np.array([])
        
                with torch.no_grad():
                    for i,data in enumerate(tqdm(testloader)):
                        images, labels = data
                        images=images.cuda(); labels=labels.cuda()
            
                        # calculate outputs by running images through the network
                        outputs = net(images)
                        loss = criterion(outputs, labels)
                        # the class with the highest energy is what we choose as prediction
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
        
                        running_loss += loss.mean().item()
        
        
                        y_local_te=np.append(y_local_te,F.softmax(outputs,dim=1).cpu().detach().numpy(),axis=0)
                        y_local_te_single=np.append(y_local_te_single,F.softmax(outputs,dim=1).argmax(dim=1).cpu().detach().numpy())
                        prob_local=np.append(prob_local,np.array([F.softmax(outputs,dim=1).cpu().detach().numpy()[i,j] for i,j in enumerate(labels.cpu().detach().numpy())]))
                        targets_te_local=np.append(targets_te_local,labels.cpu().detach().numpy())
                
                try:
                    CalibrationError.append(ECEMCE(y_local_te, y_local_te_single,targets_te_local))
                    test_loss.append(running_loss/(i+1))
                    test_acc.append(correct/total)
                except:
                    pass
        
        print(f'R={args.R}, lr={args.lr}, loss={running_loss/(i+1)}, {CalibrationError[-1]},Accuracy of the network on the 10000 test images: {100 * correct / total} %')
    
    torch.save([train_loss,test_loss,test_acc,CalibrationError],f'flat_R{args.R}_lr{args.lr}.pt')    
    
    store_local=[]
    y_te_prob=np.max(y_local_te,axis=1)
    dr=0.025
    for ratio in np.arange(0,1,dr):
        index=(y_te_prob>ratio)*(y_te_prob<ratio+dr)
        store_local.append(np.mean(y_local_te_single[index]==targets_te_local[index]))
    
    
    # some has nan, replace with 0
    store_local=np.array(store_local)
    store_local[np.isnan(store_local)]=0
    xaxis=np.arange(0,1,dr)+dr/2
    plt.bar(xaxis[store_local!=0],store_local[store_local!=0],width=dr, edgecolor="black")
    plt.bar(xaxis[store_local!=0],xaxis[store_local!=0]-store_local[store_local!=0],bottom=store_local[store_local!=0],width=dr, edgecolor="black",color='yellow',alpha=0.5)
    plt.bar(10,10,bottom=store_local,width=dr, edgecolor="black",color='green',alpha=0.5)
    plt.plot([0,1],[0,1],color='black',linewidth=2,linestyle='dashed')
    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.xlabel('Confidence', fontsize=15)
    plt.ylabel('Accuracy', fontsize=15)
    plt.legend(('Perfect calibration','Accuracy','Over-confidence','Under-confidence'),fontsize=13,frameon=0)
    #plt.savefig(args.save,format='pdf')
    plt.show()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
    parser.add_argument('--lr', default=0.0015, type=float, help='learning rate')
    parser.add_argument('--epochs', default=1, type=int,
                        help='numter of epochs')
    parser.add_argument('--batch_size', default=1000, type=int, help='logical batch size')
    parser.add_argument('--mini_batch_size', default=100, type=int, help='logical batch size')
    parser.add_argument('--epsilon', default=2, type=float, help='target epsilon')
    parser.add_argument('--R', type=float)
    parser.add_argument('--model', default='vit_base_patch16_224', type=str, help='per-sample clipping')
    parser.add_argument('--clip_function', default='vanilla', type=str, help='per-sample clipping')
    parser.add_argument('--save', default='temp1.pdf', type=str)
    
    args = parser.parse_args()
    torch.manual_seed(1)
    main(args)