import numpy as np
import math
import torch
from torch import nn
from torch.nn.functional import cross_entropy
from keras_lightning import KLModel, SparseCategoricalAccuracy
import torch_wnn as wnn
import openml
from sklearn.model_selection import train_test_split
import pickle

from difflogic import LogicLayer

from unoptimized.modules.linear import UnoptimizedLinear
from deepshift.modules import LinearShift
from deepshift.modules_q import LinearShiftQ
from deepshift.convert import convert_to_shift, round_shift_weights, count_layer_type

import sys
import random
import signal
from brevitas.nn import QuantIdentity, QuantLinear
from common import CommonWeightQuant, CommonActQuant
from tensor_norm import TensorNorm

import itertools

N = 1000
acc = np.zeros(N)

num_bits = 24
thresholds = wnn.binarization.distributive_thresholds(x_train, num_bits, individual=True)
x_train_b = wnn.binarization.apply_thresholds(x_train, thresholds).flatten(start_dim=1)
x_val_b = wnn.binarization.apply_thresholds(x_val, thresholds).flatten(start_dim=1)

x_train = np.load('mnist_x_train.npy')
y_train = np.load('mnist_y_train.npy')
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, train_size=0.9)

input_lenght = x_train.size(1)
num_output = int(y_train.max() + 1)

model_size = 98 * 8 * 1024

dropout = 0.

class GroupSum(nn.Module):
    def __init__(self, k, tau=1):
        super().__init__()
        self.k = k
        self.tau = tau

    def forward(self, x):
        if not (x.size(-1) % self.k == 0):
            pad = torch.zeros(len(x.shape), dtype=torch.int)
            pad[-1] = self.k - (x.size(-1) % self.k)
            pad = tuple(pad.numpy().tolist())
            x = torch.nn.functional.pad(x, pad)
        return x.view(*x.shape[:-1], self.k, int(x.shape[-1]/self.k)).sum(dim=-1) / self.tau

h = int((model_size - 32*num_output) / (5*(x_train.size(1) + num_output) + 32))

for i in range(N):

    h = int(model_size*8*1024/4)

    for i in range(N):

        model = KLModel(nn.Sequential(
            LogicLayer(x_train_b.size(1), h),
            GroupSum(k=num_output, tau=1/0.1)
        ))

    optimizer = torch.optim.Adam(model.parameters(), 1e-2)

    model.compile(
        optimizer=optimizer,
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=10),
        loss=cross_entropy,
        metrics={'acc': SparseCategoricalAccuracy()}
    )

    results = model.fit(
        x_train, y_train,
        x_val=x_val, y_test=y_val,
        epochs=30, 
        batch_size=32
    )

    acc[i] = results['test_acc']

print(f'\n Mean: {acc.mean()}, Std: {acc.std()} \n')
###################################################################################################


    


