import argparse
import json
import random

import numpy as np
import torch
from lightning import Trainer
from lightning.pytorch import seed_everything

from evaluation.predictive.datamodule import PredictiveDM
from evaluation.predictive.pl_model import PredictorPL
from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_real, get_data_synthetic, \
    get_data_real_crvae


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, _ = get_nfeatures_seqlen_maxlag(dataset)

    if 'crvae' in run:
        n_samples_evaluation = 1024
    elif 'causaltime' in run:
        n_samples_evaluation = 7000
    else:
        n_samples_evaluation = 8192
    train_percentage = .8
    max_epochs = 10
    lr = 1e-3
    hidden_size = 32
    num_layers = 2
    cutoff = seq_len // 10

    accelerator = 'cuda' if torch.cuda.is_available() else 'cpu'
    devices = "auto" if torch.cuda.device_count() == 0 else torch.cuda.device_count()
    pin_memory = True if accelerator == 'cuda' else False
    num_workers = 8 if accelerator == 'cuda' else 1
    batch_size = 32

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    data_real, pipeline = None, None
    if 'crvae' not in run:
        data_real, pipeline = get_data_real(run, dataset, seq_len, 'predictive')
        data_real = torch.Tensor(data_real)

    scores = list()
    for seed in range(10):
        print(f'Seed: {seed}')
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.empty_cache()

        if 'crvae' not in run:
            data_real_ev = data_real[np.random.randint(0, len(data_real), n_samples_evaluation)]
        else:
            data_real_ev = get_data_real_crvae(run, dataset, epoch, seed)

        data_synthetic = torch.Tensor(get_data_synthetic(run_id, epoch, seed, seq_len, n_features, pipeline))
        data_synthetic_ev = data_synthetic[np.random.randint(0, len(data_synthetic), n_samples_evaluation)]

        dm = PredictiveDM(
            data_real_ev, data_synthetic_ev, train_percentage, cutoff, batch_size, num_workers, pin_memory
        )
        model = PredictorPL(n_features, num_layers, hidden_size, cutoff, lr)

        # trainer = Trainer(accelerator=accelerator, devices=devices, fast_dev_run=False, max_epochs=max_epochs)
        # torch.cuda.empty_cache()
        # trainer.fit(model, datamodule=dm)
        # score_real2real = max([trainer.logged_metrics[f'epoch{e}_loss'].item() for e in range(max_epochs)])

        dm.train_on_real = False
        trainer = Trainer(accelerator=accelerator, devices=devices, fast_dev_run=False, max_epochs=max_epochs)
        torch.cuda.empty_cache()
        trainer.fit(model, datamodule=dm)
        score_syn2real = min([trainer.logged_metrics[f'epoch{e}_loss'].item() for e in range(max_epochs)])

        # score = abs(score_real2real - score_syn2real)
        score = score_syn2real
        scores.append(score)
        print()
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')
        print()

    print()
    print()
    print()
    for seed, score in enumerate(scores):
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')

    with open(f'storage/{run_id}/predictive_score_epoch={epoch}.json', 'w') as json_file:
        json.dump(scores, json_file)


if __name__ == '__main__':
    main()
