from copy import deepcopy
from pickle import TRUE
import numpy as np
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, log_loss
from torch.autograd import Variable
from torch.nn.init import kaiming_uniform_
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from dvutils.data import Custom_Dataset
from sklearn.kernel_ridge import KernelRidge

class Model_Train:
    def __init__(
        self,
        model_fn=None,
        optimizer_fn=None,
        loss_fn=None,
        lr=0.01,
        batch_size=64,
        epochs=6,
        device=None,
        # allow any additional arguments to be passed in
        **kwargs):

        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        self.model = model_fn().to(device)
        self.optimizer = optimizer_fn(self.model.parameters(), lr=lr)
        self.loss_fn = loss_fn()
        self.lr = lr
        self.batch_size = batch_size
        self.epochs = epochs
        self.device = device
        # backup for the initial model weight and optimizer
        self.init_model_weight = deepcopy(self.model.state_dict())
        self.optimizer_fn = optimizer_fn
    
    def fit_e(self, data, label, verbose=False, epochs=None, *args):
        self.restart_model()
        if epochs == None:
            epochs = self.epochs
        if len(data) == 0:
            return self.model
        data_loader = DataLoader(Custom_Dataset(data, label, device=self.device, return_idx=False), batch_size=self.batch_size, shuffle=False)
        for e in range(epochs):
            self.model.train()
            # running local epochs
            for batch_idx, batch in enumerate(data_loader):
                data, label = batch[0].to(self.device), batch[1].to(self.device)
                self.optimizer.zero_grad()
                pred = self.model(data)
                self.loss_fn(pred, label).backward()
                self.optimizer.step()
        return self.model
    
    def evaluate_e(self, data, label):
        if len(data) == 0:
            return 0, 0
        data_loader = DataLoader(Custom_Dataset(data, label, device=self.device, return_idx=False), batch_size=self.batch_size, shuffle=False)
        self.model.eval()
        total = 0
        loss = 0
        y_true = []
        y_predict = []

        with torch.no_grad():
            for i, batch in enumerate(data_loader):

                batch_data, batch_target = batch[0], batch[1]

                batch_data, batch_target = batch_data.to(self.device), batch_target.to(self.device)
                outputs = self.model(batch_data)

                loss += self.loss_fn(outputs, batch_target)
                total += len(batch_target)

                y_true.extend(list(batch_target.data.tolist()))
                if batch_target.dtype == torch.long:
                    y_predict.extend(list(torch.max(outputs, 1)[1].view(batch_target.size()).data.tolist()))
                elif batch_target.dtype == torch.float32:
                    y_predict.extend(list(outputs.view(batch_target.size()).data.tolist()))
            if batch_target.dtype == torch.long:
                accuracy =  accuracy_score(y_pred=y_predict, y_true=y_true)
            elif batch_target.dtype == torch.float32:
                accuracy = 0
            loss /= total

        return loss.cpu(), accuracy
    
    
    def fit(self, train_loader, val_loader, verbose=False, epochs=None):
        if epochs == None:
            epochs = self.epochs

        best_loss = torch.inf
        best_model = self.model.state_dict()
        for e in range(epochs):
            self.model.train()
            
            # running local epochs
            for batch_idx, batch in enumerate(train_loader):
                data, label = batch[0].to(self.device), batch[1].to(self.device)
                self.optimizer.zero_grad()
                pred = self.model(data)
                self.loss_fn(pred, label).backward()
                self.optimizer.step()
            
            # evaluate validation performance
            loss, accuracy = self.evaluate(val_loader)
            if verbose:
                print('Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}'.format(e, loss, accuracy))
            if best_loss > loss:
                best_loss = loss
                best_model = self.model.state_dict()
        
        self.model.load_state_dict(best_model)
        return self.model

    def evaluate(self,eval_loader):
        self.model.eval()
        total = 0
        loss = 0
        y_true = []
        y_predict = []

        with torch.no_grad():
            for i, batch in enumerate(eval_loader):

                batch_data, batch_target = batch[0], batch[1]

                batch_data, batch_target = batch_data.to(self.device), batch_target.to(self.device)
                outputs = self.model(batch_data)

                loss += self.loss_fn(outputs, batch_target)
                total += len(batch_target)

                y_true.extend(list(batch_target.data.tolist()))
                y_predict.extend(list(torch.max(outputs, 1)[1].view(batch_target.size()).data.tolist()))
            accuracy =  accuracy_score(y_pred=y_predict, y_true=y_true)
            loss /= total

        return loss.cpu(), accuracy

    def restart_model(self):
        self.model.load_state_dict(deepcopy(self.init_model_weight))
        self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr)


