"""
OCL-PDS Experiment
This is the main script file to run a complete experiment.

For command line arguments of this script, please refer to initialize.get_parser()
"""

import setuptools.version  # Without this, there could be a deadlock in some environments
import time
from initialize import initialize_model, initialize_loss_function, initialize_algorithm, \
     initialize_dataset, initialize_eval_metric, initialize_feedback, initialize_transform, get_parser
from transforms import get_weak_transform, get_strong_transform
from config import populate_defaults
from utils import evaluate_model, enable_deterministic, attr_is_none


def main():
    parser = get_parser()
    config = parser.parse_args()
    config = populate_defaults(config)
    for name, val in vars(config).items():
        print('{}: {}'.format(name, val))

    # 1. Load dataset and transform
    dataset_train = initialize_dataset(config)
    transform_train = initialize_transform(config, dataset_train, True)
    transform_eval = initialize_transform(config, dataset_train, False)

    # Fix seed for reproducibility
    if not attr_is_none(config, 'seed'):
        enable_deterministic(config.seed)

    # 2. Build model, eval_metric, feedback, loss_function and algorithm
    model = initialize_model(config, dataset_train)
    model = model.to(config.device)
    eval_metric = initialize_eval_metric(config)
    feedback = initialize_feedback(config, eval_metric)
    config.loss_function_dummy = initialize_loss_function(config)
    alg = initialize_algorithm(config, model)

    if not attr_is_none(config, 'fixmatch'):
        transform_weak = get_weak_transform(config, dataset_train)
        transform_strong = get_strong_transform(config, dataset_train)

    # 3. Main routine: OCL
    num_batches = dataset_train.num_batches
    if not attr_is_none(config, 'max_batches'):
        num_batches = min(num_batches, config.max_batches)

    online_performance = 0
    recent_all = 0
    recent_worst = 10000
    regression_all = 0
    regression_worst = 10000
    train_time = 0
    train_time_first = 0
    for t in range(num_batches):
        print('=== Batch {} ==='.format(t))
        
        # Step 1: Receive a new data batch
        batch_train = dataset_train.get_batch(t=t, transform=transform_train)
        w = 0 if attr_is_none(config, 'recent_batches') else config.recent_batches
        if w > t:
            w = 0 # Start testing recent batches from batch w
        batches_eval = [dataset_train.get_batch(t=t, transform=transform_eval)] + \
            [dataset_train.get_test_batch(t=t0, transform=transform_eval) for t0 in range(t - 1, t - w - 1, -1)]
        print('Batch {} has {} samples'.format(t, len(batch_train)))

        # Step 2: Evaluate the current model on the new batch, recent batches and regression set, and get feedback
        fb = feedback(t, batch_train, batches_eval, alg.model)
        if t > 0:
            online_performance += fb['performance_0']
            print('Online Performance:\t{}'.format(fb['performance_0']))
            if w > 0:
                recent = sum([fb['performance_{}'.format(i)] for i in range(1, w + 1)]) / w
                print('Recent performance (recent {} batches):\t{}'.format(w, recent))
                recent_worst = min(recent, recent_worst)
                recent_all += recent
            
            # Regression set evaluation
            if attr_is_none(config, 'eval_regression_once'):
                batch_regression = dataset_train.get_test_regression(transform=transform_eval)
                c, n = evaluate_model(alg.model, batch_regression, config, eval_metric)
                print('Regression set Performance:\t{} ({} samples in total)'.format(c, n))
                regression_worst = min(c, regression_worst)
                regression_all += c
        else:
            # The training regression set is given in the first iteration only
            fb['train_regression'] = dataset_train.get_train_regression(transform=transform_train)
            print('The training regression set has {} samples.'.format(len(fb['train_regression'])))
        
        # For FixMatch
        if not attr_is_none(config, 'fixmatch'):
            fb['transform_strong'] = transform_strong
            fb['transform_weak'] = transform_weak


        # Step 3: Fine-tune the model with alg
        if t < num_batches - 1:
            t1 = time.time()
            alg(t, fb)
            t1 = time.time() - t1
            print('Elapsed time:\t{}'.format(t1))
            if t > 0:
                train_time += t1
            else:
                train_time_first = t1

            # Post-fine-tune evaluation
            if (not attr_is_none(config, 'eval_post_train')) or (t == 0 and not attr_is_none(config, 'eval_regression_once')):
                c, n = evaluate_model(alg.model, batches_eval[0], config, eval_metric)
                print('Post-Train Performance:\t{} ({} samples in total)'.format(c, n))

            if t == 0 and not attr_is_none(config, 'eval_regression_once'):
                print('Regression set only evaluated after initial training.')
                batch_regression = dataset_train.get_test_regression(transform=transform_eval)
                c, n = evaluate_model(alg.model, batch_regression, config, eval_metric)
                print('Regression set Performance:\t{} ({} samples in total)'.format(c, n))
                regression_worst = min(c, regression_worst)

    # 4. Print final results
    print('=== Overall Results ===')
    print('Average online performance:\t{}'.format(online_performance / (num_batches - 1)))
    print('Initial model training time:\t{}'.format(train_time_first))
    if num_batches > 2:
        print('Average online training time:\t{}'.format(train_time / (num_batches - 2)))
    if not attr_is_none(config, 'recent_batches') and num_batches > config.recent_batches:
        print('Average recent performance:\t{}'.format(recent_all / (num_batches - config.recent_batches)))
        print('Worst recent performance:\t{}'.format(recent_worst))
    if attr_is_none(config, 'eval_regression_once'):
        print('Average regression set performance:\t{}'.format(regression_all / (num_batches - 1)))
    print('Worst regression set performance:\t{}'.format(regression_worst))

    k = 100 if config.eval_metric == 'acc' else 1
    if not attr_is_none(config, 'csv_file'):
        with open(config.csv_file, 'a') as f:
            f.write('{:.3f},{:.3f},{:.3f},{:.3f},{:.3f},{:.3f}\n'.format(online_performance / (num_batches - 1) * k,
                                                                         recent_all / (num_batches - config.recent_batches) * k,
                                                                         recent_worst * k,
                                                                         regression_all / (num_batches - 1) * k,
                                                                         regression_worst * k,
                                                                         train_time / (num_batches - 2)))


if __name__ == '__main__':
    main()