import torch
from classification_models import FullyConnected, LogReg
from .downstream import downstream_demographic_parity, downstream_prediction
from .logic import implication_violation
from .statistical import expectation, variance, standard_deviation, entropy
from utils import dl2_geq, dl2_neq


def loss_from_constraint_structure(constraint_structure, dataset, syn_data, device=None):
    """

    :param constraint_structure: (dict) A dictionary containing all constraint specifications 
        that are to be enforced on the synthetic data. The dictionary has to have the following
        structure:
            {
                'statistical': [{
                    'operation': (str) The statistical operation to be calculated on the function 
                        applied over the features. Available are: expectation, variance, standard deviation, 
                        and entropy. Note that the entropy does not take a function.
                    'involved_features': (list of str) The list of all feature names involved in the operation.
                    'function': (callable) A function that maps from the space of all involved features 
                        to a real number. We evaluate the statistical operation over this function.
                    'constraint': (tuple(str, float)) A comparative constraint, or inequality, where the left 
                        side is assumed to be the statistical operation and the right side has to be given here. 
                        Some examples: ('>', 5.5), ('==', 10.), ('!=', 2.3), ('<=', 3.4)
                    'optimization_weight': (float) The weight parameter at which stength the produced term will
                        be included in the final optimization of the synthetic data generation model.
                }, ... ],
                'logical': [{
                    'conditionA': (str) The left-side condition of the implication. For now, this can only take
                        two forms: 'feature_name=feature_value' or 'feature_name=NOT feaure_value'.
                    'conditionB': (str) The right-side condition of the implication. For now this can only take
                        two forms: 'feature_name=feature_value' or 'feature_name=NOT feaure_value'.
                    'optimization_weight': (float) The weight parameter at which stength the produced term will
                        be included in the final optimization of the synthetic data generation model.
                }, ...],
                'downstream_accuracy': [{
                    'predict_on_features': (list of str) A list containing the name of all features that we want
                        to consider in our downstream model to make its prediction on.
                    'target_feature': (str) The target variable in the downstream prediction task.
                    'model_name': (str) The name of the model to be used. Available are: 'FullyConnected', and 
                        'LogReg'.
                    'model_architecture': (list) In case of a fully connected model, the architecture can be 
                        specified in a list. The list should have the following structure: 
                        [num_neurons_layer1, ..., num_neurons_layerN, num_output_classes].
                    'optimization_direction': (str) Specify with 'max' or 'min' if the optimization objective is
                        to maxmimize or minimize the downstream accuracy of the model on the given features.
                    'optimization_weight': (float) The weight parameter at which stength the produced term will
                        be included in the final optimization of the synthetic data generation model.
                    'training_specs': (dict) A dictionary containing the training specifications:
                        {'lr': (float) Learning rate, 'batch_size': (int) Batch size, 'num_epochs': Number of epochs}.
                        Note that this can be left as None, in which case we default to 0.01, 512, and 10.
                }, ...],
                'downstream_fairness': [{
                    'protected_feature': (str) The name of the sensitive or protected feature from a fairness 
                        perspective. 
                    'model_name': (str) The name of the model to be used. Available are: 'FullyConnected', and 
                        'LogReg'.
                    'model_architecture': (list) In case of a fully connected model, the architecture can be 
                        specified in a list. The list should have the following structure: 
                        [num_neurons_layer1, ..., num_neurons_layerN, num_output_classes].
                    'fairness_measure': (str) The name of the fairness measure to be applied. For now, only
                        'demographic_parity' is available.
                    'optimization_direction': (str) Specify with 'max' or 'min' if the optimization objective is
                        to maxmimize or minimize the downstream fairness of the model on the given protected feature.
                    'optimization_weight': (float) The weight parameter at which stength the produced term will
                        be included in the final optimization of the synthetic data generation model.
                    'training_specs': (dict) A dictionary containing the training specifications:
                        {'lr': (float) Learning rate, 'batch_size': (int) Batch size, 'num_epochs': Number of epochs}.
                        Note that this can be left as None, in which case we default to 0.01, 512, and 10.
                }, ...]
            }
    :param dataset: (BaseDataset) The instantiated dataset object containing the necessary information for the data.
    :param syn_data: (torch.tensor) The produced one-hot encoded snythetic data that the constraints 
        are to be enforced over.
    :param device: (str) Name of the device on which the tensors are stored and the operations
        shall be conducted.
    :return: (torch.float) The total constraint loss on the current data sample over all specifications.
    """
    if device is None:
        device = syn_data.device

    constraint_loss = torch.tensor([0.], device=device)

    # check if there is no invalid key in the spec
    for key in constraint_structure.keys():
        assert key in ['statistical', 'logical', 'downstream_accuracy', 'downstream_fairness'], f'Invalid spec type: {key}'

    # ---------------- Statistical ---------------- #
    if 'statistical' in constraint_structure:
        for statistical_constraint_specs in constraint_structure['statistical']:
            constraint_loss += parse_statistical_specs(statistical_constraint_specs, dataset, syn_data, device)

    # ------------------ Logical ------------------ #
    if 'logical' in constraint_structure:
        for logical_constraint_specs in constraint_structure['logical']:
            constraint_loss += parse_logical_specs(logical_constraint_specs, dataset, syn_data, device)

    # ------------ Downstream Accuracy ------------ #
    if 'downstream_accuracy' in constraint_structure:
        for downstream_accuracy_constraint_specs in constraint_structure['downstream_accuracy']:
            constraint_loss += parse_downstream_accuracy_specs(downstream_accuracy_constraint_specs, dataset, syn_data, device)

    # ------------ Downstream Fairness ------------ #  
    if 'downstream_fairness' in constraint_structure:
        for downstream_fairness_constraint_specs in constraint_structure['downstream_fairness']:
            constraint_loss += parse_downstream_fairness_specs(downstream_fairness_constraint_specs, dataset, syn_data, device)
    
    return constraint_loss