class Model_Train_SKLearn:
    def __init__(
        self,
        model_fn=None,
        optimizer_fn=None,
        loss_fn=None,
        lr=0.01,
        batch_size=64,
        epochs=6,
        device=None,
        train_dataset=None,
        val_dataset=None,
        val_data=None,
        val_label=None,
        train_data=None,
        train_label=None):
        self.model_fn = model_fn
        self.batch_size = batch_size
        if self.model_fn == KernelRidge:
            self.model = self.model_fn(kernel='rbf')
        elif self.model_fn == LogisticRegression:
            self.model = self.model_fn(solver='liblinear', multi_class='auto')
        # self.model = self.model_fn(multi_class='auto')
        self.val_loader = None
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.val_data = val_data
        self.val_label = val_label
        self.train_data = train_data
        self.train_label = train_label
        self.train_status = 'trained'


    def fit_e(self, data, label, verbose=False, *args):
        
        if len(data) == 0:
            self.train_status = 'no_data'
            return self.model
        # data, label = self.train_data[indices], self.train_label[indices]
        
        if self.model_fn == LogisticRegression:
            if len(np.unique(label)) == 1:
                self.train_status = 'one_class'+str(np.unique(label).item())
                return self.model
        
        self.model.fit(data, label)
        self.train_status = 'trained'
        
        return self.model

    def fit(self, train_loader, val_loader, verbose=False, *args):
        if self.train_dataset == None:
            self.train_dataset = train_loader.dataset
            loader_ = DataLoader(self.train_dataset, batch_size=100, shuffle=False)
            self.train_data, self.train_label = np.concatenate([item[0].cpu().numpy() for item in loader_]).reshape(-1, 28*28), np.concatenate([item[1].cpu().numpy() for item in loader_]).reshape(-1)
            self.train_label = np.array([str(tmp) for tmp in self.train_label])
        
        if type(train_loader.sampler) is torch.utils.data.SequentialSampler:
            indices_ = np.arange(len(self.train_data))
        else:
            indices_ = train_loader.sampler.indices
        # print(len(self.train_dataset))
        data, label = self.train_data[indices_], self.train_label[indices_]
        # print(len(data))
        # len_data = len(data)
        # print("data")
        # print(data[-1][:100])
        # print(label[-1])
        # print(indices_[-10:])
        # np.savez('len_data.npz', data=data, label=label, indices_=indices_)'
        
        if len(indices_) == 0:
            self.train_status = 'no_data'
            return self.model
        
        if len(np.unique(label)) == 1:
            self.train_status = 'one_class'+str(np.unique(label).item())
            return self.model
        
        self.model.fit(data, label)
        self.train_status = 'trained'
        
        if verbose:
            # evaluate validation performance
            loss, accuracy = self.evaluate(val_loader)
            print('Loss: {:.4f}, Accuracy: {:.4f}'.format(loss, accuracy))

        return self.model

    def evaluate_e(self, data, label):
        # data, label = self.val_data, self.val_label
        if self.model_fn == LogisticRegression:
            if self.train_status == 'no_data':
                label_pred_proba = np.ones([len(label), 10]) / 10
            elif 'one_class' in self.train_status:
                label_pred_proba = np.zeros([len(label), 10])
                label_pred_proba[:, int(self.train_status[-1])] = 1
            else:
                label_pred_proba = self.model.predict_proba(data)
                class_ = np.array([int(float(tmp)) for tmp in self.model.classes_])
                if len(class_) < 10:
                    label_pred_proba_ = np.zeros([len(label_pred_proba), 10])
                    label_pred_proba_[:, class_] = label_pred_proba
                    label_pred_proba = label_pred_proba_
                    
            label_pred = [str(tmp) for tmp in np.argmax(label_pred_proba,axis=1)]
            accuracy =  accuracy_score(y_pred=label_pred, y_true=label)
            loss = log_loss(label, label_pred_proba)
        elif self.model_fn == KernelRidge:
            # regression task
            if self.train_status == 'no_data':
                label_pred = np.zeros(len(label))
            else:
                label_pred = self.model.predict(data)
            # use mse loss
            loss = np.mean((label_pred - label)**2)
            accuracy = 0

        return loss, accuracy


    def evaluate(self,val_loader):
        if self.val_loader is not val_loader:
            self.val_data, self.val_label = np.concatenate([item[0].cpu().numpy() for item in val_loader]).reshape(-1, 28*28), np.concatenate([item[1].cpu().numpy() for item in val_loader]).reshape(-1)
            self.val_label = np.array([str(tmp) for tmp in self.val_label])
            self.val_loader = val_loader
        
        data, label = self.val_data, self.val_label
        
        if self.train_status == 'no_data':
            label_pred_proba = np.ones([len(self.val_label), 10]) / 10
        elif 'one_class' in self.train_status:
            label_pred_proba = np.zeros([len(self.val_label), 10])
            label_pred_proba[:, int(self.train_status[-1])] = 1
        else:
            label_pred_proba = self.model.predict_proba(data)
            class_ = np.array([int(float(tmp)) for tmp in self.model.classes_])
            if len(class_) < 10:
                label_pred_proba_ = np.zeros([len(label_pred_proba), 10])
                label_pred_proba_[:, class_] = label_pred_proba
                label_pred_proba = label_pred_proba_
                
        label_pred = [str(tmp) for tmp in np.argmax(label_pred_proba,axis=1)]
        accuracy =  accuracy_score(y_pred=label_pred, y_true=label)
        loss = log_loss(label, label_pred_proba)

        return loss, accuracy

    def restart_model(self):
        pass

