import logging
import os
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Callable, List, Optional

import comet_ml
import gym
import h5py
import imageio.v2
import numpy as np
import torch
import torch.nn as nn
import yaml
from common.dail.models import DAILAgent
from pygifsicle import gifsicle
from torch.nn.utils import spectral_norm
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def visualize_move(env, obs, goals):
    for ob, goal in zip(obs, goals):
        env.reset_to_location(ob[:2], no_noise=True)
        env.set_target(goal)
        env.set_marker()
        env.render()


def save_model(model: torch.nn.Module, path: Path, epoch=None):
    os.makedirs(path, exist_ok=True)
    torch.save(model.state_dict(), path / 'model.pt')
    logger.info(f'The model is saved to {path}')

    if epoch:
        with open(path / 'epoch.txt', 'w') as f:
            f.write(str(epoch))


def read_dataset(path: str,
                 source_trans_fn: Callable,
                 source_action_trans_fn: Callable,
                 shift: bool = False):
    with h5py.File(path, 'r') as f:
        obs = np.array(f['observations'])
        actions = np.array(f['actions'])
        goals = np.array(f['infos/goal'])
        dones = np.array(f['timeouts'])

    goals = goals.round().astype(int)
    goal_list = np.unique(goals, axis=0)
    goal_to_task_id = {tuple(goal): i + 1 for i, goal in enumerate(goal_list)}
    logger.info(f'Goal to task ID: {goal_to_task_id}')
    task_ids = np.array([goal_to_task_id[tuple(goal)] for goal in goals])
    task_ids_onehot = np.eye(task_ids.max() + 1)[task_ids].astype(np.float32)

    # x <-> y and vx <-> vy are swapped in the source domain
    source_obs = source_trans_fn(obs, shift=shift).astype(np.float32)
    source_actions = source_action_trans_fn(actions).astype(np.float32)
    target_obs = obs.astype(np.float32)
    target_actions = actions.astype(np.float32)

    return task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id


def split_dataset(dataset: torch.utils.data.TensorDataset, train_ratio: float,
                  batch_size: int):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    return train_loader, val_loader


def _remove_single_step_trajectories(source_obs: np.ndarray,
                                     target_obs: np.ndarray,
                                     task_ids_onehot: np.ndarray,
                                     source_actions: np.ndarray,
                                     target_actions: np.ndarray,
                                     dones: np.ndarray):
    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    traj_lens = done_idx[1:] - done_idx[:-1]
    single_step_traj_ids = np.where(traj_lens == 1)[0]

    valid_steps = np.ones(len(dones), dtype=bool)
    for single_step_traj_id in single_step_traj_ids:
        invalid_step_idx = done_idx[single_step_traj_id + 1]
        valid_steps[invalid_step_idx] = False

    source_obs = source_obs[valid_steps]
    target_obs = target_obs[valid_steps]
    task_ids_onehot = task_ids_onehot[valid_steps]
    source_actions = source_actions[valid_steps]
    target_actions = target_actions[valid_steps]
    dones = dones[valid_steps]

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones


def _select_n_trajecotries(source_obs: np.ndarray, target_obs: np.ndarray,
                           task_ids_onehot: np.ndarray,
                           source_actions: np.ndarray,
                           target_actions: np.ndarray, dones: np.ndarray,
                           n: int):

    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    start_bit = np.concatenate(([False], dones))[:-1]
    traj_ids = np.cumsum(start_bit)

    selected_traj_ids = np.random.choice(range(traj_ids.max() + 1),
                                         n,
                                         replace=False)

    selected_steps = np.zeros(len(traj_ids))
    for selected_id in selected_traj_ids:
        st_idx, gl_idx = done_idx[selected_id] + 1, done_idx[selected_id + 1]
        selected_steps[st_idx] += 1
        if gl_idx + 1 < len(selected_steps):
            selected_steps[gl_idx + 1] -= 1

    selected_steps = np.cumsum(selected_steps).astype(bool)

    source_obs = source_obs[selected_steps]
    target_obs = target_obs[selected_steps]
    task_ids_onehot = task_ids_onehot[selected_steps]
    source_actions = source_actions[selected_steps]
    target_actions = target_actions[selected_steps]
    dones = dones[selected_steps]

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones


def filter_dataset(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    dones: np.ndarray,
    task_ids: np.ndarray,
    filter_by_id_fn: Callable,
    n_traj: Optional[int] = None,
):
    dones[-1] = True
    logger.info(f'At first, {dones.sum()} trajectories available.')

    select_flag = filter_by_id_fn(task_ids)
    source_obs = source_obs[select_flag]
    target_obs = target_obs[select_flag]
    task_ids_onehot = task_ids_onehot[select_flag]
    source_actions = source_actions[select_flag]
    target_actions = target_actions[select_flag]
    dones = dones[select_flag]
    task_id_list = np.unique(task_ids[select_flag])
    logger.info(f'After task ID filtering, {dones.sum()} trajectories remain.')
    logger.info(
        f'{len(task_id_list)} tasks are going to be used. (ID={task_id_list})')

    # remove single_step trajectory
    source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones = _remove_single_step_trajectories(
        source_obs=source_obs,
        target_obs=target_obs,
        task_ids_onehot=task_ids_onehot,
        source_actions=source_actions,
        target_actions=target_actions,
        dones=dones,
    )
    logger.info(
        f'After removing single-step trajectories, {dones.sum()} trajectories remain.'
    )

    # select final trajectories
    if n_traj:
        source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones = _select_n_trajecotries(
            source_obs=source_obs,
            target_obs=target_obs,
            task_ids_onehot=task_ids_onehot,
            source_actions=source_actions,
            target_actions=target_actions,
            dones=dones,
            n=n_traj,
        )

    logger.info(f'Finally, {dones.sum()} trajectories are selected.')
    if n_traj:
        assert dones.sum() == n_traj

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones


def create_savedir_root(phase_tag: str, env_tag) -> Path:
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    random_no = np.random.randint(low=0, high=int(1e6) - 1)
    savedir_root = Path(f'custom/results/{env_tag}/{phase_tag}'
                        ) / f'{timestamp}-{random_no:06d}'
    os.makedirs(savedir_root, exist_ok=True)

    return savedir_root


def get_processed_data(dataset: str,
                       task_id_zero: bool,
                       filter_by_id_fn: Callable,
                       trans_into_source_obs: Callable,
                       trans_into_source_action: Callable,
                       n_traj: Optional[int] = None,
                       shift: bool = False):
    task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id = read_dataset(
        path=dataset,
        source_trans_fn=trans_into_source_obs,
        source_action_trans_fn=trans_into_source_action,
        shift=shift,
    )

    source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones = filter_dataset(
        source_obs,
        target_obs,
        task_ids_onehot,
        source_actions,
        target_actions,
        dones,
        task_ids,
        filter_by_id_fn=filter_by_id_fn,
        n_traj=n_traj,
    )

    source_next_obs = np.concatenate((source_obs[1:], source_obs[-1:]))
    target_next_obs = np.concatenate((target_obs[1:], target_obs[-1:]))
    source_next_obs = np.where(dones[..., None], source_obs, source_next_obs)
    target_next_obs = np.where(dones[..., None], target_obs, target_next_obs)

    # task_id = 0 for adaptation. Here it's False.
    if task_id_zero:
        task_ids_onehot = np.zeros_like(task_ids_onehot, dtype=np.float32)
    return source_obs, source_actions, target_obs, target_actions, task_ids_onehot, goal_to_task_id, source_next_obs, target_next_obs


