import os
import sys
import traceback
import warnings
from datetime import datetime as dt
from pathlib import Path
from typing import Optional

sys.path.append("../..")

import comet_ml
import gym
import h5py
import numpy as np
import torch.optim as optim
from comet_ml import Experiment
from common.cca.cca import train_cca
from common.cca.ddpg import train_ddpg
from evaluate import _evaluate
from omegaconf import DictConfig, OmegaConf
from utils.utils import process_args

from ours.utils.dataset_utils import get_dataset


def main(
    args: DictConfig,
    experiment: Optional[Experiment] = None,
):

    logdir = Path(args.logdir)

    source_dataset = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )

    target_dataset = get_dataset(
        dataset_path=args.target_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_target_observations,
        transform_actions=args.reverse_target_actions,
    )

    source_transform, target_transform = train_cca(
        source_dataset=source_dataset,
        target_dataset=target_dataset,
        task_ids=args.task_ids,
        n_components=min(args.source_state_dim, args.target_state_dim),
        model_path=logdir / "cca.pkl",
    )

    source_dataset = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.inference_task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )

    ddpg_model = train_ddpg(
        args=args,
        source_transform=source_transform,
        target_transform=target_transform,
        source_dataset=source_dataset,
        model_path=logdir / "ddpg_agent",
        max_steps=args.transfer.max_steps,
    )


if __name__ == "__main__":
    cli_args = OmegaConf.from_cli()
    conf_args = OmegaConf.load(cli_args.config)
    args = OmegaConf.merge(conf_args, cli_args)
    OmegaConf.resolve(args)
    args = process_args(args, "align", args.inference_task_ids)

    root_dir = Path(args.root_dir)
    root_dir.mkdir(exist_ok=True)
    current_time = dt.now().strftime("%Y%m%d_%H%M%S")
    logdir = root_dir / (args.experiment_name + "_" + current_time)
    logdir.mkdir()
    args.logdir = str(logdir)
    OmegaConf.save(args, logdir / "config.yaml")

    experiment = None

    main(args, experiment)
    _evaluate(args, experiment)