def parse_statistical_specs(specs, dataset, syn_data, device=None):
    """

    :param specs:
    :param dataset:
    :param syn_data:
    :param device:
    :return:
    """
    if device is None:
        device = syn_data.device

    operations = {
        'expectation': expectation,
        'variance': variance,
        'standard_deviation': standard_deviation,
        'entropy': entropy
    }

    # calculate the current value of the given operation over the given function
    op_f = operations[specs['operation']](
        data=syn_data,
        dataset=dataset,
        features=specs['involved_features'],
        function=specs['function']
    )

    # now, parse the constraint -- TODO: ridiculously simple, but will need a better one
    # using DL2 primitives
    operation, target = specs['constraint']
    alpha = specs['optimization_weight']
    if operation == '>=':
        violation_score = alpha * dl2_geq(op_f, target)
    elif operation == '<=':
        violation_score = alpha * dl2_geq(target, op_f)
    elif operation == '>':
        violation_score = alpha * (dl2_geq(op_f, target) + dl2_neq(op_f, target))
    elif operation == '<':
        violation_score = alpha * (dl2_geq(target, op_f) + dl2_neq(op_f, target))
    elif operation == '==':
        violation_score = alpha * (target - op_f).abs()  # L1 norm, but other would owrk as well
    elif operation == '!=':
        violation_score = alpha * dl2_neq(op_f, target)
    else:
        raise ValueError(f'Unknown constraint')

    return violation_score


def parse_logical_specs(specs, dataset, syn_data, device=None):
    """

    :param specs:
    :param dataset:
    :param syn_data:
    :param device:
    :return:
    """
    if device is None:
        device = syn_data.device

    implication_violation_score = specs['optimization_weight'] * implication_violation(
        data=syn_data, 
        dataset=dataset, 
        condition_A=specs['conditionA'], 
        condition_B=specs['conditionB']
    ).sum()

    return implication_violation_score


