import argparse
import json
import random

import numpy as np
import torch
from lightning import Trainer
from lightning.pytorch import seed_everything

from evaluation.discriminative.datamodule import DiscriminativeDM
from evaluation.discriminative.pl_model import DiscriminatorPL
from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_real, get_data_synthetic, \
    get_data_real_crvae


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, _ = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 7000 if run != 'crvae' else 1024
    train_percentage = .8
    max_epochs = 30
    lr = 1e-4
    hidden_size = 8
    num_layers = 2

    accelerator = 'cuda' if torch.cuda.is_available() else 'cpu'
    devices = "auto" if torch.cuda.device_count() == 0 else torch.cuda.device_count()
    pin_memory = True if accelerator == 'cuda' else False
    num_workers = 8 if accelerator == 'cuda' else 1
    batch_size = 32

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    data_real, pipeline = None, None
    if 'crvae' not in run:
        data_real, pipeline = get_data_real(run, dataset, seq_len, 'discriminative')
        data_real = torch.Tensor(data_real)

    scores = list()
    for seed in range(10):
        print(f'Seed: {seed}')
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.empty_cache()

        if 'crvae' not in run:
            data_real_ev = data_real[np.random.randint(0, len(data_real), n_samples_evaluation)]
        else:
            data_real_ev = get_data_real_crvae(run, dataset, epoch, seed)

        data_synthetic = torch.Tensor(get_data_synthetic(run_id, epoch, seed, seq_len, n_features, pipeline))
        data_synthetic_ev = data_synthetic[np.random.randint(0, len(data_synthetic), n_samples_evaluation)]

        dm = DiscriminativeDM(data_real_ev, data_synthetic_ev, train_percentage, batch_size, num_workers, pin_memory)
        model = DiscriminatorPL(n_features, hidden_size, num_layers, lr)

        trainer = Trainer(accelerator=accelerator, devices=devices, fast_dev_run=False, max_epochs=max_epochs)

        trainer.fit(model, datamodule=dm)

        score = max([trainer.logged_metrics[f'epoch{e}_discriminative_score'].item() for e in range(max_epochs)])
        scores.append(score)
        print()
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')
        print()

    print()
    print()
    print()
    for seed, score in enumerate(scores):
        print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')

    with open(f'storage/{run_id}/discriminative_score_epoch={epoch}.json', 'w') as json_file:
        json.dump(scores, json_file)


if __name__ == '__main__':
    main()
