import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from comet_ml import Experiment
from omegaconf import DictConfig, OmegaConf

import d4rl
from common.cdil.cdil import (CDILConfig, behavioral_cloning,
                              calc_alignment_score, main_loop,
                              pretrain_position_encoding,
                              save_translation_models, train_idm)
from common.cdil.dataset_utils import read_step_dateset
from common.cdil.models import (Discriminator, InverseDynamicsModel, Policy,
                                PositionalEncoder, StateConverter)
from common.utils.process_dataset import (get_goal_candidates,
                                          read_env_config_yamls)


def create_logdir(args):
    source = args.domains[0].env_tag
    target = args.domains[1].env_tag
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    random_no = np.random.randint(low=0, high=int(1e6) - 1)
    path = Path("results/cdil/"
                f"{source}_{target}/{args.name}"
                f"{timestamp}-{random_no}")
    path.mkdir(parents=True, exist_ok=True)
    return path


def main(
    args: DictConfig,
    experiment: Optional[Experiment],
):
    logdir = create_logdir(args)
    shutil.copy(args.config, logdir / "config.yaml")
    args.logdir = str(logdir)

    if experiment:
        experiment.log_parameter("logdir", str(logdir))

    train_dataloader_dict, task_id_managers = read_step_dateset(args)
    inference_dataloader_dict, _ = read_step_dateset(args, inference=True)
    dataloader_dict = {}
    for key, val in train_dataloader_dict.items():
        dataloader_dict[f"align_{key}"] = val
    for key, val in inference_dataloader_dict.items():
        dataloader_dict[f"adapt_{key}"] = val
    n_task_id = task_id_managers[0].n_task_id if not args.multienv else args.n_task_ids
    if args.complex_task:
        n_task_id += 1

    s = args.state_converter
    forward_converter = StateConverter(
        in_state_dim=args.domains[0].obs_dim,
        out_state_dim=args.domains[1].obs_dim,
        hid_dim=s.hid_dim,
        latent_dim=args.latent_dim,
        num_hidden_layers=s.num_hidden_layers,
        activation=s.activation,
    ).to(args.device)
    backward_converter = StateConverter(
        in_state_dim=args.domains[1].obs_dim,
        out_state_dim=args.domains[0].obs_dim,
        hid_dim=s.hid_dim,
        latent_dim=args.latent_dim,
        num_hidden_layers=s.num_hidden_layers,
        activation=s.activation,
    ).to(args.device)

    d = args.discriminator
    source_discriminator = Discriminator(
        latent_dim=args.domains[0].obs_dim * 2,
        hid_dim=d.hid_dim,
        num_hidden_layers=d.num_hidden_layers,
        activation=d.activation,
        task_cond=True,
        cond_dim=n_task_id,
        adv_coef=d.adversarial_coef,
    ).to(args.device)
    target_discriminator = Discriminator(
        latent_dim=args.domains[1].obs_dim * 2,
        hid_dim=d.hid_dim,
        num_hidden_layers=d.num_hidden_layers,
        activation=d.activation,
        task_cond=True,
        cond_dim=n_task_id,
        adv_coef=d.adversarial_coef,
    ).to(args.device)
    latent_discriminator = Discriminator(
        latent_dim=args.latent_dim * 2,
        hid_dim=d.hid_dim,
        num_hidden_layers=d.num_hidden_layers,
        activation=d.activation,
        task_cond=True,
        cond_dim=n_task_id,
        adv_coef=d.adversarial_coef,
    ).to(args.device)

    p = args.positional_encoder
    source_positional_encoder = PositionalEncoder(
        state_dim=args.domains[0].obs_dim,
        hid_dim=p.hid_dim,
        num_hidden_layers=p.num_hidden_layers,
        cond_dim=n_task_id,
        activation=p.activation,
    ).to(args.device)
    target_positional_encoder = PositionalEncoder(
        state_dim=args.domains[1].obs_dim,
        hid_dim=p.hid_dim,
        num_hidden_layers=p.num_hidden_layers,
        cond_dim=n_task_id,
        activation=p.activation,
    ).to(args.device)
    latent_positional_encoder = PositionalEncoder(
        state_dim=args.latent_dim,
        hid_dim=p.hid_dim,
        num_hidden_layers=p.num_hidden_layers,
        cond_dim=0,
        activation=p.activation,
    ).to(args.device)

    idm = args.inverse_dynamics_model
    inverse_dynamics_model = InverseDynamicsModel(
        state_dim=args.domains[1].obs_dim,
        action_dim=args.domains[1].action_dim,
        hid_dim=idm.hid_dim,
        num_hidden_layers=idm.num_hidden_layers,
        activation=idm.activation,
    ).to(args.device)

    p = args.policy
    policy = Policy(
        state_dim=args.domains[1].obs_dim,
        action_dim=args.domains[1].action_dim,
        hid_dim=p.hid_dim,
        num_hidden_layers=p.num_hidden_layers,
        activation=p.activation,
        image_observation=args.image_observation,
        image_state_dim=args.image_state_dim,
        coord_conv=args.use_coord_conv,
        pretrained=args.pretrained,
        use_image_decoder=args.use_image_decoder,
    ).to(args.device)

    model_dict = {}
    model_dict["forward_converter"] = forward_converter
    model_dict["backward_converter"] = backward_converter
    model_dict["source_discriminator"] = source_discriminator
    model_dict["target_discriminator"] = target_discriminator
    model_dict["latent_discriminator"] = latent_discriminator
    model_dict["source_positional_encoder"] = source_positional_encoder
    model_dict["target_positional_encoder"] = target_positional_encoder
    model_dict["latent_positional_encoder"] = latent_positional_encoder
    model_dict["inverse_dynamics_model"] = inverse_dynamics_model
    model_dict["policy"] = policy

    optimizer_dict = {}
    pretrain_params = []
    pretrain_params += list(source_positional_encoder.parameters())
    pretrain_params += list(target_positional_encoder.parameters())
    pretrain_params += list(policy.parameters())  # for image encoder/decoder
    optimizer_dict["pretrain"] = torch.optim.Adam(
        params=pretrain_params,
        lr=args.lr,
    )
    optimizer_dict["inverse_dynamics_model"] = torch.optim.Adam(
        params=inverse_dynamics_model.parameters(),
        lr=args.lr,
    )
    optimizer_dict["bc"] = torch.optim.Adam(
        params=policy.net.parameters(),
        lr=args.lr,
    )

    params = []
    for key, val in model_dict.items():
        if key in [
                "source_positional_encoder", "target_positional_encoder",
                "inverse_dynamics_model", "policy"
        ]:
            continue
        params += list(val.parameters())
    optimizer_all = torch.optim.Adam(params=params, lr=args.lr)
    optimizer_dict["all"] = optimizer_all

    pretrain_position_encoding(args, experiment, model_dict, optimizer_dict,
                               dataloader_dict)
    main_loop(args, experiment, model_dict, optimizer_dict, dataloader_dict)
    save_translation_models(args, model_dict)
    if args.n_domains == 2 and 'point' in args.domains[0].env_tag and 'point' in args.domains[1].env_tag \
        and not hasattr(args.domains[1], 'obs_converter'):
        calc_alignment_score(
            args=args,
            data_loader=dataloader_dict['align_val'][1],  # target
            model_dict=model_dict,
            experiment=experiment)
    train_idm(args, experiment, model_dict, optimizer_dict, dataloader_dict)
    behavioral_cloning(args,
                       experiment,
                       model_dict,
                       optimizer_dict,
                       dataloader_dict,
                       task_id_manager=task_id_managers[1])