# for MNIST 32*32
class CNN_Net(nn.Module):

    def __init__(self, device=None):
        super(CNN_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, 1)
        self.conv2 = nn.Conv2d(64, 16, 7, 1)
        self.fc1 = nn.Linear(4 * 4 * 16, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view(-1, 1, 32, 32)
        x = torch.tanh(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = torch.tanh(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 16)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# for MNIST 32*32 LogReg
class MNIST_LogisticRegression(nn.Module):

    def __init__(self, input_dim=1024, output_dim=10, device=None):
        super(MNIST_LogisticRegression, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = torch.nn.Linear(self.input_dim, self.output_dim)

    def forward(self, x):
        x = x.view(-1,  self.input_dim)
        outputs = self.linear(x)
        return F.sigmoid(outputs, dim=1)

# for MNIST 32*32
class MLP_Net(nn.Module):

    def __init__(self, device=None):
        super(MLP_Net, self).__init__()
        self.fc1 = nn.Linear(1024, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1,  1024)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# for MNIST 32*32
class MLP_MNIST(nn.Module):

    def __init__(self, device=None):
        super(MLP_MNIST, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self, x):
        x = x.view(-1,  784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x, dim=1)
    
    def get_feature(self, x):
        x = x.view(-1,  784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return x
        
class MLP_MNIST_S(nn.Module):

    def __init__(self, device=None):
        super(MLP_MNIST_S, self).__init__()
        self.fc1 = nn.Linear(784, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = x.view(-1,  784)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)

    def get_feature(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        return x

class MLP_R(nn.Module):

    def __init__(self, device=None):
        super(MLP_R, self).__init__()
        self.fc1 = nn.Linear(8, 32)
        self.fc2 = nn.Linear(32, 1)

    def forward(self, x):
        x = x.view(-1,  8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x.reshape([-1])

    def get_feature(self, x):
        x = x.view(-1, 8)
        x = torch.relu(self.fc1(x))
        return x

# class LogisticRegression(nn.Module):

#     def __init__(self, input_dim=86, output_dim=2, device=None):
#         super(LogisticRegression, self).__init__()
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.linear = torch.nn.Linear(self.input_dim, self.output_dim)

#     def forward(self, x):
#         outputs = self.linear(x)
#         return outputs

class Flower_LR(nn.Module):

    def __init__(self, device=None):
        super(Flower_LR, self).__init__()
        self.input_dim = 2048
        self.output_dim = 5
        self.linear = torch.nn.Linear(self.input_dim, self.output_dim)

    def forward(self, x):
        outputs = self.linear(x)
        return F.log_softmax(outputs, dim=1)

class MLP(nn.Module):

    def __init__(self, input_dim=86, output_dim=2, device=None):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 32)
        self.fc2 = nn.Linear(32, output_dim)

        # self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# For names language classification
class RNN(nn.Module):

    def __init__(self, input_size=57, output_size=7, hidden_size=64, device=None):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        self.device = device

    def forward(self, line_tensors):
        return torch.cat([self.forward_one_tensor(line_tensor) for line_tensor in line_tensors])

    def forward_one_tensor(self, line_tensor):
        hidden = self.initHidden()
        for i in range(line_tensor.size()[0]):
            if line_tensor[i][0] != -1: # ignore the padded -1 at the end
                output, hidden = self.forward_once(line_tensor[i].view(1,-1), hidden)
        return output

    def forward_once(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size).to(self.device)

# For time series predication
class RNN_TS(nn.Module):

    def __init__(self, input_size=5, output_size=1, hidden_size=10, device=None):
        super(RNN_TS, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        # self.relu = nn.ReLU()
        self.device = device

    def forward(self, line_tensors):
        return torch.cat([self.forward_one_tensor(line_tensor) for line_tensor in line_tensors]).reshape([-1])

    def forward_one_tensor(self, line_tensor):
        hidden = self.initHidden()
        for i in range(line_tensor.size()[0]):
            if line_tensor[i][0] != -1: # ignore the padded -1 at the end
                output, hidden = self.forward_once(line_tensor[i].view(1,-1), hidden)
        return output

    def forward_once(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = torch.sigmoid(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size).to(self.device)

# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
# LeNet
class CNNCifar(nn.Module):
    def __init__(self, device=None):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# https://www.tensorflow.org/tutorials/images/cnn
class CNNCifar_TF(nn.Module):
    def __init__(self, device=None):
        super(CNNCifar_TF, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 64, 3)
        # self.bn1 = nn.BatchNorm2d(32)
        # self.bn2 = nn.BatchNorm2d(64)
        # self.bn3 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 64)
        self.fc2 = nn.Linear(64, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

class CNNCifar_10(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10, device=None):
        super(CNNCifar_10, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5)
        self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNNCIFAR10_SMALL(nn.Module):
    def __init__(self, in_channels=3, n_kernels=32, out_dim=10, device=None):
        super(CNNCIFAR10_SMALL, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, n_kernels, 3)
        self.fc1 = nn.Linear(n_kernels * 6 * 6, 50)
        self.fc2 = nn.Linear(50, 32)
        self.fc3 = nn.Linear(32, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        # print(x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.shape)
        x = x.view(x.shape[0], -1)
        # print(x.shape)
        x = F.relu(self.fc1(x))
        # print(x.shape)
        x = F.relu(self.fc2(x))
        # print(x.shape)
        x = self.fc3(x)
        # print(x.shape)
        return F.softmax(x, dim=1)

    def get_feature(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        # print(x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.shape)
        x = x.view(x.shape[0], -1)
        # print(x.shape)
        x = F.relu(self.fc1(x))
        # # print(x.shape)
        x = F.relu(self.fc2(x))
        # print(x.shape)
        return x

class CNNCifar_100(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=100, device=None):
        super(CNNCifar_100, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5)
        self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN_Cifar100_BN(nn.Module):
    """CNN."""
    # https://zhenye-na.github.io/2018/09/28/pytorch-cnn-cifar10.html
    def __init__(self, device=None):
        """CNN Builder."""
        super(CNN_Cifar100_BN, self).__init__()

        self.conv_layer = nn.Sequential(

            # Conv Layer block 1
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            # nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv Layer block 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            # nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=0.05),

            # Conv Layer block 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            # nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )


        self.fc_layer = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, 100)
        )


    def forward(self, x):
        """Perform forward."""
        
        # conv layers
        x = self.conv_layer(x)
        
        # flatten
        x = x.view(x.size(0), -1)
        
        # fc layer
        x = self.fc_layer(x)

        return F.log_softmax(x, dim=1)


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, block=BasicBlock, num_blocks=[2, 2, 2, 2], num_classes=10, device=None):
        super(ResNet18, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        # return out
        return F.log_softmax(out, dim=1)


from torchvision import models


class ResNet18_torch(nn.Module):
    def __init__(self, pretrained=False, device=None):
        super().__init__()
        self.resnet = models.resnet18(pretrained=pretrained)

        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 100)  # make the change

        self.resnet.conv1 = torch.nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.resnet.maxpool = torch.nn.Identity()

    def forward(self, x):
        x = self.resnet(x)
        x = F.log_softmax(x, dim=1)
        return x


class AlexNet(nn.Module):
    def __init__(self, n_class=100, device=None):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=8, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=1),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, n_class),
        )
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x



cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, device=None):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def VGG11(device=None):
    return VGG('VGG11',device=device)


def VGG13(device=None):
    return VGG('VGG13',device=device)


def VGG16(device=None):
    return VGG('VGG16',device=device)


def VGG19(device=None):
    return VGG('VGG19',device=device)

class CNN_Text(nn.Module):
    
    def __init__(self, args=None, device=None):
        super(CNN_Text,self).__init__()

        
        self.args = args
        self.device = device
        
        V = args['embed_num']
        D = args['embed_dim']
        C = args['class_num']
        Ci = 1
        Co = args['kernel_num']
        Ks = args['kernel_sizes']

        self.embed = nn.Embedding(V, D)
        self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks])
        '''
        self.conv13 = nn.Conv2d(Ci, Co, (3, D))
        self.conv14 = nn.Conv2d(Ci, Co, (4, D))
        self.conv15 = nn.Conv2d(Ci, Co, (5, D))
        '''
        self.dropout = nn.Dropout(0.5)
        # self.dropout = nn.Dropout(args.dropout)
        self.fc1 = nn.Linear(len(Ks)*Co, C)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3) #(N,Co,W)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x


    def forward(self, x):

        x = self.embed(x) # (W,N,D)
        # x = x.permute(1,0,2) # -> (N,W,D)
        # permute during loading the batches instead of in the forward function
        # in order to allow nn.DataParallel

        if not self.args or self.args['static']:
            x = Variable(x).to(self.device)

        x = x.unsqueeze(1) # (W,Ci,N,D)

        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
        x = torch.cat(x, 1)
        '''
        x1 = self.conv_and_pool(x,self.conv13) #(N,Co)
        x2 = self.conv_and_pool(x,self.conv14) #(N,Co)
        x3 = self.conv_and_pool(x,self.conv15) #(N,Co)
        x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co)
        '''
        x = self.dropout(x) # (N,len(Ks)*Co)
        logit = self.fc1(x) # (N,C)
        return F.log_softmax(logit, dim=1)
        # return logit

# Sentiment analysis : binary classification
class RNN_IMDB(nn.Module):
    # def __init__(self, embed_num, embed_dim, output_dim, pad_idx):
    def __init__(self, args=None, device=None):
        super(RNN_IMDB, self).__init__()

        self.args = args
        self.device = device
        embed_num = args.embed_num
        embed_dim = args.embed_dim
        output_dim = args.class_num
        pad_idx = args.pad_idx
        
        self.embedding = nn.Embedding(embed_num, embed_dim, padding_idx=pad_idx)
        
        self.fc = nn.Linear(embed_dim, output_dim)
        
    def forward(self, text):
        
        #text = [sent len, batch size]
        embedded = self.embedding(text)		
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) 
        
        #pooled = [batch size, embed_dim]
        return F.log_softmax(self.fc(pooled), dim=1)


