from genericpath import exists
import itertools
from select import select
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from random import random, seed, shuffle
WN = 2**20
WB = 20
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-p', type=int, help='number of process')
parser.add_argument('-n', type=int, help='id of process')
parser.add_argument('-f', type=str, help='storage')
args = parser.parse_args()
print(args.n)
print(args.p)
A = int(args.n * WN  / args.p)
B = int((args.n+1) * WN  / args.p)
N = int(WN  / args.p)
def matrix2bitstring(matrix  : torch.Tensor):
    return matrix[0][0] + 2 * matrix[0][1] + 4 * matrix[0][2] + \
            8 * matrix[1][0] + 16 * matrix[1][1] + 32 * matrix[1][2] + \
                64 * matrix[2][0] + 128 * matrix[2][1] + 256 * matrix[2][2]
def bitstring2matrix(bitstring : int) -> torch.Tensor:
    matrix = torch.zeros((3,3))
    bitstring = list(bin(bitstring)[2:].rjust(9, '0'))
    matrix[2][2] = int(bitstring[0] == '1')
    matrix[2][1] = int(bitstring[1] == '1')
    matrix[2][0] = int(bitstring[2] == '1')
    matrix[1][2] = int(bitstring[3] == '1')
    matrix[1][1] = int(bitstring[4] == '1')
    matrix[1][0] = int(bitstring[5] == '1')
    matrix[0][2] = int(bitstring[6] == '1')
    matrix[0][1] = int(bitstring[7] == '1')
    matrix[0][0] = int(bitstring[8] == '1')
    return matrix
    


def load(label: list):
    
    train0 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[0]))
    train0_label = torch.Tensor([1 for _ in range(train0.shape[0])]).long()
    train1 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[1]))
    train1_label = torch.Tensor([2 for _ in range(train1.shape[0])]).long()
    train2 : torch.Tensor = torch.load('./data/train/samples_{}.pt'.format(label[2]))
    train2_label = torch.Tensor([3 for _ in range(train2.shape[0])]).long()
    
    train_data = torch.cat([train0, train1, train2], dim=0).long()
    train_label = torch.cat([train0_label, train1_label, train2_label], dim=0).long()

    test0 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[0]))
    test0_label = torch.Tensor([1 for _ in range(test0.shape[0])]).long()
    test1 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[1]))
    test1_label = torch.Tensor([2 for _ in range(test1.shape[0])]).long()
    test2 : torch.Tensor = torch.load('./data/test/samples_{}.pt'.format(label[2]))
    test2_label = torch.Tensor([3 for _ in range(test2.shape[0])]).long()
    
    test_data = torch.cat([test0, test1, test2], dim=0).long()
    test_label = torch.cat([test0_label, test1_label, test2_label], dim=0).long()

    reduce_table_train = torch.zeros((512, 4))
    reduce_table_test = torch.zeros((512, 4))
    for i in range(train_data.shape[0]):
        bits = matrix2bitstring(train_data[i])
        reduce_table_train[bits][train_label[i]] += 1
    for i in range(test_data.shape[0]):
        bits = matrix2bitstring(test_data[i])
        reduce_table_test[bits][test_label[i]] += 1
    
    reduce_train_label = reduce_table_train.argmax(dim=1)
    reduce_test_label = reduce_table_test.argmax(dim=1)
    
    for i in range(reduce_train_label.shape[0]):
        if reduce_table_train[i].ge(reduce_table_train[i][reduce_train_label[i]]).long().sum() > 1:
            reduce_train_label[i] = 0
    
    for i in range(reduce_test_label.shape[0]):
        if reduce_table_test[i].ge(reduce_table_test[i][reduce_test_label[i]]).long().sum() > 1:
            reduce_test_label[i] = 0
    train_sample_num = reduce_train_label.gt(0).long().sum().item()
    test_sample_num = reduce_test_label.gt(0).long().sum().item()

    train_data_output = torch.zeros((train_sample_num, 3, 3)).long()
    train_label_output = torch.zeros((train_sample_num)).long()

    test_data_output = torch.zeros((test_sample_num, 3, 3)).long()
    test_label_output = torch.zeros((test_sample_num)).long()

    train_data_num = 0
    for i in range(reduce_train_label.shape[0]):
        if reduce_train_label[i] != 0:
            train_data_output[train_data_num] = bitstring2matrix(i)
            train_label_output[train_data_num] = reduce_train_label[i] - 1
            train_data_num += 1
    test_data_num = 0
    for i in range(reduce_test_label.shape[0]):
        if reduce_test_label[i] != 0:
            test_data_output[test_data_num] = bitstring2matrix(i)
            test_label_output[test_data_num] = reduce_test_label[i] - 1
            test_data_num += 1
    return train_data_output.long(), train_label_output.long(), test_data_output.long(), test_label_output.long()


        
