import argparse
import random

import numpy as np
import torch
from lightning.pytorch import seed_everything

from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_real, get_data_synthetic, \
    get_data_real_crvae
from evaluation.xcorr import CrossCorrelLoss


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, _ = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 1024
    max_lag = 10

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    data_real = None
    if 'crvae' not in run:
        data_real, _ = get_data_real(run, dataset, seq_len)

    scores = list()
    for seed in range(10):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        if 'crvae' in run:
            data_real = get_data_real_crvae(run, dataset, epoch, seed)

        data_real_ev = data_real[np.random.randint(0, len(data_real), n_samples_evaluation)]
        data_real_ev = torch.Tensor(data_real_ev)

        data_synthetic = get_data_synthetic(run_id, epoch, seed, seq_len, n_features)
        data_synthetic_ev = data_synthetic[np.random.randint(0, len(data_synthetic), n_samples_evaluation)]
        data_synthetic_ev = torch.Tensor(data_synthetic_ev)

        loss_fn_xcorr = CrossCorrelLoss(max_lag)
        score = loss_fn_xcorr(data_real_ev, data_synthetic_ev).item()
        scores.append(score)
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')


if __name__ == '__main__':
    main()
