import os
import warnings
from pathlib import Path
from typing import Callable, Dict

os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

import d4rl
import gym
import numpy as np
from .env import CCATransferWrapper
from omegaconf import DictConfig
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.td3.policies import TD3Policy


def train_ddpg(
    args: DictConfig,
    source_transform: Callable,
    target_transform: Callable,
    source_dataset: Dict[str, np.ndarray],
    model_path: Path,
    max_steps: int,
) -> DDPG:
    """train ddpg for transferring

    Args:
        args (DictConfig): the following keys are required
            target_env_id: str,
            inference_task_ids: List[int],
            num_task_ids: int,
            alpha: float,
            aux_reward_only: bool,
            policy.pi: List[int],  # size of layers
            policy.qf: List[int],
        
        source_transform (Callable): source_transform
        
        target_transform (Callable): target_transform
        
        source_dataset (Dict[str, np.ndarray]): dataset from source domain,
            specific keys ("terminals", "observations", "infos/goal_id") are required
        
        model_path (Path): path to the zip file where the trained model will be saved
        
        max_steps (int): max environmental steps
    """

    env = gym.make(args.target_env_id, reward_type="sparse")
    env = CCATransferWrapper(
        env=env,
        args=args,
        source_dataset=source_dataset,
        source_transform=source_transform,
        target_transform=target_transform,
        inference_task_ids=args.inference_task_ids,
        num_task_ids=args.num_task_ids,
        use_task_id_for_obs=False,
        alpha=args.alpha,
        aux_reward_only=args.aux_reward_only,
    )
    env.reset()

    n_actions = env.action_space.shape[-1]
    action_noise = NormalActionNoise(mean=np.zeros(n_actions),
                                     sigma=0.1 * np.ones(n_actions))

    policy_kwargs = {
        "net_arch": {
            "pi": args.policy.pi,
            "qf": args.policy.qf,
        },
    }

    model = DDPG(
        policy=TD3Policy,
        env=env,
        action_noise=action_noise,
        verbose=1,
        policy_kwargs=policy_kwargs,
    )

    model.learn(total_timesteps=max_steps)
    model.save(model_path)
    print("DDPG model saved to", model_path)

    return model
