import numpy as np
import math
import torch
from torch import nn
from torch.nn.functional import cross_entropy
from keras_lightning import KLModel, SparseCategoricalAccuracy
import torch_wnn as wnn
import openml
from sklearn.model_selection import train_test_split
import pickle

from difflogic import LogicLayer

from unoptimized.modules.linear import UnoptimizedLinear
from deepshift.modules import LinearShift
from deepshift.modules_q import LinearShiftQ
from deepshift.convert import convert_to_shift, round_shift_weights, count_layer_type

import sys
import random
import signal
from brevitas.nn import QuantIdentity, QuantLinear
from common import CommonWeightQuant, CommonActQuant
from tensor_norm import TensorNorm

import itertools

N = 10
acc = np.zeros(N)

class FC_BNN(nn.Module):
    def __init__(self, num_inputs, num_classes, intermediate_layers, norm_min, norm_max, dropout=0.2):
        super(FC_BNN, self).__init__()

        self.norm_min = nn.Parameter(norm_min.type(torch.float32), requires_grad=False)
        self.norm_max = nn.Parameter(norm_max.type(torch.float32), requires_grad=False)

        self.features = nn.Sequential()
        self.features.append(QuantIdentity(act_quant=CommonActQuant, bit_width=1))
        self.features.append(nn.Dropout(p=dropout))

        layer_inps = num_inputs
        for layer_outps in intermediate_layers:
            print(layer_inps, layer_outps)
            self.features.append(QuantLinear(
                in_features=layer_inps,
                out_features=layer_outps,
                bias=False,
                weight_bit_width=1,
                weight_quant=CommonWeightQuant))
            self.features.append(nn.BatchNorm1d(num_features=layer_outps))
            self.features.append(QuantIdentity(act_quant=CommonActQuant, bit_width=1))
            self.features.append(nn.Dropout(p=dropout))
            layer_inps = layer_outps
        print(layer_inps, num_classes)
        self.features.append(QuantLinear(
                in_features=layer_inps,
                out_features=num_classes,
                bias=False,
                weight_bit_width=1,
                weight_quant=CommonWeightQuant))
        self.features.append(TensorNorm())

        for m in self.modules():
          if isinstance(m, QuantLinear):
            torch.nn.init.uniform_(m.weight.data, -1, 1)

    def clip_weights(self, min_val, max_val):
        for mod in self.features:
            if isinstance(mod, QuantLinear):
                mod.weight.data.clamp_(min_val, max_val)
    
    def forward(self, x):
        x = x.view(x.shape[0], -1)#.type(torch.float32)
        x = (x - self.norm_min) / (self.norm_max - self.norm_min)
        x = (2 * x) - 1
        return self.features.forward(x)

x_train = np.load('satimage_x_train.npy')
y_train = np.load('satimage_y_train.npy')

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, train_size=0.9)

input_lenght = x_train.size(1)
num_output = int(y_train.max() + 1)

model_size = 9.5 * 8 * 1024

dropouts = (0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)

acc_max = 0
std_max = 0
config = 0, 0 ,0

for dropout in dropouts:

    h = int((model_size*8*1024 - num_output) / (x_train.size(1) + num_output + 1))

    for i in range(N):

        model = KLModel(nn.Sequential(
            FC_BNN(x_train.size(1), num_output, [h], x_train.amin(dim=0), x_train.amax(dim=0), dropout=0.2),
        ))

    optimizer = torch.optim.Adam(model.parameters(), 1e-2)

    model.compile(
        optimizer=optimizer,
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=10),
        loss=cross_entropy,
        metrics={'acc': SparseCategoricalAccuracy()}
    )

    results = model.fit(
        x_train, y_train,
        x_val=x_val, y_test=y_val,
        epochs=30, 
        batch_size=32
    )

    acc[i] = results['test_acc']
    
print(f'\n  Best Config - Mean: {acc.mean()}, Std: {acc.std()}, Config: {config} \n')
###################################################################################################


    