train_data, train_label, test_data, test_label = load([1, 2, 7])

print(train_label.eq(0).long().sum())
print(train_label.eq(1).long().sum())
print(train_label.eq(2).long().sum())
print(train_label.shape)
print(test_label.shape)


def model(data : torch.Tensor, w : list):
    
    m1 = int((data[0][0] * w[0]) + (data[0][1] * w[1]) + (data[0][2] * w[2]) + \
            (data[1][0] * w[3]) + (data[1][1] * w[4]) + (data[1][2] * w[5]) + \
                (data[2][0] * w[6]) + (data[2][1] * w[7]) + (data[2][2] * w[8]) >= 1)^w[9]
    m2 =  int((data[0][0] * w[10]) + (data[1][0] * w[11]) + (data[2][0] * w[12]) + \
            (data[0][1] * w[13]) + (data[1][1] * w[14]) + (data[2][1] * w[15]) + \
                (data[0][2] * w[16]) + (data[1][2] * w[17]) + (data[2][2] * w[18]) >=1)^w[19]
    # o1 = (m1 ^ w[8]) * (m2 ^ w[9])
    # o2 = (m1 ^ w[10]) * (m2 ^ w[11])
    return m1, m2

def calculate():
    # if os.path.exists('./accuracy/train.pt'):
    #     train_accuracy = torch.load('./accuracy/train.pt')
    #     test_accuracy = torch.load('./accuracy/test.pt')
    #     return train_accuracy, test_accuracy
    W = [i for i in range(WN)]
    Weights = []
    for w in W:
        w = list(bin(w)[2:].rjust(WB, '0'))
        for i in range(len(w)):
            w[i] =  int((w[i] == '1'))
        Weights.append(w)
    train_accuracy = torch.zeros((WN))
    test_accuracy = torch.zeros((WN))
    if args.n == 0:
        for x in tqdm(range(A, B),total=N):
            w = Weights[x]
            for i in range(train_data.shape[0]):
                o1, o2 = model(train_data[i], w)
                if (train_label[i] == 0 and o1 == 1) or (train_label[i] == 1 and o1==0 and o2 == 1) or (train_label[i] == 2 and o1==0 and o2==0):
                    train_accuracy[x] += 1
            for j in range(test_data.shape[0]):
                o1, o2 = model(test_data[j], w)
                if (test_label[j] == 0 and o1 == 1) or (test_label[j] == 1 and o1==0 and o2 == 1) or (test_label[j] == 2 and o1==0 and o2==0):
                    test_accuracy[x] += 1
        torch.save(train_accuracy,'./accuracy/2153/temp_train{}_{}.pt'.format(args.n, args.p))
        torch.save(test_accuracy,'./accuracy/2153/temp_test{}_{}.pt'.format(args.n, args.p))
        f = open('./accuracy/2153/temp_{}_{}.log'.format(args.n, args.p),'w')
        f.write(str(train_accuracy.max().item()))
        
        f.write('\n')
        id = train_accuracy.argmax()
        f.write(str(test_accuracy.max().item()))
        f.write('\n')
        f.write(str(test_accuracy[id].item()))
        f.close()
        return train_accuracy, test_accuracy
    else :
        for x in range(A, B):
            w = Weights[x]
            for i in range(train_data.shape[0]):
                o1, o2 = model(train_data[i], w)
                if (train_label[i] == 0 and o1 == 1) or (train_label[i] == 1 and o1==0 and o2 == 1) or (train_label[i] == 2 and o1==0 and o2==0):
                    train_accuracy[x] += 1
            for j in range(test_data.shape[0]):
                o1, o2 = model(test_data[j], w)
                if (test_label[j] == 0 and o1 == 1) or (test_label[j] == 1 and o1==0 and o2 == 1) or (test_label[j] == 2 and o1==0 and o2==0):
                    test_accuracy[x] += 1
        torch.save(train_accuracy,'./accuracy/2153/temp_train{}_{}.pt'.format(args.n, args.p))
        torch.save(test_accuracy,'./accuracy/2153/temp_test{}_{}.pt'.format(args.n, args.p))
        f = open('./accuracy/2153/temp_{}_{}.log'.format(args.n, args.p),'w')
        f.write(str(train_accuracy.max().item()))
        f.write('\n')
        id = train_accuracy.argmax()
        f.write(str(test_accuracy.max().item()))
        f.write('\n')
        f.write(str(test_accuracy[id].item()))
        f.close()
        return train_accuracy, test_accuracy
calculate()