def parse_downstream_accuracy_specs(specs, dataset, syn_data, device=None):
    """

    :param specs:
    :param dataset:
    :param syn_data:
    :param device:
    :return:
    """
    if device is None:
        device = syn_data.device

    if specs['training_specs'] is None:
        training_specs = {'lr': 0.01, 'batch_size': 512, 'num_epochs': 10}
    else:
        training_specs = specs['training_specs']

    # prepare the data
    true_one_hot_train = dataset.get_Dtrain_full_one_hot(return_torch=True).to(device)
    X_train = torch.cat([syn_data[:, dataset.full_one_hot_index_map[feature]] for feature in specs['predict_on_features']], axis=1)
    y_train = syn_data[:, dataset.full_one_hot_index_map[specs['target_feature']]][:, -1].long()  # TODO only binary for now
    X_eval = torch.cat([true_one_hot_train[:, dataset.full_one_hot_index_map[feature]] for feature in specs['predict_on_features']], axis=1)
    y_eval = true_one_hot_train[:, dataset.full_one_hot_index_map[specs['target_feature']]][:, -1].long()  # TODO only binary for now

    # instantiate the model
    num_classes = len(dataset.features[specs['target_feature']])
    if specs['model_name'] == 'LogReg':
        model = LogReg(X_train.size()[1], num_classes).to(device)
    elif specs['model_name'] == 'FullyConnected':
        architecture = specs['model_architecture']
        architecture[-1] = num_classes  # make sure that we make no mistake at the output
        model = FullyConnected(X_train.size()[1], architecture).to(device)
    else:
        raise NotImplementedError('Downstream Model: Only LogReg and FullyConnected models are implemented')

    # train and record the evaluation loss
    alpha = -1. * specs['optimization_weight'] if specs['optimization_direction'] == 'max' else specs['optimization_weight']
    downstream_prediction_score = downstream_prediction(
        X_train=X_train, 
        y_train=y_train, 
        X_eval=X_eval, 
        y_eval=y_eval, 
        model=model, 
        lr=training_specs['lr'], 
        batch_size=training_specs['batch_size'], 
        num_epochs=training_specs['num_epochs']
    )

    return downstream_prediction_score
    

def parse_downstream_fairness_specs(specs, dataset, syn_data, device=None):
    """

    :param specs:
    :param dataset:
    :param syn_data:
    :param device:
    :return:
    """
    if device is None:
        device = syn_data.device

    if specs['training_specs'] is None:
        training_specs = {'lr': 0.01, 'batch_size': 512, 'num_epochs': 10}
    else:
        training_specs = specs['training_specs']

    # prepare the data both for training and eval -- TODO this now only makes sense/works for binary classification
    # where the label column is the last in the data table
    true_one_hot_train = dataset.get_Dtrain_full_one_hot(return_torch=True).to(device)
    X_train, y_train = syn_data[:, :-2], syn_data[:, -1].long()
    X_eval, y_eval = true_one_hot_train[:, :-2], true_one_hot_train[:, -1].long()

    # instantiate the model
    num_classes = len(dataset.features[dataset.label])
    if specs['model_name'] == 'LogReg':
        model = LogReg(X_train.size()[1], num_classes).to(device)
    elif specs['model_name'] == 'FullyConnected':
        architecture = specs['model_architecture']
        architecture[-1] = num_classes  # make sure that we make no mistake at the output
        model = FullyConnected(X_train.size()[1], architecture).to(device)
    else:
        raise NotImplementedError('Downstream Model: Only LogReg and FullyConnected models are implemented')

    # calculate the desired fairness measure -- note that at the moment only demographic parity is available
    if specs['fairness_measure'] == 'demographic_parity':
        alpha = -1. * specs['optimization_weight'] if specs['optimization_direction'] == 'max' else specs['optimization_weight']
        fairness_score = alpha * downstream_demographic_parity(
            X_train=X_train,
            y_train=y_train, 
            X_eval=X_eval, 
            y_eval=y_eval, 
            model=model, 
            dataset=dataset, 
            protected_feature=specs['protected_feature'], 
            target_feature=dataset.label, 
            lr=training_specs['lr'], 
            batch_size=training_specs['batch_size'], 
            num_epochs=training_specs['num_epochs']
        )
    else:
        raise NotImplementedError('Fairness Measure: Only Demographic Parity is implemented')

    return fairness_score
    