from genericpath import exists
import itertools
from select import select
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from random import random, sample, seed, shuffle
from matplotlib import cm
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
WN = 2**20
train_accuracy = []
test_accuracy = []
for i in range(32):
    train_ac = torch.load('./accuracy/2153/temp_train{}_{}.pt'.format(i, 32))
    test_ac = torch.load('./accuracy/2153/temp_test{}_{}.pt'.format(i, 32))
    train_accuracy.append(train_ac[int(i * WN / 32): int((i+1)*WN/32)])
    test_accuracy.append(test_ac[int(i * WN / 32): int((i+1)*WN/32)])
train_accuracy = torch.cat(train_accuracy, dim=0)
test_accuracy = torch.cat(test_accuracy, dim=0)

print(train_accuracy.max())
print(test_accuracy.max())
print(train_accuracy.mean())
print(test_accuracy.mean())
print(train_accuracy.ge(67).long().sum())
print(train_accuracy.ge(64).long().sum())
print(train_accuracy.ge(60).long().sum())
sampler = list(range(0, 2**20, 32))
sorted_train_accuracy = train_accuracy.sort(descending=True)
norm = plt.Normalize(sorted_train_accuracy[0][sampler].min(), sorted_train_accuracy[0][sampler].max())

    
map_vir = cm.get_cmap(name='OrRd')
color = map_vir(norm(sorted_train_accuracy[0][sampler]))


plt.title('MNIST Accuracy Distribution')
plt.xlabel('Weight Sorted by Accuracy')
plt.ylabel('Accuracy')
plt.bar(list(range(32768)), sorted_train_accuracy[0][sampler] / 79, color=color)
plt.savefig('./img/mnist_accuracy_distribution3.pdf')
plt.clf()