from genericpath import exists
import itertools
from select import select
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from random import random, seed, shuffle
WN = 2**20
WB = 20
def matrix2bitstring(matrix  : torch.Tensor):
    return matrix[0][0] + 2 * matrix[0][1] + 4 * matrix[0][2] + \
            8 * matrix[1][0] + 16 * matrix[1][1] + 32 * matrix[1][2] + \
                64 * matrix[2][0] + 128 * matrix[2][1] + 256 * matrix[2][2]
def bitstring2matrix(bitstring : int) -> torch.Tensor:
    matrix = torch.zeros((3,3))
    bitstring = list(bin(bitstring)[2:].rjust(9, '0'))
    matrix[2][2] = int(bitstring[0] == '1')
    matrix[2][1] = int(bitstring[1] == '1')
    matrix[2][0] = int(bitstring[2] == '1')
    matrix[1][2] = int(bitstring[3] == '1')
    matrix[1][1] = int(bitstring[4] == '1')
    matrix[1][0] = int(bitstring[5] == '1')
    matrix[0][2] = int(bitstring[6] == '1')
    matrix[0][1] = int(bitstring[7] == '1')
    matrix[0][0] = int(bitstring[8] == '1')
    return matrix
    


def load(label: list):
    
    train0 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[0]))
    train0_label = torch.Tensor([1 for _ in range(train0.shape[0])]).long()
    train1 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[1]))
    train1_label = torch.Tensor([2 for _ in range(train1.shape[0])]).long()
    train2 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[2]))
    train2_label = torch.Tensor([3 for _ in range(train2.shape[0])]).long()
    train3 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[3]))
    train3_label = torch.Tensor([4 for _ in range(train3.shape[0])]).long()
    train_data = torch.cat([train0, train1, train2, train3], dim=0).long()
    train_label = torch.cat([train0_label, train1_label, train2_label, train3_label], dim=0).long()

    test0 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[0]))
    test0_label = torch.Tensor([1 for _ in range(test0.shape[0])]).long()
    test1 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[1]))
    test1_label = torch.Tensor([2 for _ in range(test1.shape[0])]).long()
    test2 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[2]))
    test2_label = torch.Tensor([3 for _ in range(test2.shape[0])]).long()
    test3 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[3]))
    test3_label = torch.Tensor([4 for _ in range(test3.shape[0])]).long()
    test_data = torch.cat([test0, test1, test2, test3], dim=0).long()
    test_label = torch.cat([test0_label, test1_label, test2_label, test3_label], dim=0).long()

    reduce_table_train = torch.zeros((512, 5))
    reduce_table_test = torch.zeros((512, 5))
    for i in range(train_data.shape[0]):
        bits = matrix2bitstring(train_data[i])
        reduce_table_train[bits][train_label[i]] += 1
    for i in range(test_data.shape[0]):
        bits = matrix2bitstring(test_data[i])
        reduce_table_test[bits][test_label[i]] += 1
    
    reduce_train_label = reduce_table_train.argmax(dim=1)
    reduce_test_label = reduce_table_test.argmax(dim=1)
    
    for i in range(reduce_train_label.shape[0]):
        if reduce_table_train[i].ge(reduce_table_train[i][reduce_train_label[i]]).long().sum() > 1:
            reduce_train_label[i] = 0
    
    for i in range(reduce_test_label.shape[0]):
        if reduce_table_test[i].ge(reduce_table_test[i][reduce_test_label[i]]).long().sum() > 1:
            reduce_test_label[i] = 0
    train_sample_num = reduce_train_label.gt(0).long().sum().item()
    test_sample_num = reduce_test_label.gt(0).long().sum().item()

    train_data_output = torch.zeros((train_sample_num, 3, 3)).long()
    train_label_output = torch.zeros((train_sample_num)).long()

    test_data_output = torch.zeros((test_sample_num, 3, 3)).long()
    test_label_output = torch.zeros((test_sample_num)).long()

    train_data_num = 0
    for i in range(reduce_train_label.shape[0]):
        if reduce_train_label[i] != 0:
            train_data_output[train_data_num] = bitstring2matrix(i)
            train_label_output[train_data_num] = reduce_train_label[i] - 1
            train_data_num += 1
    test_data_num = 0
    for i in range(reduce_test_label.shape[0]):
        if reduce_test_label[i] != 0:
            test_data_output[test_data_num] = bitstring2matrix(i)
            test_label_output[test_data_num] = reduce_test_label[i] - 1
            test_data_num += 1
    return train_data_output.long(), train_label_output.long(), test_data_output.long(), test_label_output.long()


        
train_data, train_label, test_data, test_label = load([0, 2, 7, 8])

print(train_label.shape)

print(test_label.shape)


def model(data : torch.Tensor, w : list):
    m1 = int((data[0][0] ^ w[0]) * (data[0][1] ^ w[1]) * (data[0][2] ^ w[2]) + \
            (data[1][0] ^ w[3]) * (data[1][1] ^ w[4]) * (data[1][2] ^ w[5]) + \
                (data[2][0] ^ w[6]) * (data[2][1] ^ w[7]) * (data[2][2] ^ w[8]) >= 1)^w[9]
    m2 =  int((data[0][0] ^ w[10]) * (data[1][0] ^ w[11]) * (data[2][0] ^ w[12]) + \
            (data[0][1] ^ w[13]) * (data[1][1] ^ w[14]) * (data[2][1] ^ w[15]) + \
                (data[0][2] ^ w[16]) * (data[1][2] ^ w[17]) * (data[2][2] ^ w[18]) >=1)^w[19]
    # o1 = (m1 ^ w[8]) * (m2 ^ w[9])
    # o2 = (m1 ^ w[10]) * (m2 ^ w[11])
    return int(m1 + 2 * m2)

def calculate():
    # if os.path.exists('./accuracy/train.pt'):
    #     train_accuracy = torch.load('./accuracy/train.pt')
    #     test_accuracy = torch.load('./accuracy/test.pt')
    #     return train_accuracy, test_accuracy
    W = [i for i in range(WN)]
    Weights = []
    for w in W:
        w = list(bin(w)[2:].rjust(WB, '0'))
        for i in range(len(w)):
            w[i] =  int((w[i] == '1'))
        Weights.append(w)
    train_accuracy = torch.zeros((WN))
    test_accuracy = torch.zeros((WN))

    for x in tqdm(range(int(WN/4),int(WN/2)),total=int(WN/4)):
        w = Weights[x]
        for i in range(train_data.shape[0]):
            o = model(train_data[i], w)
            if o == train_label[i]:
                train_accuracy[x] += 1
        for j in range(test_data.shape[0]):
            o = model(test_data[j], w)
            if o == test_label[j]:
                test_accuracy[x] += 1
    torch.save(train_accuracy,'./accuracy/train_1.pt')
    torch.save(test_accuracy,'./accuracy/test_1.pt')
    print(train_accuracy.max())
    print(test_accuracy.max())
    return train_accuracy, test_accuracy
calculate()