import torch
import torch.nn as nn
import numpy as np
from sklearn.linear_model import LogisticRegression
import lightgbm as lgb

def get_algorithm_class(algorithm_name):
    '''Return the algorithm class with the given name.'''
    if algorithm_name not in globals():
        raise NotImplementedError('Algorithm not found: {}'.format(algorithm_name))
    return globals()[algorithm_name]

class Clf_Algorithm(torch.nn.Module):
    '''
    A subclass of Algorithm implements a domain adaptation algorithm.
    Subclasses should implement the update() method.
    '''

    def __init__(self, configs,
                train_X, train_y,
                validate_X, validate_y,
                train_weight, validate_weight,
                seed, num_threads):
        super(Clf_Algorithm, self).__init__()
        self.configs = configs
        self.seed_ = seed
        self.num_threads_ = num_threads
        self.model = None
        self.model_non_weights = None
        self.all_train_X_ = np.concatenate([train_X, validate_X], axis=0)
        self.all_train_y_ = np.concatenate([train_y, validate_y], axis=0)
        self.all_weight_ = np.concatenate([train_weight, validate_weight], 
                                          axis=0)
    # def update(self, *args, **kwargs):
    #     raise NotImplementedError
    
    def fit(self, hparams, *args, **kwargs):
        raise NotImplementedError
    
    def fit_all_data(self, hparams, *args, **kwargs):
        raise NotImplementedError

    def predict(self, x):
        pred = self.model.predict(x)
        pred_non_weights = self.model_non_weights.predict(x)
        return pred, pred_non_weights


class Linear_Classifier(Clf_Algorithm):
    def __init__(self, configs, 
                train_X, train_y,
                validate_X, validate_y,
                train_weight, validate_weight,
                seed, num_threads):
        super(Linear_Classifier, self).__init__(
                configs, 
                train_X, train_y,
                validate_X, validate_y,
                train_weight, validate_weight,
                seed, num_threads)

    def fit(self, hparams, *args, **kwargs):
        self.model = LogisticRegression(
                penalty='elasticnet',
                solver='saga',
                l1_ratio = hparams['l1_ratio'],
                C= hparams['lambda'],
                random_state=self.seed_)
        self.model_non_weights = LogisticRegression(
                penalty='elasticnet',
                solver='saga',
                l1_ratio = hparams['l1_ratio'],
                C= hparams['lambda'],
                random_state=self.seed_)
        self.model.fit(X=self.all_train_X_, y=self.all_train_y_,
                       sample_weight=self.all_weight_)
        self.model_non_weights.fit(X=self.all_train_X_,
                                   y=self.all_train_y_,)

    def fit_all_data(self, hparams, *args, **kwargs):
        self.fit(hparams, args, kwargs)

class GDBT_Classifier(Clf_Algorithm):
    def __init__(self,
                configs, 
                train_X, train_y,
                validate_X, validate_y,
                train_weight, validate_weight,             
                seed, num_threads):
        super(GDBT_Classifier, self).__init__(
            configs,
            train_X, train_y,
            validate_X, validate_y,
            train_weight, validate_weight,
            seed, num_threads)
        self.feature_name_ = [f'col{i}' for i in range(train_X.shape[1])]
        self.train_data_ = lgb.Dataset(
            data=train_X, 
            label=train_y, 
            weight=train_weight,
            feature_name=self.feature_name_,
            free_raw_data=False)
        self.validation_data_ = lgb.Dataset(
            data=validate_X, 
            label=validate_y, 
            weight=validate_weight,
            feature_name=self.feature_name_,
            free_raw_data=False)
        self.train_data_non_weights_ = lgb.Dataset(
            data=train_X, 
            label=train_y, 
            feature_name=self.feature_name_,
            free_raw_data=False)
        self.validation_data_weights_ = lgb.Dataset(
            data=validate_X, 
            label=validate_y, 
            feature_name=self.feature_name_,
            free_raw_data=False)
        self.all_train_data_ = lgb.Dataset(
            data=self.all_train_X_, 
            label=self.all_train_y_, 
            weight=self.all_weight_,
            feature_name=self.feature_name_,
            free_raw_data=False
        )
        self.all_train_data_non_weights_ = lgb.Dataset(
            data=self.all_train_X_, 
            label=self.all_train_y_, 
            feature_name=self.feature_name_,
            free_raw_data=False
        )

    def fit(self, hparams, *args, **kwargs):
        _params = {
            'objective': 'binary',
            'verbosity': -1,
            'lambda_l1': hparams['lambda_l1'],
            'lambda_l2': hparams['lambda_l2'],
            'num_leaves': hparams['num_leaves'],
            'learning_rate': hparams['learning_rate'],
            'feature_fraction': hparams['feature_fraction'],
            'bagging_fraction':  hparams['bagging_fraction'],
            'num_threads': self.num_threads_,
            'num_iterations': 250,
            'early_stopping_round': 5, 
            'seed': self.seed_,
        }
        self.model = lgb.train(_params,
                               train_set=self.train_data_,
                               valid_sets=self.validation_data_)
        self.model_non_weights = lgb.train(_params,
                               train_set=self.train_data_non_weights_,
                               valid_sets=self.validation_data_weights_)
    def fit_all_data(self, hparams, *args, **kwargs):
        _params = {
            'objective': 'binary',
            'verbosity': -1,
            'lambda_l1': hparams['lambda_l1'],
            'lambda_l2': hparams['lambda_l2'],
            'num_leaves': hparams['num_leaves'],
            'learning_rate': hparams['learning_rate'],
            'feature_fraction': hparams['feature_fraction'],
            'bagging_fraction': hparams['bagging_fraction'],
            'num_threads': self.num_threads_,
            'num_iterations': 250,
            'seed': self.seed_,
            }                      
        self.model = lgb.train(_params,
                            train_set=self.all_train_data_)
        self.model_non_weights = lgb.train(_params,
                            train_set=self.all_train_data_non_weights_)
            
class Base_Feature_Extractor(torch.nn.Module):
    '''
    A subclass of Algorithm implements a domain adaptation algorithm.
    Subclasses should implement the update() method.
    '''

    def __init__(self, dataset_configs):
        super(Base_Feature_Extractor, self).__init__()
        self.dataset_configs = dataset_configs
        self.feature_extractor = None

    def feature_extract(self, x, *args, **kwargs):
        raise self.feature_extractor(x)

    def forward(self, x):
        feat = self.feature_extractor(x)
        return feat


class Feature_Extractor(Base_Feature_Extractor):
    def __init__(self, backbone_fe, dataset_configs):
        super(Feature_Extractor, self).__init__(dataset_configs)

        self.feature_extractor = backbone_fe(dataset_configs)


class Classifier(nn.Module):
    """Classifier model for head of feature extraction"""

    def __init__(self, configs):
        """Init discriminator."""
        super(Classifier, self).__init__()
        self.hid_dim = int(configs.final_out_channels / 2)
        #self.hid_dim = configs.final_out_channels


        self.layer = nn.Sequential(
            nn.Linear(configs.final_out_channels, self.hid_dim),
            nn.ReLU(),
            nn.Linear(self.hid_dim, self.hid_dim),
            nn.ReLU(),
            nn.Linear(self.hid_dim, 2)
        )

    def forward(self, input):
        """Forward the classifier."""
        out = self.layer(input)
        return out