import sys
from pathlib import Path
from typing import Optional, Union

sys.path.append("../..")

import comet_ml
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
from comet_ml import ExistingExperiment, Experiment
from common.algorithm_main import epoch_loop
from common.models import Discriminator, Policy
from omegaconf import DictConfig, OmegaConf
from utils.dataset_utils import (get_dataset, process_alignment_dataset,
                                 split_dataset)


def align(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]] = None,
):

    logdir = Path(args.logdir) / "align"
    model_dir = logdir / "model"
    model_dir.mkdir(exist_ok=True)

    # --------------------------------------------------
    #  Prepare datasets
    # --------------------------------------------------

    source_dataset_ = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )
    target_dataset_ = get_dataset(
        dataset_path=args.target_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_target_observations,
        transform_actions=args.reverse_target_actions,
    )

    if args.no_task_id:
        source_dataset_["infos/goal_id"][:] = args.inference_task_ids[0]
        target_dataset_["infos/goal_id"][:] = args.inference_task_ids[0]

    alignment_dataset_ = process_alignment_dataset(
        source_dataset=source_dataset_,
        target_dataset=target_dataset_,
        num_task_ids=args.num_task_ids,
        max_size=args.max_dataset_size,
    )

    dataloader_dict = {}

    keys = [
        "observations", "task_ids", "domain_ids", "actions",
        "next_observations", "action_masks"
    ]
    alignment_dataset = data.TensorDataset(
        *[alignment_dataset_[key] for key in keys])

    train_dataset, val_dataset = split_dataset(alignment_dataset,
                                               args.train_ratio)

    dataloader_dict["train"] = data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
    )

    dataloader_dict["validation"] = data.DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size,
        num_workers=4,
    )

    # --------------------------------------------------
    #  Prepare models
    # --------------------------------------------------

    model_dict = {}

    policy = Policy(
        state_dim=args.policy.state_dim,
        cond_dim=args.policy.cond_dim,
        out_dim=args.policy.out_dim,
        domain_dim=args.policy.domain_dim,
        latent_dim=args.policy.latent_dim,
        hid_dim=args.policy.hid_dim,
        num_hidden_layers=args.policy.num_hidden_layers,
        activation=args.policy.activation,
        repr_activation=args.policy.latent_activation,
        enc_sn=args.policy.spectral_norm,
        decode_with_state=args.policy.decode_with_state,
    ).to(args.device)
    model_dict["policy"] = policy

    optimizer_dict = {}

    params = list(policy.parameters())
    optimizer_dict["policy"] = optim.Adam(params,
                                          lr=args.policy.lr,
                                          betas=args.betas)

    if args.discriminator.enable:
        discriminator = Discriminator(
            latent_dim=args.policy.latent_dim,
            hid_dim=args.policy.hid_dim,
            num_classes=args.policy.domain_dim,
            cond_dim=args.policy.cond_dim,
            sa_disc=args.discriminator.alpha,
            task_cond=args.discriminator.use_task_id,
            activation=args.discriminator.activation,
            sn=args.discriminator.spectral_norm,
        ).to(args.device)
        optimizer_dict["discriminator"] = optim.Adam(
            discriminator.parameters(),
            lr=args.discriminator.lr,
            betas=args.betas)
        model_dict["discriminator"] = discriminator

    align_args = OmegaConf.create({
        "device": args.device,
        "enc_decay": args.policy.encoder_decay,
        "calc_align_score": False,
        "adversarial_coef": args.discriminator.coef,
        "smooth": args.label_smoothing,
        "target_coef": args.target_coef,
        "source_coef": args.source_coef,
        "verbose": args.verbose,
    })

    # --------------------------------------------------
    #  Alignment
    # --------------------------------------------------

    best_val_loss = 1e9
    for epoch in range(1, 1 + args.num_epoch):
        val_loss = epoch_loop(
            args=align_args,
            model_dict=model_dict,
            optimizer_dict=optimizer_dict,
            dataloader_dict=dataloader_dict,
            epoch=epoch,
            experiment=experiment,
            log_prefix="align",
        )

        if val_loss < best_val_loss:
            torch.save(model_dict["policy"].state_dict(),
                       model_dir / "best.pt")

        if epoch % args.log_interval == 0:
            torch.save(model_dict["policy"].state_dict(),
                       model_dir / f"{epoch:03d}.pt")
