import os
import sys
from datetime import datetime as dt
from pathlib import Path
from typing import Dict, Optional, Union

sys.path.append("..")

import comet_ml
import torch
import torch.utils.data as data
from comet_ml import ExistingExperiment, Experiment
from common.dail.dail import (train_gama, train_source_policy,
                              train_target_dynamics_model)
from common.dail.models import DAILAgent
from evaluate import _evaluate
from omegaconf import DictConfig, OmegaConf
from utils.utils import process_args

from ours.utils.dataset_utils import get_dataset, process_alignment_dataset
from ours.utils.datasets import TrajectoryDataset


def align(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]],
    dataloader_dict: Dict[str, data.DataLoader],
):

    logdir = Path(args.logdir)
    model_dir = logdir / "model"

    # ----------------------------------------
    # Start training
    # ----------------------------------------

    agent = DAILAgent(args).to(args.device)
    loaded = False
    if args.load_pretrained_model:
        try:
            agent.load_state_dict(
                torch.load(open(args.pretrained, "rb"),
                           map_location=args.device))
            print("Loaded pretrained model", args.pretrained)
            loaded = True
        except:
            print("Unable to load file:", args.pretrained)
            print("Start training from scratch")

    if not args.evaluate_only:
        if args.train_source_policy and not loaded:
            model_path = model_dir / "000.pt"
            # model_path = f"saved_models/{args.source_env_id}_{args.target_env_id}_{args.reverse_source_observations}_{args.reverse_source_actions}.pt"
            train_source_policy(
                args=args,
                omega_args=args,
                agent=agent,
                epochs=args.num_epoch_bc,
                train_loader=dataloader_dict["source"],
                model_path=model_path,
                experiment=experiment,
            )

        if args.train_dynamics_model and not loaded:
            model_path = model_dir / "000.pt"
            # model_path = f"saved_models/{args.source_env_id}_{args.target_env_id}_{args.reverse_source_observations}_{args.reverse_source_actions}.pt"
            train_target_dynamics_model(
                args=args,
                omega_args=args,
                agent=agent,
                epochs=args.num_epoch_dynamics,
                train_loader=dataloader_dict["target"],
                model_path=model_path,
                experiment=experiment,
            )

        if args.train_gama:
            model_path = model_dir / "010.pt"
            train_gama(
                args=args,
                omega_args=args,
                agent=agent,
                epochs=args.num_epoch_gama,
                train_loader=dataloader_dict["align"],
                model_path=model_path,
                experiment=experiment,
            )


def adapt(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]],
    dataloader_dict: Dict[str, data.DataLoader],
):
    logdir = Path(args.logdir)
    model_dir = logdir / "model"

    # ----------------------------------------
    # Train source domain policy with BC
    # ----------------------------------------

    agent = DAILAgent(args).to(args.device)
    state_dict = torch.load(model_dir / "010.pt", map_location=args.device)
    agent.load_state_dict(state_dict)

    adapt_model_path = model_dir / "020.pt"
    train_source_policy(
        args=args,
        omega_args=args,
        agent=agent,
        epochs=args.num_epoch_adapt,
        train_loader=dataloader_dict["source"],
        model_path=adapt_model_path,
        experiment=experiment,
    )


def main(args: DictConfig, experiment: Optional[Union[Experiment,
                                                      ExistingExperiment]]):

    # ----------------------------------------
    # Prepare datasets
    # ----------------------------------------

    model_dir = logdir / "model"
    model_dir.mkdir(exist_ok=True)

    source_dataset_ = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )
    target_dataset_ = get_dataset(
        dataset_path=args.target_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_target_observations,
        transform_actions=args.reverse_target_actions,
    )
    alignment_dataset_ = process_alignment_dataset(
        source_dataset=source_dataset_,
        target_dataset=target_dataset_,
        num_task_ids=args.num_task_ids,
        max_size=args.max_dataset_size,
    )

    # Load datasets
    source_dataset = TrajectoryDataset(
        dataset=source_dataset_,
        num_task_ids=args.num_task_ids,
        max_size=args.max_dataset_size // 2,
        domain_id=args.source_domain_id,
        domain_dim=args.domain_dim,
    )
    target_dataset = TrajectoryDataset(
        dataset=target_dataset_,
        num_task_ids=args.num_task_ids,
        max_size=args.max_dataset_size // 2,
        domain_id=args.target_domain_id,
        domain_dim=args.domain_dim,
    )
    keys = [
        "observations", "task_ids", "domain_ids", "actions",
        "next_observations", "action_masks"
    ]
    alignment_dataset = data.TensorDataset(
        *[alignment_dataset_[k] for k in keys])
    dataloader_dict = {
        "source":
        data.DataLoader(
            dataset=source_dataset,
            batch_size=args.bc.batch_size,
            shuffle=True,
            num_workers=4,
        ),
        "target":
        data.DataLoader(
            dataset=target_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
        ),
        "align":
        data.DataLoader(
            dataset=alignment_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
        ),
    }

    # Alignment
    if args.train_alignment:
        align(args, experiment, dataloader_dict)

    if args.evaluate_alignment:
        evaluate_args = OmegaConf.merge(args, args.evaluate_args)

        _evaluate(evaluate_args, experiment, log_prefix="align")

    args = process_args(
        args=args,
        phase="adapt",
        inference_task_ids=args.inference_task_ids,
    )

    source_dataset_ = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )
    source_dataset = TrajectoryDataset(
        dataset=source_dataset_,
        num_task_ids=args.num_task_ids,
        max_size=args.max_dataset_size,
        domain_id=args.source_domain_id,
        domain_dim=args.domain_dim,
    )
    dataloader_dict = {
        "source":
        data.DataLoader(
            dataset=source_dataset,
            batch_size=args.bc.batch_size,
            shuffle=True,
            num_workers=4,
        ),
    }

    if args.train_adaptation:
        adapt(args, experiment, dataloader_dict)

    if args.evaluate_adaptation:
        evaluate_args = OmegaConf.merge(args, args.evaluate_args)
        _evaluate(evaluate_args, experiment, log_prefix="adapt")


if __name__ == "__main__":
    cli_args = OmegaConf.from_cli()
    conf_args = OmegaConf.load(cli_args.config)
    args = OmegaConf.merge(conf_args, cli_args)
    OmegaConf.resolve(args)

    args = process_args(
        args=args,
        phase="align",
        inference_task_ids=args.inference_task_ids,
    )

    root_dir = Path(args.root_dir)
    root_dir.mkdir(exist_ok=True)
    current_time = dt.now().strftime("%Y%m%d_%H%M%S")
    loaded = False
    if args.logdir:
        logdir = Path(args.logdir)
        loaded = True
    else:
        logdir = root_dir / (args.experiment_name + "_" + current_time)
        logdir.mkdir()
        args.logdir = str(logdir)
        OmegaConf.save(args, logdir / "config.yaml")

    experiment = None
    main(args, experiment)