if __name__ == "__main__":
    args = OmegaConf.structured(CDILConfig)
    cli_args = OmegaConf.from_cli()

    config_file_path = cli_args.get('config', 'common/cdil/config/p2p.yaml')
    assert Path(config_file_path).suffix == '.yaml'
    file_args = OmegaConf.load(config_file_path)
    args = OmegaConf.merge(args, file_args)

    # TODO perhaps you need to reload this when you load dataset for inference if args.complex_task is True
    args: CDILConfig = OmegaConf.merge(args, cli_args)
    read_env_config_yamls(args)

    goal_candidates = get_goal_candidates(
        n_goals=args.domains[0].n_goals
        if not args.multienv else args.n_task_ids,
        target_goal=args.goal,
        align=True,
        complex_task=args.complex_task,
        n_tasks=args.n_tasks,
        is_r2r='Lift' in args.domains[0]['env'],
    )
    args.train_goal_ids = goal_candidates

    hparams = OmegaConf.to_container(args)

    if args.comet:
        experiment = Experiment(
            project_name=os.environ['COMET_PLP_PROJECT_NAME'])
        experiment.set_name(args.name)
        experiment.log_parameters(hparams)
        experiment.add_tag("cdil")
        experiment.add_tag(args.config.stem)  # env name
    else:
        experiment = None

    print(args)
    main(args, experiment)
