import torch
import numpy as np
import pickle
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

init_dict = {'kaiming_uniform' : torch.nn.init.kaiming_uniform_,
             'kaiming_normal'  : torch.nn.init.kaiming_normal_,
             'xavier_uniform'  : torch.nn.init.xavier_uniform_,
             'xavier_normal'   : torch.nn.init.xavier_normal_,
             'uniform'         : torch.nn.init.uniform_,
             'normal'          : torch.nn.init.normal_}

class Feedforward(torch.nn.Module):
    def __init__(self, input_size, hidden_arr=[], we_init=None):
        super(Feedforward, self).__init__()
        self.input_size = input_size
        if we_init is not None:
            self.init_func = init_dict[we_init]
        self.relu = torch.nn.ReLU()

        prev = self.input_size
        self.fca = []
        for ele in hidden_arr:
            self.fca.append(torch.nn.Linear(prev, ele))
            if we_init is not None:
                self.init_func(self.fca[-1].weight)
            prev = ele

        self.fca = torch.nn.ModuleList(self.fca)

        self.fcout = torch.nn.Linear(prev, 2)
        if we_init is not None:
            self.init_func(self.fcout.weight)

    def forward(self, x):

        for fcele in self.fca:
            x = fcele(x)
            x = self.relu(x)

        out = self.fcout(x)
        return out

def get_mlp_model(input_size, hidden_layers, we_init, ckpt=None, cuda=True, dataloader=None):
    if ckpt is not None:
        model = torch.load(ckpt)
    elif we_init in init_dict:
        model = Feedforward(input_size, hidden_layers, we_init)

    if cuda:
        model = model.cuda()
    return model

def get_dtc_model(ckpt=None):
    if ckpt is not None:
        clf = pickle.load(open(ckpt, 'rb'))
    else:
        clf = DecisionTreeClassifier(random_state=np.random.randint(10000))
    return clf

def get_lr_model(ckpt=None):
    if ckpt is not None:
        clf = pickle.load(open(ckpt, 'rb'))
    else:
        clf = LogisticRegression(random_state=np.random.randint(10000), solver='sag')
    return clf

def get_model(arch_id, input_size, hidden_layers, we_init, ckpt=None, cuda=True, dataloader=None):
    if 'mlp' in arch_id:
        model = get_mlp_model(input_size, hidden_layers, we_init, ckpt=ckpt, cuda=cuda, dataloader=dataloader)
    elif 'dtc' in arch_id:
        model = get_dtc_model(ckpt=ckpt)
    elif 'lr' in arch_id:
        model = get_lr_model(ckpt=ckpt)

    return model

def flatten_weights(model):
    flat_arr = []
    for name, param in model.named_parameters():
        flat_arr.extend(param.cpu().detach().numpy().flatten())

    return flat_arr
