import argparse
import logging
import random
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import comet_ml
import d4rl
import gym
import numpy as np
import torch.nn
from comet_ml import Experiment

from custom.transformer_modules import TransformerPredictor
from custom.utils import (PairedTrajectoryDataset, create_savedir_root,
                          eval_cond_transformer_policy,
                          prepare_paired_trajectory_dataset, read_dataset)
from transformer_modules import generate_square_subsequent_mask

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TRAIN_TASK_IDS = [1, 2, 3, 4, 5, 6]


def trans_into_source_obs(original_obs, shift: bool = False):
    source_obs = original_obs[..., [1, 0, 3, 2]]
    if shift:
        source_obs[..., :2] -= 6
    return source_obs


def get_action_translator(action_type: str):
    assert action_type in ['normal',
                           'inv'], f'invalid action_type "{action_type}"'

    def trans_into_source_action(original_action):
        if action_type == 'normal':
            source_action = original_action
        else:
            source_action = -original_action

        return source_action

    return trans_into_source_action


def _filter_by_id(task_ids: np.ndarray):
    select_flag = np.zeros_like(task_ids, dtype=bool)
    for id_ in TRAIN_TASK_IDS:
        select_flag |= (task_ids == id_)

    return select_flag


def task_ids_to_target_pos_array(
    goal_to_task_id: Dict,
    task_ids: List[int],
) -> List[Tuple[int, int]]:
    target_pos_array = []
    for pos, id_ in goal_to_task_id.items():
        if id_ in task_ids:
            target_pos_array.append(pos)
    assert len(target_pos_array) > 0, 'Target pos not found'

    return target_pos_array


def eval_callback(
        policy: TransformerPredictor,
        env: gym.Env,
        epoch: int,
        val_dataset: PairedTrajectoryDataset,
        test_dataset: PairedTrajectoryDataset,
        train_task_ids: List[int],
        test_task_ids: List[int],
        goal_to_task_id: Dict,
        device,
        experiment: Optional[comet_ml.Experiment],
        render_episodes: int = 0,
        savedir_root: Path = Path(),
        times: int = 10,
):
    train_success_rate, train_steps_mean = eval_cond_transformer_policy(
        env=env,
        times=times,
        policy=policy,
        device=device,
        render_episodes=render_episodes,
        video_path=savedir_root / 'train' / f'{epoch}.gif',
        experiment=experiment,
        traj_dataset=val_dataset,
        task_ids=train_task_ids,
        goal_to_task_id=goal_to_task_id,
    )
    test_success_rate, test_steps_mean = eval_cond_transformer_policy(
        env=env,
        times=times,
        policy=policy,
        device=device,
        render_episodes=render_episodes,
        video_path=savedir_root / 'test' / f'{epoch}.gif',
        experiment=experiment,
        traj_dataset=test_dataset,
        task_ids=test_task_ids,
        goal_to_task_id=goal_to_task_id,
    )

    logger.info(
        f'Epoch: {epoch} Success rate: {train_success_rate * 100:.1f}% Step len: {train_steps_mean:.1f}'
    )
    logger.info(
        f'Epoch: {epoch} Success rate (test): {test_success_rate * 100:.1f}% Step len (test): {test_steps_mean:.1f}'
    )
    if experiment is not None:
        experiment.log_metrics(
            {
                f'train_success_rate': train_success_rate,
                f'train_steps_mean': train_steps_mean,
                f'test_success_rate': test_success_rate,
                f'test_steps_mean': test_steps_mean,
            },
            epoch=epoch)


