import time

from gradiend.model import GradiendModel, ModelWithGradiend
from gradiend.setups.emotion import EmotionSetup
from gradiend.training.gradiend_training import train_for_configs

configs = {
        'distilbert-base-cased': dict(eval_max_size=100,
                                      supervised=False,
                                      batch_size=1,
                                      n_evaluation=250,
                                      epochs=100,
                                      source='counterfactual',
                                      target='diff',
                                      max_iterations=10000,
                                      #use_gradients=False
                                      ),
        #'roberta-base': dict(eval_max_size=200, batch_size=16, n_evaluation=100, epochs=1, source='counterfactual', target='diff', max_iterations=20000),
    }

def multi_train():

    valence_gradiend = 'results/experiments/gradiend/valence/distilbert-base-cased/valence_v10/1'
    arousal_gradiend = 'results/experiments/gradiend/arousal/distilbert-base-cased/arousal_v10/1'

    valence_gradiend = GradiendModel.from_pretrained(valence_gradiend)
    arousal_gradiend = GradiendModel.from_pretrained(arousal_gradiend)

    setup = EmotionSetup(n_features=2, features=['valence', 'arousal'],)

    def model_inserter(model):
        print('called')

    for model, config in configs.items():
        config['post_processing'] = model_inserter


    train_for_configs(setup, configs, version='tanh_from_pretrained', n=5, )


def further_train():
    #time.sleep(20*60)
    initial_gradiend = f'results/experiments/gradiend/emotion-3/distilbert-base-cased/tanh_3_supervised_cf/4_best'
    setup = EmotionSetup(n_features=3)

    def model_inserter(model):
        return ModelWithGradiend.from_pretrained(initial_gradiend)

    for model, config in configs.items():
        config['post_processing'] = model_inserter


    train_for_configs(setup, configs, version='tanh_from_pretrained', n=4, )


if __name__ == '__main__':
    #multi_train()
    further_train()