#!/usr/bin/env python
import re
import pickle
import sys
import argparse
import random
import torch
import math
import numpy as np
import logging
import copy

from models.models import Module
from core.metrics import Metrics

from core.numpy_dataset import NumpyDataset
from learner.optimize_gradient_descent_learner import (
    OptimizeGradientDescentLearner)

from sklearn.metrics import zero_one_loss


###############################################################################

class NNLearner():

    def __init__(
        self, model, model_kwargs, batch_size, lr, device, load=None, 
    ):
        self.model_kwargs = model_kwargs
        self.batch_size = batch_size
        self.lr = lr

        self.device = torch.device('cpu')
        if(torch.cuda.is_available() and device != "cpu"):
            self.device = torch.device(device)

        self.model = Module(model, self.device, self.model_kwargs)
        self.model.to(self.device)
        
        self.load = load

        # ------------------------------------------------------------------- #

        self.__loss = Metrics("BoundedCrossEntropyLoss", self.model).fit
        self.loss = self.__loss.fit
        
        self.optim = torch.optim.SGD(
            self.__loss.parameters(), lr=self.lr)

        self.learner = OptimizeGradientDescentLearner(
            self.model, self.loss, self.optim, self.device, batch_size=self.batch_size)

        # ------------------------------------------------------------------- #

    def fit(self, X, y):

        logging.info("Learning ...\n")
        self.learner.fit(X, y)

    def save(self):
        return {"state_dict": self.learner.save()}

    def load(self, load_dict):
        return self.learner.load(load_dict["state_dict"])

    def predict(self, X, init_keep=True):
        if(self.step == "prior"):
            model = self.prior_learner.model
        else:
            model = self.post_learner.model

        data = NumpyDataset({"x_test": X})
        data.set_mode("test")
        if(self.batch_size is None):
            self.batch_size = len(X)
        loader = torch.utils.data.DataLoader(
            data, batch_size=self.batch_size)

        pred = None
        for i, batch in enumerate(loader):

            batch["x"] = batch["x"].to(
                device=self.device, dtype=torch.float32)

            batch["keep"] = True
            if(i == 0):
                batch["keep"] = init_keep

            model(batch)
            if(pred is None):
                pred = model.pred.cpu().detach().numpy()
            else:
                pred = np.concatenate(
                    (pred, model.pred.cpu().detach().numpy()))
        return pred[:, 0]

    def output(self, X, init_keep=True):
        if(self.step == "prior"):
            model = self.prior_learner.model
        else:
            model = self.post_learner.model

        data = NumpyDataset({"x_test": X})
        data.set_mode("test")
        if(self.batch_size is None):
            self.batch_size = len(X)
        loader = torch.utils.data.DataLoader(
            data, batch_size=self.batch_size)

        output = None
        for i, batch in enumerate(loader):

            batch["keep"] = True
            if(i == 0):
                batch["keep"] = init_keep

            batch["x"] = batch["x"].to(
                device=self.device, dtype=torch.float32)
            model(batch)
            if(output is None):
                output = model.out.cpu().detach().numpy()
            else:
                output = np.concatenate(
                    (output, model.out.cpu().detach().numpy()))
        return output[:, 0]

    def predict_proba(self, X):
        raise NotImplementedError

###############################################################################
