import numpy as np
import math
import torch
from torch import nn
from torch.nn.functional import cross_entropy
from keras_lightning import KLModel, SparseCategoricalAccuracy
import torch_wnn as wnn
import openml
from sklearn.model_selection import train_test_split
import pickle

from difflogic import LogicLayer

from unoptimized.modules.linear import UnoptimizedLinear
from deepshift.modules import LinearShift
from deepshift.modules_q import LinearShiftQ
from deepshift.convert import convert_to_shift, round_shift_weights, count_layer_type

import sys
import random
import signal
from brevitas.nn import QuantIdentity, QuantLinear
from common import CommonWeightQuant, CommonActQuant
from tensor_norm import TensorNorm

import itertools

N = 10
acc = np.zeros(N)

class Temperature(nn.Module):
    def __init__(self, tau=1):
        super().__init__()
        self.tau = tau

    def forward(self, x):
        return x / self.tau

x_train = np.load('mnist_x_train.npy')
y_train = np.load('mnist_y_train.npy')

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, train_size=0.9)

num_bits = 24
thresholds = wnn.binarization.distributive_thresholds(x_train, num_bits, individual=True)
x_train_b = wnn.binarization.apply_thresholds(x_train, thresholds).flatten(start_dim=1)
x_val_b = wnn.binarization.apply_thresholds(x_val, thresholds).flatten(start_dim=1)

input_lenght = x_train_b.size(1)
num_output = int(y_train.max() + 1)

model_size = 98 * 8 * 1024

tuple_lenghts = range(2, 28)
num_keyss = range(1, 5)
filter_tuple_lenghts = range(2, 9)
dropouts = (0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)

acc_max = 0
std_max = 0
config = 0, 0 ,0

for tuple_lenght, num_keys, filter_tuple_lenght, dropout in itertools.product(tuple_lenghts, num_keyss, filter_tuple_lenghts, dropouts):

    if filter_tuple_lenght > tuple_lenght:
        continue

    num_models = math.floor(model_size / (math.ceil(input_lenght / tuple_lenght)*2**filter_tuple_lenght*num_output))

    if num_models == 0:
        continue

    for i in range(N):

        model = KLModel(nn.Sequential(
            wnn.MultiDoMaFilter2(
                input_lenght=input_lenght,
                tuple_lenght=tuple_lenght, 
                num_keys=num_keys, 
                filter_tuple_lenght=filter_tuple_lenght, 
                num_output=num_output, 
                num_models=num_models,
                dropout=dropout,
        )))

        optimizer = torch.optim.Adam(model.parameters(), 1e-2)

        model.compile(
            optimizer=optimizer,
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=10),
            loss=cross_entropy,
            metrics={'acc': SparseCategoricalAccuracy()}
        )

        results = model.fit(
            x_train_b, y_train,
            x_val=x_val_b, y_test=y_val,
            epochs=30, 
            batch_size=32
        )

        acc[i] = results['test_acc']

    
print(f'\n  Best Config - Mean: {acc.mean()}, Std: {acc.std()}, Config: {config} \n')
###################################################################################################


    


