import json
import pickle
from typing import Optional

import numpy as np
import torch

def get_nfeatures_seqlen_maxlag(dataset: str):
    if 'rivers' in dataset or dataset == 'henon':
        n_features = 3
    elif dataset == 'henon6':
        n_features = 6
    else:
        n_features = 36

    if dataset == 'rivers_lag3':
        seq_len = 29
    elif dataset == 'airquality':
        seq_len = 22
    else:
        seq_len = 30

    max_lag = 3 if dataset == 'rivers_lag3' else 2

    return n_features, seq_len, max_lag

def get_runid_and_epoch(run: str, dataset: str):
    with open(f'src/evaluation/{dataset}_run2idNepoch.json') as json_file:
        run2idNepoch = json.load(json_file)
    run2idNepoch['crvae'] = (f'{dataset}_crvae', 99000 if dataset == 'rivers' else 90000)
    return run2idNepoch[run]


def get_data_real(run: str, dataset: str, seq_len: int, metric: Optional[str] = None):
    if 'causaltime' in run:
        data_real = np.load(f'storage/{dataset}_{run}/data_ori.npy')
        pipeline = None
    else:
        with open(f'data/{dataset}/datamodule.pkl', 'rb') as f:
            datamodule = pickle.load(f)
        pipeline = datamodule.pipeline
        if metric in ['discriminative', 'predictive']:
            data_real = datamodule.dataset_train.data
        elif metric == 'tsne_pca':
            data_real = datamodule.dataset_val.data_unprocessed
        else:
            data_real = datamodule.dataset_train.data_unprocessed
    data_real = torch.from_numpy(data_real).unfold(0, seq_len, 1).transpose(1, 2).numpy()
    return data_real, pipeline

def get_data_real_crvae(run, dataset, epoch, seed):
    return np.load(f'storage/{dataset}_{run}/ori_it={epoch}_seed={seed}.npy')

def get_data_synthetic(run_id, epoch, seed, seq_len, n_features, pipeline=None, coefficients=False):
    if 'crvae' in run_id:
        data_synthetic = np.load(f'storage/{run_id}/syn_it={epoch}_seed={seed}.npy')
    elif 'causaltime' in run_id:
        data_synthetic = np.load(f'storage/{run_id}/generated_datas_seed={seed}_ne1={epoch[0]}_ne2={epoch[1]}.npy')
        data_synthetic = data_synthetic[:, seq_len:, :n_features]
    else:
        to_open = 'coefficients' if coefficients else 'synthetics'
        with open(f'storage/{run_id}/inference_data/{to_open}_epoch={epoch}_seed={seed}.pkl', 'rb') as f:
            data_synthetic = pickle.load(f)
        if pipeline:
            data_synthetic = np.asarray([pipeline.preprocess(x) for x in data_synthetic])
    return data_synthetic
