from .debias import *
import torch
from .model import normalize, denormalize


class Baseline:
    def __init__(self, debias, hparams, stats, run=True):
        super().__init__()
        self.debias = debias
        self.params = hparams
        self.run= run

        # get stats
        for key in stats:
            setattr(self, '{}_stats_moment'.format(key), torch.stack([stats[key].mean, stats[key].std]).numpy())
            setattr(self, '{}_stats_minmax'.format(key), torch.stack([stats[key].min, stats[key].max]).numpy())

        if hasattr(self, 'target_stats_moment'):
            self.mode = 'r'
        else:
            self.mode = 'c'
        
        # condition for different debiasing methods
        if debias ==  'reduction-sp':
            self.model = Reduction(constraint='sp', params=hparams[0])

        elif debias == 'reduction-eo':
            self.model = Reduction(constraint='eo', params=hparams[0])
        
        elif debias == 'caleqodds':
            self.model = CalEqOdds()
        
        elif debias == 'eqodds':
            self.model = EqOdds()
        
        elif debias == 'rejection-sp':
            self.model = Rejection(constraint='sp', params=hparams[0])
        
        elif debias == 'rejection-eo':
            self.model = Rejection(constraint='eo', params=hparams[0])
        
        elif debias == 'fairproj-kl-eo':
            self.model = FairProjKL(constraint='eo', params=hparams[0])
        
        elif debias == 'fairproj-kl-sp':
            self.model = FairProjKL(constraint='sp', params=hparams[0])
        
        elif debias == 'fairproj-ce-eo':
            self.model = FairProjCE(constraint='eo', params=hparams[0])
        
        elif debias == 'fairproj-ce-sp':
            self.model = FairProjCE(constraint='sp', params=hparams[0])
        
        elif debias == 'postprocess-sp':
            self.model = PostProcessClf(constraint='sp', params=hparams[0])
        
        elif debias == 'postprocess-eo':
            self.model = PostProcessClf(constraint='eo', params=hparams[0])
        
        elif debias == 'postprocess-reg':
            self.model = PostProcessReg(params=hparams[0])

        elif debias == 'fairreg':
            self.model = FairReg(params=hparams[0])

        elif debias == 'wasserstein':
            self.model = Wasserstein()
        
        elif debias == 'none':
            self.model = Vanilla()
        
        else:
            raise NotImplementedError('module not implemented, choose from {reduction-sp, reduction-eo, caleqodds, eqodds, rejection-sp, rejection-eo, none}')

    
    def __call__(self, input):
        output = {}
        x = input['data']
        # normalize data
        x = normalize(x, self.data_stats_moment[0], self.data_stats_moment[1])
   
        # min-max normalization for target if regression
        if self.mode == 'r':
            input['target'] = normalize(input['target'], self.target_stats_minmax[0],
                                        self.target_stats_minmax[1] - self.target_stats_minmax[0])

        x['target'] = input['target']
        x['sensitive'] = input['sensitive']
        target = self.model.debias(x, train=self.run)
        if target.dtype == np.float32:
            output['target'] = torch.tensor(target)
        else:
            output['target'] = torch.tensor(target).view(-1)

        #  convert to torch tensor for evaluation
        input['target'] = torch.tensor(input['target']).view(-1)
        input['sensitive'] = torch.tensor(input['sensitive']).view(-1)
        # denormalize target if regression
        if self.mode == 'r':
            output['loss'] = torch.nn.functional.mse_loss(output['target'], input['target'], reduction='mean')
            input['target'] = denormalize(input['target'], self.target_stats_minmax[0],
                                        self.target_stats_minmax[1] - self.target_stats_minmax[0])
            output['target'] = denormalize(output['target'], self.target_stats_minmax[0],
                                        self.target_stats_minmax[1] - self.target_stats_minmax[0])
        return output
    
    def train(self, run):
        self.run = run

    def load_state_dict(self, state_dict):
        self.model.debias_model = state_dict.model.debias_model
        if hasattr(self.model, 'base_model'):
            self.model.base_model = state_dict.model.base_model


def baseline(cfg):
    if cfg['hparams'] == 'none':
        hparams = None
    else:
        hparams = [float(s) for s in cfg['hparams'].split('-')]
    
    stats = cfg['stats']
    debias = cfg['debias']
    model = Baseline(debias=debias, hparams=hparams, stats=stats)
    return model


    