def main(args, experiment: Optional[comet_ml.Experiment]):
    train_loader, val_loader, goal_to_task_id = prepare_paired_trajectory_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator(
            action_type=args.action),
        goal_demo=args.goal_demo,
        sa_demo=args.sa_demo,
    )
    savedir_root = create_savedir_root(phase_tag='cond', env_tag=args.env[:-3])
    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))
        logger.info(f'Save to directory {savedir_root}')

    # TODO fix this
    global TRAIN_TASK_IDS
    prev_train_task_id = deepcopy(TRAIN_TASK_IDS)
    TRAIN_TASK_IDS = [args.goal]
    args.n_traj = None
    test_loader, _, _ = prepare_paired_trajectory_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator(
            action_type=args.action),
        goal_demo=args.goal_demo,
        sa_demo=args.sa_demo,
    )
    TRAIN_TASK_IDS = prev_train_task_id
    TEST_TASK_IDS = [args.goal]

    # dataloader_dict = {}
    # dataloader_dict['train'] = train_loader
    # dataloader_dict['validation'] = val_loader
    s_obs_sample, t_obs_sample, t_act_sample, _, _ = train_loader.dataset[0]

    policy = TransformerPredictor(
        enc_in_dim=s_obs_sample.shape[-1],
        dec_in_dim=t_obs_sample.shape[-1],
        out_dim=t_act_sample.shape[-1],
        embed_dim=256,
        n_head=8,
        ff_dim=256,
        n_enc_layers=3,
        n_dec_layers=3,
    ).to(args.device)

    look_ahead_mask = generate_square_subsequent_mask(sz=args.seq_len).to(
        args.device)

    optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr)

    env = gym.make(args.env)

    def process_batch(args, batch, policy, optimizer, train: bool = False):
        loss_dict = {}
        s_obs, t_obs, t_act, s_pad_mask, t_pad_mask = batch
        s_obs = s_obs.to(args.device)
        t_obs = t_obs.to(args.device)
        t_act = t_act.to(args.device)
        s_pad_mask = s_pad_mask.to(args.device)
        t_pad_mask = t_pad_mask.to(args.device)

        out = policy(
            source_obs=s_obs,
            target_obs=t_obs,
            tgt_look_ahead_mask=look_ahead_mask,
            src_pad_mask=s_pad_mask,
            tgt_pad_mask=t_pad_mask,
        )

        loss = torch.nn.MSELoss()(out[~t_pad_mask], t_act[~t_pad_mask])
        loss_dict['main_loss'] = loss.item()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        return loss_dict

    logger.info('Initial Policy Evaluation...')
    eval_callback(
        policy,
        env,
        0,
        val_loader.dataset,
        test_loader.dataset,
        TRAIN_TASK_IDS,
        TEST_TASK_IDS,
        goal_to_task_id,
        args.device,
        experiment,
        render_episodes=args.render_episodes,
        savedir_root=savedir_root,
    )

    train_losses_dict = defaultdict(list)
    val_losses_dict = defaultdict(list)
    test_losses_dict = defaultdict(list)
    for epoch in range(args.epochs):
        for idx, batch in enumerate(train_loader):
            loss_dict = process_batch(args,
                                      batch=batch,
                                      policy=policy,
                                      optimizer=optimizer,
                                      train=True)
            for key, val in loss_dict.items():
                train_losses_dict[key].append(val)

        for idx, batch in enumerate(val_loader):
            with torch.no_grad():
                loss_dict = process_batch(args,
                                          batch=batch,
                                          policy=policy,
                                          optimizer=optimizer,
                                          train=False)
            for key, val in loss_dict.items():
                val_losses_dict[key].append(val)

        for idx, batch in enumerate(test_loader):
            with torch.no_grad():
                loss_dict = process_batch(args,
                                          batch=batch,
                                          policy=policy,
                                          optimizer=optimizer,
                                          train=False)
            for key, val in loss_dict.items():
                test_losses_dict[key].append(val)

        logger.info(f'=== Epoch {epoch} Summary ===')
        metrics_dict = {}
        for key, val in train_losses_dict.items():
            logger.info(f'train_{key}:\t{np.array(val).mean():.5f}')
            metrics_dict[f'train_{key}'] = np.array(val).mean()
        for key, val in val_losses_dict.items():
            logger.info(f'val_{key}:\t{np.array(val).mean():.5f}')
            metrics_dict[f'val_{key}'] = np.array(val).mean()
        for key, val in test_losses_dict.items():
            logger.info(f'test_{key}:\t{np.array(val).mean():.5f}')
            metrics_dict[f'test_{key}'] = np.array(val).mean()

        if experiment:
            experiment.log_metrics(
                metrics_dict,
                epoch=epoch,
            )

        if (epoch + 1) % args.eval_interval == 0 or (epoch + 1) == args.epochs:
            logger.info('Evaluate Policy...')
            eval_callback(
                policy,
                env,
                epoch + 1,
                val_loader.dataset,
                test_loader.dataset,
                TRAIN_TASK_IDS,
                TEST_TASK_IDS,
                goal_to_task_id,
                args.device,
                experiment,
                render_episodes=args.render_episodes,
                savedir_root=savedir_root,
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--hid_dim', type=int, default=256)
    parser.add_argument('--latent_dim', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--train_ratio', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--plot_times', type=int, default=5)
    parser.add_argument('--n_tasks', type=int, default=-1)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--env', type=str, default='maze2d-medium-v1')
    parser.add_argument('--n_traj', type=int, default=None)
    parser.add_argument('--comet', action='store_true')
    parser.add_argument('--shift', action='store_true')
    parser.add_argument('--eval_interval', type=int, default=10)
    parser.add_argument('--render_episodes', type=int, default=0)
    parser.add_argument('--action',
                        type=str,
                        default='normal',
                        help='"normal" or "inv"')
    parser.add_argument('--activation',
                        type=str,
                        default='relu',
                        help='activation. ["relu", "noact", "tanh", "ln"].')
    parser.add_argument(
        '--repr_activation',
        type=str,
        default='relu',
        help=
        'activation before representation. ["relu", "noact", "tanh", "ln"].')
    parser.add_argument('--no_plot', action='store_true')
    parser.add_argument('--goal_demo', action='store_true')
    parser.add_argument('--single_task', action='store_true')
    parser.add_argument('--sa_demo', action='store_true')
    parser.add_argument('--policy_num_layers',
                        type=int,
                        nargs=3,
                        default=[3, 5, 3])
    parser.add_argument(
        '--goal',
        type=int,
        default=7,
        help="Goal ID of target task. It's not used in alignment phase.")
    args = parser.parse_args()
    args.sa_demo = True

    if args.env == 'maze2d-umaze-v1':
        args.dataset = 'data/maze2d-umaze-sparse-v1.hdf5'
        args.seq_len = 250
        args.calc_align_score = True
        lis = (np.array(range(7)) + 1).tolist()
        lis.remove(args.goal)

        if args.n_tasks > 0:
            lis = random.sample(lis, k=args.n_tasks)

        TRAIN_TASK_IDS = lis
        if args.single_task:
            TRAIN_TASK_IDS = [args.goal]
        logger.info(
            f'Train tasks ({len(TRAIN_TASK_IDS)} tasks) = {TRAIN_TASK_IDS}')
    elif args.env == 'maze2d-medium-v1':
        args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        args.seq_len = 400
        args.calc_align_score = True
        lis = (np.array(range(26)) + 1).tolist()
        lis.remove(args.goal)

        if args.n_tasks > 0:
            lis = random.sample(lis, k=args.n_tasks)

        TRAIN_TASK_IDS = lis
        if args.single_task:
            TRAIN_TASK_IDS = [args.goal]
        logger.info(
            f'Train tasks ({len(TRAIN_TASK_IDS)} tasks) = {TRAIN_TASK_IDS}')
    else:
        raise ValueError(f'Invalid env {args.env} is specified.')

    hparams = {
        'batch_size': args.batch_size,
        'lr': args.lr,
        'epochs': args.epochs,
    }

    if args.comet:
        experiment = Experiment(project_name='d4rl-cd')
        experiment.set_name(args.name)
        experiment.add_tags(['cond', args.env[:-3]])
        experiment.log_parameters(hparams)

    else:
        experiment = None

    main(args, experiment)