def prepare_dataset(args, filter_by_id_fn: Callable,
                    trans_into_source_obs: Callable,
                    trans_into_source_action: Callable,
                    dataset_concat_fn: Callable, task_id_zero: bool):
    logger.info('Start creating dataset...')
    source_obs, source_actions, target_obs, target_actions, task_ids_onehot, goal_to_task_id, source_next_obs, target_next_obs, = get_processed_data(
        dataset=args.dataset,
        task_id_zero=task_id_zero,
        filter_by_id_fn=filter_by_id_fn,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=trans_into_source_action,
        n_traj=args.n_traj,
        shift=args.shift,
    )

    source_domain_id, target_domain_id = np.array(
        [[1, 0]], dtype=np.float32), np.array([[0, 1]], dtype=np.float32)
    obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset, next_obs_for_dataset = dataset_concat_fn(
        source_obs=source_obs,
        target_obs=target_obs,
        source_actions=source_actions,
        target_actions=target_actions,
        source_domain_id=source_domain_id,
        target_domain_id=target_domain_id,
        task_ids_onehot=task_ids_onehot,
        source_next_obs=source_next_obs,
        target_next_obs=target_next_obs,
        source_only=args.source_only
        if hasattr(args, 'source_only') else False,
        target_only=args.target_only
        if hasattr(args, 'target_only') else False,
    )

    # create torch dataset
    obs_for_dataset = torch.from_numpy(obs_for_dataset)
    cond_for_dataset = torch.from_numpy(cond_for_dataset)
    domains_for_dataset = torch.from_numpy(domains_for_dataset)
    actions_for_dataset = torch.from_numpy(actions_for_dataset)
    next_obs_for_dataset = torch.from_numpy(next_obs_for_dataset)
    dataset = TensorDataset(obs_for_dataset, cond_for_dataset,
                            domains_for_dataset, actions_for_dataset,
                            next_obs_for_dataset)
    train_loader, val_loader = split_dataset(dataset=dataset,
                                             train_ratio=args.train_ratio,
                                             batch_size=args.batch_size)
    logger.info('Dataset has been successfully created.')
    return train_loader, val_loader, goal_to_task_id


def get_success(obs, target, threshold=0.1):
    return np.linalg.norm(obs[0:2] - target) <= threshold


def save_video(path: Path,
               images: List[np.ndarray],
               fps: int = 20,
               skip_rate: int = 5,
               experiment: Optional[comet_ml.Experiment] = None):
    os.makedirs(path.parent, exist_ok=True)

    # imageio.mimsave(video_path, images[1::5], fps=20)
    # logger.info(f'video is saved to {video_path}')
    # gifsicle(sources=str(video_path),
    #          destination=str(video_path),
    #          optimize=False,
    #          colors=256,
    #          options=['--optimize=3'])
    # logger.info(f'video has been successfully compressed.')

    mp4_path = path.with_suffix('.mp4')
    imageio.mimsave(mp4_path, images[1::skip_rate], fps=fps)
    logger.info(f'video is saved to {mp4_path}')
    if experiment:
        # experiment.log_image(video_path,
        #                      name=video_path.parent.name,
        #                      step=int(video_path.stem))
        experiment.log_asset(mp4_path,
                             file_name=path.parent.name,
                             step=int(path.stem))


import random
from typing import List, Literal

#### GAMA SPECIFIC
from omegaconf import DictConfig, OmegaConf


