import argparse
import random

import numpy as np
import torch
from lightning.pytorch import seed_everything

from evaluation.mmd import mmd_rbf
from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_real, get_data_synthetic, \
    get_data_real_crvae


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, _ = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 2048

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    data_real = None
    if 'crvae' not in run:
        data_real, _ = get_data_real(run, dataset, seq_len)

    scores = list()
    for seed in range(10):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        if 'crvae' in run:
            data_real = get_data_real_crvae(run, dataset, epoch, seed)
        random_idx_real = np.random.randint(0, data_real.shape[0], n_samples_evaluation)
        data_real_ev = data_real[random_idx_real].reshape(n_samples_evaluation, -1)

        data_synthetic = get_data_synthetic(run_id, epoch, seed, seq_len, n_features)
        random_idx_synthetic = np.random.randint(0, data_synthetic.shape[0], n_samples_evaluation)
        data_synthetic_ev = data_synthetic[random_idx_synthetic].reshape(n_samples_evaluation, -1)

        score = mmd_rbf(data_real_ev, data_synthetic_ev)
        scores.append(score)
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')


if __name__ == '__main__':
    main()