class DQN(nn.Module):

    def __init__(self, c, h, w, outputs, device=None):
        super(DQN, self).__init__()
        self.device=device
        self.conv1 = nn.Conv2d(c, 32, kernel_size=8, stride=4)
        kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        kaiming_uniform_(self.conv3.weight, mode='fan_in', nonlinearity='relu')

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size = 5, stride = 2):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w, 8, 4),4,2),3,1)
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h, 8, 4),4,2),3,1)
        linear_input_size = convw * convh * 64

        self.hidden = nn.Linear(linear_input_size, 512)
        self.head = nn.Linear(512, outputs)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = x.to(self.device) / 255
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.hidden(x.view(x.size(0), -1)))
        return self.head(x)

# for MNIST 32*32
class MLP_Net(nn.Module):

    def __init__(self, device=None):
        super(MLP_Net, self).__init__()
        self.fc1 = nn.Linear(1024, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1,  1024)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor

from torch import nn
from abc import abstractmethod

class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: torch.tensor) -> List[torch.tensor]:
        raise NotImplementedError

    def decode(self, input: torch.tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> torch.tensor:
        raise NotImplementedError

    def generate(self, x: torch.tensor, **kwargs) -> torch.tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: torch.tensor) -> torch.tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> torch.tensor:
        pass
    

class VanillaVAE(BaseVAE):


    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 1,
                                      kernel_size= 5, padding= 0),
                            nn.Tanh())

    def encode(self, input: torch.tensor) -> List[torch.tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (torch.tensor) Input torch.tensor to encoder [N x C x H x W]
        :return: (torch.tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: torch.tensor) -> torch.tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (torch.tensor) [B x D]
        :return: (torch.tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 1, 1)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: torch.tensor, logvar: torch.tensor) -> torch.tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (torch.tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (torch.tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (torch.tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: torch.tensor, **kwargs) -> List[torch.tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> torch.tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (torch.tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: torch.tensor, **kwargs) -> torch.tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (torch.tensor) [B x C x H x W]
        :return: (torch.tensor) [B x C x H x W]
        """

        return self.forward(x)[0]
    
model_dict = {"MLP": [MLP_MNIST, Model_Train], "MLP_S": [MLP_MNIST_S, Model_Train], "MLP_R": [MLP_R, Model_Train], "CNN": [CNNCIFAR10_SMALL, Model_Train], "KR": [KernelRidge, Model_Train_SKLearn], "Logistic": [LogisticRegression, Model_Train_SKLearn], "Logistic_torch": [MNIST_LogisticRegression , Model_Train]}