def get_omega_args(
    phase: Literal["align", "adapt"] = "align",
    source_env: str = 'maze2d-medium-v1',
    target_env: str = '',
    inference_task_ids: List[int] = [7],
) -> DictConfig:
    if target_env == '':
        target_env = source_env

    args = OmegaConf.create({
        "experiment_name": "dail-test",
        "source_env_id": source_env,  # domain y
        "target_env_id": target_env,  # domain x
        "source_dataset": "../datasets/point/point-medium-v1.hdf5",
        "target_dataset": "../datasets/ant/ant-medium-v1.hdf5",
        "source_domain_id": 0,
        "target_domain_id": 1,
        "domain_dim": 2,
        "reverse_source_observations": False,
        "reverse_source_actions": False,
        "reverse_target_observations": False,
        "reverse_target_actions": False,
        "load_pretrained_model": False,
        "pretrained":
        "saved_models/${source_env_id}_${target_env_id}_${reverse_source_observations}_${reverse_source_actions}.pt",
        "train_source_policy": False,
        "source_policy_trained_path":
        "saved_models/${source_env_id}_${target_env_id}_${reverse_source_observations}_${reverse_source_actions}.pt",
        "train_dynamics_model": False,
        "dynamics_trained_path":
        "saved_models/${source_env_id}_${target_env_id}_${reverse_source_observations}_${reverse_source_actions}.pt",
        "train_gama": True,
    })
    OmegaConf.resolve(args)

    args.seed = np.random.randint(2**31)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    for domain in ["source", "target"]:
        env_id = args[domain + "_env_id"]

        if "ant" in env_id:
            args[domain + "_state_dim"] = 29
            args[domain + "_action_dim"] = 8
        elif "point" in env_id:
            args[domain + "_state_dim"] = 6
            args[domain + "_action_dim"] = 2
        elif "maze2d" in env_id:
            args[domain + "_state_dim"] = 4
            args[domain + "_action_dim"] = 2
        else:
            print("Unrecognized env_id:", env_id)
            raise ValueError

    assert args.source_env_id.split("-")[1] == args.target_env_id.split("-")[1]

    if "umaze" in args.source_env_id:
        proxy_task_ids = list(range(1, 8))
        all_task_ids = list(range(1, 8))
    elif "medium" in args.source_env_id:
        proxy_task_ids = list(range(1, 27))
        all_task_ids = list(range(1, 27))
    elif "large" in args.source_env_id:
        proxy_task_ids = list(range(1, 47))
        all_task_ids = list(range(1, 47))
    else:
        print("Unrecognized env_id:", env_id)
        raise ValueError

    for inference_task_id in inference_task_ids:
        proxy_task_ids.remove(inference_task_id)

    if phase == "align":
        args.task_ids = proxy_task_ids
    elif phase == "adapt":
        args.task_ids = inference_task_ids
    args["num_task_ids"] = len(all_task_ids) + 1

    return generate_params(args)


def generate_params(args):

    # params for training gama
    args["source_buffer_size"] = int(1e5)
    args["target_buffer_size"] = int(1e4)
    args["batch_size"] = 256

    # params for BC
    args["bc"] = {
        "batches_per_epoch": 10000,
        "batch_size": 256,
        "lr": 1e-3,
    }

    # params for GAMA
    h = 64
    act = "leaky_relu"

    lr_source_policy = 1e-4
    lr_target_policy = 1e-5

    lr_dynamics_model = 1e-3
    lr_discriminator = 1e-4

    lr_auto = 1e-4
    lr_state_map = 1e-4
    lr_action_map = 1e-4

    # goal_dim = 2
    goal_dim = args.num_task_ids

    # 使いそうなやつだけ
    # y == expert == source
    # x == learner == target
    args["models"] = {
        "source_policy": { # π_y : s_y -> a_y
            "lr": lr_source_policy,
            "in_dim": args.source_state_dim+goal_dim,
            "hid_dims": [300, 200, args.source_action_dim],
            "activations": [act] * 2 + [None],
        },
        "target_policy": { # π_x : s_x -> a_x
            "lr": lr_target_policy,
            "in_dim": args.target_state_dim+goal_dim,
            "hid_dims": [300, 200, args.target_action_dim],
            "activations": [act] * 2 + [None],
        },
        "state_map": { # f : s_x -> s_y
            "lr": lr_state_map,
            "in_dim": args.target_state_dim,
            "hid_dims": [h] * 2 + [args.source_state_dim],
            "activations": [act] * 2 + [None]
        },
        "action_map": { # g : a_y -> a_x
            "lr": lr_action_map,
            "in_dim": args.source_action_dim,
            "hid_dims": [h] * 2 + [args.target_action_dim],
            "activations": [act] * 2 + [None]
        },
        "inv_state_map": { # f^-1 : s_y -> s_x
            "lr": lr_auto,
            "in_dim": args.source_state_dim,
            "hid_dims": [200, 200] + [args.target_state_dim],
            "activations": [act] * 2 + [None]
        },
        "dynamics_model": { # P_x : (s_x, a_x) -> s'_x
            "lr": lr_dynamics_model,
            "in_dim": args.target_state_dim + args.target_action_dim,
            "hid_dims": [h] * 3 + [args.target_state_dim],
            "activations": [act] * 3 + [None]
        },
        "discriminator": { # D : (s_y, a_y, s'_y) -> [0, 1]
            "lr": lr_discriminator,
            "in_dim": args.source_state_dim * 2 + args.source_action_dim,
            "hid_dims": [h] * 2 + [1],
            "activations": [act] * 2 + [None]
        }
    }

    return args


