import sys
from pathlib import Path
from typing import Optional, Union

sys.path.append("../..")

import comet_ml
import torch
import torch.optim as optim
import torch.utils.data as data
from comet_ml import ExistingExperiment, Experiment
from common.algorithm_main import epoch_loop
from common.models import Policy
from omegaconf import DictConfig, OmegaConf
from utils.dataset_utils import (get_dataset, process_alignment_dataset,
                                 split_dataset)


def adapt(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]] = None,
):

    logdir = Path(args.logdir) / "adapt"
    model_dir = logdir / "model"
    model_dir.mkdir(exist_ok=True)

    # --------------------------------------------------
    #  Prepare datasets
    # --------------------------------------------------

    source_dataset_ = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )
    target_dataset_ = get_dataset(
        dataset_path=args.target_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_target_observations,
        transform_actions=args.reverse_target_actions,
    )

    alignment_dataset_ = process_alignment_dataset(
        source_dataset=source_dataset_,
        target_dataset=target_dataset_,
        num_task_ids=args.num_task_ids,
        max_size=args.max_dataset_size,
    )

    dataloader_dict = {}

    keys = [
        "observations", "task_ids", "domain_ids", "actions",
        "next_observations", "action_masks"
    ]
    alignment_dataset = data.TensorDataset(
        *[alignment_dataset_[key] for key in keys])

    train_dataset, val_dataset = split_dataset(alignment_dataset,
                                               args.train_ratio)

    dataloader_dict["train"] = data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
    )

    dataloader_dict["validation"] = data.DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size,
        num_workers=4,
    )

    # --------------------------------------------------
    #  Prepare models
    # --------------------------------------------------

    model_dict = {}

    if hasattr(args.policy, "decode_with_state"):
        decode_with_state = args.policy.decode_with_state
    else:
        decode_with_state = False

    policy = Policy(
        state_dim=args.policy.state_dim,
        cond_dim=args.policy.cond_dim,
        out_dim=args.policy.out_dim,
        domain_dim=args.policy.domain_dim,
        latent_dim=args.policy.latent_dim,
        hid_dim=args.policy.hid_dim,
        num_hidden_layers=args.policy.num_hidden_layers,
        activation=args.policy.activation,
        repr_activation=args.policy.latent_activation,
        enc_sn=args.policy.spectral_norm,
        decode_with_state=decode_with_state,
    ).to(args.device)
    policy.load_state_dict(
        torch.load(args.model_path, map_location=args.device))
    model_dict["policy"] = policy

    optimizer_dict = {}

    optimizer_dict["policy"] = optim.Adam(policy.core.parameters(),
                                          lr=args.policy.lr,
                                          betas=args.betas)

    # --------------------------------------------------
    #  Adaptation
    # --------------------------------------------------

    adapt_args = OmegaConf.create({
        "device": args.device,
        "enc_decay": args.policy.encoder_decay,
        "calc_align_score": False,
        "target_coef": 0,
        "source_coef": 1,
        "verbose": args.verbose,
    })

    best_val_loss = 1e9
    epoch = 0
    torch.save(model_dict["policy"].state_dict(),
               model_dir / f"{epoch:03d}.pt")
    for epoch in range(1, 1 + args.num_epoch):
        val_loss = epoch_loop(
            args=adapt_args,
            model_dict=model_dict,
            optimizer_dict=optimizer_dict,
            dataloader_dict=dataloader_dict,
            epoch=epoch,
            experiment=experiment,
            log_prefix="adapt",
        )

        if val_loss < best_val_loss:
            torch.save(model_dict["policy"].state_dict(),
                       model_dir / "best.pt")

        if epoch % args.log_interval == 0:
            torch.save(model_dict["policy"].state_dict(),
                       model_dir / f"{epoch:03d}.pt")
