import argparse
import random

import numpy as np
from lightning.pytorch import seed_everything

from evaluation.authenticity import authenticity_score
from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_real, get_data_synthetic, \
    get_data_real_crvae


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, _ = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 7000 if run != 'crvae' else 1024

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    data_real = None
    if 'crvae' not in run:
        data_real, _ = get_data_real(run, dataset, seq_len)
        data_real = data_real.reshape(-1, seq_len * n_features)

    scores = list()
    for seed in range(10):
        random.seed(seed)
        np.random.seed(seed)

        if 'crvae' in run:
            data_real = get_data_real_crvae(run, dataset, epoch, seed)
            data_real = data_real.reshape(n_samples_evaluation, seq_len * n_features)

        data_synthetic = get_data_synthetic(run_id, epoch, seed, seq_len, n_features)
        data_synthetic_ev = data_synthetic[np.random.randint(0, len(data_synthetic), n_samples_evaluation)]
        data_synthetic_ev = data_synthetic_ev.reshape(n_samples_evaluation, seq_len * n_features)

        score = authenticity_score(data_real, data_synthetic_ev)
        scores.append(score)
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')


if __name__ == '__main__':
    main()