def eval_policy_dail(
    env: gym.Env,
    policy: DAILAgent,
    device,
    source_trans_fn: Callable,
    task_dim: int,
    source_action_type: str = 'normal',
    times=10,
    target_center=(1, 1),
    source_flag=False,
    render_episodes: int = 0,
    video_path: Path = '.',
    experiment: Optional[comet_ml.Experiment] = None,
    shift: bool = False,
):
    assert source_action_type in ['normal', 'inv']

    # if source_flag == True,  x and y will be inverted.
    results = []
    steps = []
    images = []

    for t in range(times):
        obs = env.reset()
        target_location = np.array(target_center) + env.np_random.uniform(
            low=-.1, high=.1, size=env.model.nq)
        env.set_target(target_location)

        while get_success(obs=obs, target=env.get_target(), threshold=0.5):
            obs = env.reset()

        done = False
        step = 0
        success = False
        while not done:
            step += 1
            if source_flag:
                state_input = source_trans_fn(obs, shift=shift)[None].astype(
                    np.float32)
            else:  # normal evaluation
                state_input = obs[None].astype(np.float32)

            state_input = torch.from_numpy(state_input).to(device)
            task_id = torch.zeros((1, task_dim),
                                  dtype=torch.float32,
                                  device=device)
            with torch.no_grad():
                if source_flag:
                    inp = torch.cat((state_input, task_id), dim=-1)
                    action = policy.source_policy(inp)
                else:
                    action = policy(target_obs=state_input, task_ids=task_id)

                action = action.detach().cpu().numpy()[0]

            if source_flag and source_action_type == 'inv':
                action = -action

            obs, reward, done, _ = env.step(action)

            if t < render_episodes:
                images.append(env.render('rgb_array'))

            if get_success(obs=obs, target=env.get_target()):
                success = True
                if t < render_episodes:
                    for _ in range(10):
                        images.append(env.render('rgb_array'))
                break

        results.append(success)
        steps.append(step)
        logger.debug(f'Trial {t}: success={success}; steps={step}')

    success_rate = np.array(results).mean()
    steps_mean = np.array(steps).mean()

    if images:
        save_video(path=video_path,
                   images=images,
                   fps=20,
                   skip_rate=5,
                   experiment=experiment)

    return success_rate, steps_mean


def record_align_hparams(path: Path, experiment: Optional[comet_ml.Experiment],
                         args):
    yaml_path = path / 'hparams.yaml'
    if not yaml_path.exists():
        logger.info(f'{yaml_path} not found')
        return

    with open(yaml_path, 'r') as f:
        align_hparams = yaml.safe_load(f)

    logger.info(f'Align hparams in {yaml_path} are loaded.')

    if 'align_action_type' in align_hparams.keys():
        args.action = align_hparams["align_action_type"]

    if 'align_hid_dim' in align_hparams.keys():
        args.hid_dim = align_hparams['align_hid_dim']
    else:
        args.hid_dim = 256

    if 'align_task_cond' in align_hparams.keys():
        args.task_cond = align_hparams['align_task_cond']
    else:
        args.task_cond = False

    if 'align_activation' in align_hparams.keys():
        args.activation = align_hparams['align_activation']
    else:
        args.activation = 'relu'

    if 'align_repr_activation' in align_hparams.keys():
        args.repr_activation = align_hparams['align_repr_activation']
    else:
        args.repr_activation = 'relu'

    if experiment:
        experiment.log_parameters(align_hparams)
        if 'align_adversarial' in align_hparams.keys(
        ) and align_hparams['align_adversarial']:
            experiment.add_tag('adversarial')

        if 'align_action_type' in align_hparams.keys():
            experiment.add_tag(f'{align_hparams["align_action_type"]}-action')
