__author__ = ''
__date__ = '2023/07/04'

'''
Fairness-aware restricted generated projection.
'''

from networks.main import build_networks
from optim.trainer import Trainer
from data.main import load_dataset

class FRGP(object):
    
    def __init__(self, dataset_name, net_name, data_path, optimizer_name: str, lr: float, epochs: int, batch_size: int, device: str, results_dir: str, 
                        print=None, in_channels=None, _lambda=None, latent_dimension=None, Fairway=None, stop_threshold=None, entropy_reg_coe=None, 
                        beta=None, balanced=False, threshold_ratio=None, fair_c_ratio=None, af_name=None):

        self.ae_net = build_networks(net_name, in_channels=in_channels, mid_dim=latent_dimension, af_name=af_name)

        begin_epoch = 1
        end_epoch = begin_epoch + epochs

        self.dataset =  load_dataset(dataset_name, data_path, balanced=balanced, fair_c_ratio=fair_c_ratio)

        self.trainer = Trainer(optimizer_name, lr=lr, begin_epoch=begin_epoch, dataset_name=dataset_name, end_epoch=end_epoch, batch_size=batch_size, device=device, print=print, results_dir=results_dir,
                                        _lambda=_lambda, latent_dimension=latent_dimension, Fairway=Fairway, stop_threshold=stop_threshold, entropy_reg_coe=entropy_reg_coe, beta=beta, 
                                        threshold_ratio=threshold_ratio)

        
    def train(self):
        
        self.trainer.train(self.dataset, self.ae_net)
