import os
import sys

import torch

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, 'src')))
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from experiment_utils import data_ingredient, load_dataset, training_ingredient, train_model, make_experiment, \
    make_experiment_tempfile, serialization_guard
from models.reconstruction import USADModel, USADAnomalyDetector, USADDecoder1Loss, USADDecoder2Loss
from utils.utils import Bunch


experiment = make_experiment(ingredients=[data_ingredient, training_ingredient])


def get_training_pipeline():
    return {
        'window': {'class': 'WindowTransform', 'args': {'window_size': 12}},
        'reconstruction': {'class': 'ReconstructionTargetTransform', 'args': {'replace_labels': True}}
    }


def get_test_pipeline():
    return {
        'window': {'class': 'WindowTransform', 'args': {'window_size': 12}}
    }


def get_batch_dim():
    return 0


@data_ingredient.config
def data_config():
    pipeline = get_training_pipeline()

    ds_args = dict(
        training=True
    )

    split = (1, 0)


@training_ingredient.config
def training_config():
    loss = [USADDecoder1Loss, USADDecoder2Loss]
    batch_dim = get_batch_dim()
    trainer_hooks = []


@experiment.config
def config():
    # Model-specific parameters
    model_params = dict(
        z_size=40,  # Size of the latent space (will be multiplied by window size)
    )

    detector_params = dict(
        alpha=0.5,  # Trade-off parameter in the anomaly score
    )

    train_detector = True
    save_detector = True


@experiment.command(unobserved=True)
@serialization_guard('model', 'val_loader')
def get_anomaly_detector(model, val_loader, training, detector_params, _run, save_detector=True):
    training = Bunch(training)
    detector = USADAnomalyDetector(model, **detector_params).to(training.device)
    # detector.fit(val_loader)

    if save_detector:
        with make_experiment_tempfile('final_model.pth', _run, mode='wb') as f:
            torch.save(dict(detector=detector), f)

    return detector


@experiment.automain
@serialization_guard
def main(model_params, dataset, training, _run, train_detector=True):
    model_params = Bunch(model_params)
    ds_params = Bunch(dataset)
    train_params = Bunch(training)

    train_ds, val_ds = load_dataset()
    model = USADModel(train_ds.num_features * train_ds.seq_len, model_params.z_size * train_ds.seq_len)

    trainer = train_model(_run, model, train_ds, val_ds)

    if train_detector:
        detector = get_anomaly_detector(model, trainer.val_iter)
    else:
        detector = None

    return dict(detector=detector, model=model, trainer=trainer)
