import copy
import os
import uuid
from dataclasses import asdict, dataclass
from typing import Optional, Tuple, Any

import gymnasium as gym
import itertools
from collections import defaultdict
from gymnasium.vector import SyncVectorEnv
import numpy as np
import pyrallis
import torch
from torch import Tensor
from torch.nn import functional as F  # noqa
from torch.utils.data import DataLoader

import wandb
from tqdm import tqdm

import src.envs
from src.envs.ant_dir import train_test_goals_ant
from src.envs.dark_room import train_test_goals_dr
from src.envs.gridworld import train_test_goals_gw
from src.envs.half_cheetah_vel import train_test_goals_hcv
from src.envs.hopper_params import train_test_goals_hopp
from src.envs.semi_circle import train_test_goals_sc
from src.envs.walker_params import train_test_goals_walkp
from src.model_tuples_cache import ADTuples, ADIQLTuples
from src.utils.data import TuplesMapDataset
from src.utils.misc import set_seed, Timeit
from src.utils.schedule import cosine_annealing_with_warmup
from src.envs.dark_key_to_door import train_test_goals
from src.utils.visualization import per_episode_in_context


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class TrainConfig:
    # wandb params
    project: str = "IQLAD"
    group: str = "IQLAD-Tuples-K2D"
    name: str = "iqlad-tuples-k2d"
    # model params
    hidden_dim: int = 512
    num_layers: int = 4
    num_heads: int = 4
    seq_len: int = 200
    attention_dropout: float = 0.5
    residual_dropout: float = 0.1
    embedding_dropout: float = 0.3
    normalize_qk: bool = False
    pre_norm: bool = True
    # training params
    env_name: str = "Dark-Key2Door-9x9-v0"
    learning_rate: float = 3e-4
    warmup_ratio: float = 0.05
    betas: Tuple[float, float] = (0.9, 0.99)
    weight_decay: float = 0.0
    clip_grad: Optional[float] = 1.0
    subsample: int = 1
    batch_size: int = 128
    update_epochs: int = 1
    num_workers: int = 0
    label_smoothing: float = 0.0
    data_uniform_sample: bool = False
    # Dataorder
    data_order: str = "default"
    # evaluation params
    eval_every: int = 25_000
    eval_episodes: int = 200
    eval_train_goals: int = 10
    eval_test_goals: int = 50
    eval_h: int = 500
    # IQL params
    iql_tau: float = 0.9
    iql_beta: float = 3.0
    discount: float = 0.99
    bc_weight: float = 0.0
    soft_update_rate: float = 0.005
    use_ln: bool = False
    use_iql: bool = True
    use_iql_vf: bool = False
    detach_v: bool = False
    reward_multiplier: float = 1.0
    # TD3 params
    policy_noise: float = 0.2
    noise_clip: float = 0.5
    policy_freq: int = 1
    # Test-time tuning
    tune_test: bool = False
    skip_steps: bool = False
    freeze_base: bool = False
    # steps or episodic tunes
    steps_tune: bool = False
    tune_every: int = 1
    tune_updates: int = 5
    test_bc_weight: float = 0.0
    # general params
    learning_histories_path: str = "trajectories"
    checkpoints_path: Optional[str] = None
    train_seed: int = 42
    data_seed: int = 0
    eval_seed: int = 42

    def __post_init__(self):
        assert (self.hidden_dim / self.num_heads) % 8 == 0, "head dim should be multiple of 8 for flash attn"

        self.name = f"{self.name}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor:
    return torch.mean(torch.abs(tau - (u < 0).float()) * u ** 2)


def get_v_loss(q1_target, q2_target, v, iql_tau) -> tuple[Tensor, Any]:
    # Update value function
    # print("V LOSS")
    # print(q1_target.shape, q2_target.shape, v.shape)
    target_q = torch.min(q1_target, q2_target).detach()
    # print(target_q.shape, v.shape)
    adv = target_q - v
    loss = asymmetric_l2_loss(adv, iql_tau)
    return loss, adv


def get_q_loss(
        next_v: torch.Tensor,
        q1, q2,
        rewards,
        dones,
        discount,
):
    targets = rewards + (1.0 - dones) * discount * next_v.detach()
    # print("Q LOSS")
    # print(q1.shape, q2.shape, next_v.shape, rewards.shape, targets.shape)

    q1_loss = F.mse_loss(
        input=q1.flatten(),
        target=targets.flatten(),
    )
    q2_loss = F.mse_loss(
        input=q2.flatten(),
        target=targets.flatten(),
    )

    return q1_loss + q2_loss


def get_awr_policy_loss(
        adv: torch.Tensor,
        policy_out,
        actions: torch.Tensor,
        beta,
):
    exp_adv = torch.exp(beta * adv.detach()).clamp(max=100)
    # print(-policy_out.log_prob(actions).sum(-1, keepdim=False))
    # print(exp_adv)
    # raise ValueError()
    log_prob = policy_out.log_prob(actions)
    log_prob = torch.clamp(torch.nan_to_num(log_prob, nan=0.0, neginf=-1e6, posinf=1e6), min=-1e6, max=1e6)
    bc_losses = -log_prob.sum(-1, keepdim=False)
    # print(exp_adv.shape, bc_losses.shape, (exp_adv * bc_losses).shape)
    # raise ValueError()
    policy_loss = torch.mean((exp_adv * bc_losses))
    return policy_loss


def get_dqn_loss(
        q1, q2,
        next_q1_target, next_q2_target,
        rewards,
        dones,
        discount,
):
    target_q = torch.min(next_q1_target, next_q2_target).detach()
    targets = rewards + (1.0 - dones) * discount * target_q
    # print("Q LOSS")
    # print(q1.shape, q2.shape, next_v.shape, rewards.shape, targets.shape)

    q1_loss = F.mse_loss(
        input=q1.flatten(),
        target=targets.flatten(),
    )
    q2_loss = F.mse_loss(
        input=q2.flatten(),
        target=targets.flatten(),
    )

    return q1_loss + q2_loss


def get_cql_loss(q_all, q_data):
    log_sum_exp = torch.logsumexp(q_all, dim=2)
    # print(log_sum_exp.shape, q_data.shape)
    return (log_sum_exp - q_data).mean()


def train_step(
        config: TrainConfig, model, scaler, optim, policy_optim,
        states, prev_actions, prev_rewards, prev_dones, target_actions, rewards, dones, steps,
        continuous_actions, update_targets=True,
):
    # print(states.shape, prev_actions.shape, prev_rewards.shape, prev_dones.shape, target_actions.shape, steps.shape)
    with (torch.cuda.amp.autocast()):
        q1, q2, q1_target, q2_target, v, (policy_out, target_policy_out), _ = model(
            states=states, prev_actions=prev_actions, prev_rewards=prev_rewards, prev_dones=prev_dones,
            actions=target_actions, steps=steps,
        )
        log_dict = {}
        v_loss = torch.FloatTensor([0.0]).to(DEVICE)

        if config.use_iql:
            v_loss, adv = get_v_loss(
                (torch.gather(q1_target, dim=2, index=target_actions.to(torch.int64).unsqueeze(
                    -1)) if not continuous_actions else q1_target).squeeze(-1),
                (torch.gather(q2_target, dim=2, index=target_actions.to(torch.int64).unsqueeze(
                    -1)) if not continuous_actions else q2_target).squeeze(-1),
                v.squeeze(-1),
                config.iql_tau
            )
            # dones = torch.zeros_like(rewards)[:, :-1]
            # dones[:, -1] = 1
            q_loss = get_q_loss(
                v.squeeze(-1)[:, 1:],
                (torch.gather(q1, dim=2, index=target_actions.to(torch.int64).unsqueeze(
                    -1)) if not continuous_actions else q1).squeeze(-1)[:, :-1],
                (torch.gather(q2, dim=2, index=target_actions.to(torch.int64).unsqueeze(
                    -1)) if not continuous_actions else q2).squeeze(-1)[:, :-1],
                rewards[:, :-1] * config.reward_multiplier,
                dones=dones[:, :-1],
                discount=config.discount,
            )
            log_dict['value_funcs/adv_mean'] = adv.mean().item()
            if continuous_actions and not config.use_iql_vf:
                policy_loss = get_awr_policy_loss(adv, policy_out, target_actions, config.iql_beta)
                policy_loss.backward()
                policy_optim.step()
                policy_optim.zero_grad(set_to_none=True)
                log_dict['policy_loss'] = policy_loss.item()
            if config.use_iql_vf:
                pred_actions = policy_out.mean.to(torch.float)
                q1_actor, _, _, _, _, _, _ = model(
                    states=states, prev_actions=prev_actions, prev_rewards=prev_rewards, prev_dones=prev_dones,
                    actions=pred_actions, steps=steps,
                )

        else:
            if not continuous_actions:
                q_loss = get_dqn_loss(
                    torch.gather(q1, dim=2, index=target_actions.to(torch.int64).unsqueeze(-1)).squeeze(-1)[:,
                    :-1],
                    torch.gather(q2, dim=2, index=target_actions.to(torch.int64).unsqueeze(-1)).squeeze(-1)[:,
                    :-1],
                    torch.max(q1_target, dim=2)[0][:, 1:], torch.max(q2_target, dim=2)[0][:, 1:],
                    rewards[:, :-1] * config.reward_multiplier,
                    dones=dones[:, :-1],
                    discount=config.discount,
                )
            else:
                pred_actions = policy_out.mean.to(torch.float)
                target_pred_actions = target_policy_out.mean.to(torch.float)
                noise = (torch.randn_like(prev_actions) * config.policy_noise).clamp(
                    -config.noise_clip, config.noise_clip
                )
                # print(pred_actions)
                target_pred_actions = (target_pred_actions + noise).clamp(-1, 1)
                _, _, q1_target, q2_target, _, _, _ = model(
                    states=states, prev_actions=prev_actions, prev_rewards=prev_rewards, prev_dones=prev_dones,
                    actions=target_pred_actions, steps=steps,
                )
                q1_actor, _, _, _, _, _, _ = model(
                    states=states, prev_actions=prev_actions, prev_rewards=prev_rewards, prev_dones=prev_dones,
                    actions=pred_actions, steps=steps,
                )

                q_loss = get_dqn_loss(
                    q1.squeeze(-1)[:, :-1],
                    q2.squeeze(-1)[:, :-1],
                    q1_target.squeeze(-1)[:, 1:], q2_target.squeeze(-1)[:, 1:],
                    rewards[:, :-1] * config.reward_multiplier,
                    dones=dones[:, :-1],
                    discount=config.discount,
                )

        cql_loss = torch.FloatTensor([0.0]).to(DEVICE)
        if not continuous_actions:
            # cql_loss = get_cql_loss(
            #     q1, torch.gather(q1, dim=2, index=target_actions.to(torch.int64).unsqueeze(-1)).squeeze(-1)
            # ) + get_cql_loss(
            #     q2, torch.gather(q2, dim=2, index=target_actions.to(torch.int64).unsqueeze(-1)).squeeze(-1)
            # )
            cql_loss = F.cross_entropy(
                input=q1.flatten(0, 1),
                target=target_actions.flatten(0, 1),
                label_smoothing=config.label_smoothing,
            ) + F.cross_entropy(
                input=q2.flatten(0, 1),
                target=target_actions.flatten(0, 1),
                label_smoothing=config.label_smoothing,
            )

        loss = v_loss + q_loss + config.bc_weight * cql_loss
        predicted_actions = q1

    scaler.scale(loss).backward(retain_graph=True)

    log_dict.update({
        "q_loss": q_loss.item(),
        "v_loss": v_loss.item(),
        "cql_loss": cql_loss.item(),
        "value_funcs/q1_mean": q1.mean().item(),
        "value_funcs/q2_mean": q2.mean().item(),
        "value_funcs/v_mean": v.mean().item(),
    })

    if continuous_actions and update_targets:
        if not config.use_iql or config.use_iql_vf:
            bc_loss = F.mse_loss(pred_actions, target_actions)
            norm_coef = 1 / q1_actor.abs().mean().detach()
            policy_loss = config.bc_weight * bc_loss - norm_coef * q1_actor.mean()
            log_dict['policy_loss'] = policy_loss.item()
            policy_loss.backward()
            policy_optim.step()
            policy_optim.zero_grad(set_to_none=True)
            log_dict["BC_loss"] = bc_loss.item()

    if config.clip_grad is not None:
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad)
    scaler.step(optim)
    scaler.update()
    optim.zero_grad(set_to_none=True)
    if update_targets:
        model.update_targets(config.soft_update_rate)

    return predicted_actions, log_dict


@torch.no_grad()
def evaluate_in_context(env_name, model,  goals, eval_episodes, seed=None, cont_s=False, cont_a=False,):
    vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
    tmp_env = gym.make(env_name, goal_pos=goals[0])

    if cont_s:
        states = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.observation_space.shape[1]), dtype=torch.float32,
                             device=DEVICE)
    else:
        states = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
    if cont_a:
        prev_actions = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.action_space.shape[1]), dtype=torch.float32, device=DEVICE)
    else:
        prev_actions = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
    prev_rewards = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
    prev_dones = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
    steps = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)

    # to track number of episodes for each goal and returns
    num_episodes = np.zeros(vec_env.num_envs)
    returns = np.zeros(vec_env.num_envs)
    # for logging
    eval_info = defaultdict(list)

    state, _ = vec_env.reset(seed=seed)
    # print("INIT STATE", state, flush=True)
    if cont_a:
        prev_action = torch.zeros((vec_env.num_envs, vec_env.action_space.shape[1]), dtype=torch.float32, device=DEVICE)
    else:
        prev_action = 0

    prev_reward, prev_done = 0, 0
    prev_step = torch.zeros((vec_env.num_envs,), dtype=torch.float32, device=DEVICE)

    for step in itertools.count(start=1):
        # roll context back for new step
        states = states.roll(-1, dims=0)
        prev_actions = prev_actions.roll(-1, dims=0)
        prev_rewards = prev_rewards.roll(-1, dims=0)
        prev_dones = prev_dones.roll(-1, dims=0)
        steps = steps.roll(-1, dims=0)

        # fill last s-a-r tuple
        states[-1] = torch.as_tensor(state, device=DEVICE)
        prev_actions[-1] = torch.as_tensor(prev_action, device=DEVICE)
        prev_rewards[-1] = torch.as_tensor(prev_reward, device=DEVICE)
        prev_dones[-1] = torch.as_tensor(prev_done, device=DEVICE)
        steps[-1] = torch.as_tensor(prev_step, device=DEVICE)

        # predict next action
        with torch.cuda.amp.autocast():
            # [num_envs, seq_len, num_actions] -> [num_envs, num_actions]
            q, _, _, _, v, (policy_out, _), _ = model(
                states=states[-step:].permute(1, 0, 2) if cont_s else states[-step:].permute(1, 0),
                prev_actions=prev_actions[-step:].permute(1, 0, 2) if cont_a else prev_actions[-step:].permute(1, 0),
                prev_rewards=prev_rewards[-step:].permute(1, 0),
                prev_dones=prev_dones[-step:].permute(1, 0),
                steps=steps[-step:].permute(1, 0),
            )
        if not cont_a:
            logits = q[:, -1]# - v[:, -1]
            dist = torch.distributions.Categorical(logits=logits)
            # action = dist.sample()
            action = dist.mode
        else:
            action = policy_out.mean[:, -1]
        # query the world
        state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
        prev_step += 1
        done = terminated | truncated

        # relabel for the next step
        prev_action = action
        prev_reward = reward
        prev_done = done

        num_episodes += done.astype(int)
        returns += reward

        # log returns if done
        for i, d in enumerate(done):
            if d and num_episodes[i] <= eval_episodes:
                if "Params" not in env_name:
                    log_key = (
                        tmp_env.unwrapped.pos_to_state(goals[i][0]),
                        tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
                        goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
                else:
                    log_key =  tmp_env.unwrapped.pos_to_state(goals[i])
                eval_info[log_key].append(returns[i])
                # reset return for this goal
                returns[i] = 0.0
                prev_step[i] = 0.0

        # check that all goals are done
        if np.all(num_episodes > eval_episodes):
            break

    vec_env.close()
    tmp_env.close()
    return eval_info


def evaluate_in_context_with_tune(env_name, original_model, goals, eval_episodes, seed=None, cont_s=False, cont_a=False, config: TrainConfig = None):
    # for logging
    eval_info = defaultdict(list)

    bc_w = config.bc_weight
    config.bc_weight = config.test_bc_weight
    for goal in tqdm(goals):
        model = copy.deepcopy(original_model)
        if config.freeze_base:
            model.freeze_base()
        scaler = torch.cuda.amp.GradScaler()
        optim = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, model.parameters()),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
            betas=config.betas,
        )
        policy_optim = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, model.pi.parameters()),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
            betas=config.betas,
        )

        vec_env = SyncVectorEnv([lambda: gym.make(env_name, goal_pos=goal)])
        tmp_env = gym.make(env_name, goal_pos=goals[0])

        if cont_s:
            states = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.observation_space.shape[1]), dtype=torch.float32,
                                 device=DEVICE)
        else:
            states = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
        if cont_a:
            prev_actions = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.action_space.shape[1]), dtype=torch.float32, device=DEVICE)
        else:
            prev_actions = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
        prev_rewards = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
        prev_dones = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
        steps = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)

        # to track number of episodes for each goal and returns
        num_episodes = np.zeros(vec_env.num_envs)
        returns = np.zeros(vec_env.num_envs)

        state, _ = vec_env.reset(seed=seed)
        if cont_a:
            prev_action = torch.zeros((vec_env.num_envs, vec_env.action_space.shape[1]), dtype=torch.float32, device=DEVICE)
        else:
            prev_action = 0

        prev_reward, prev_done = 0, 0
        prev_step = torch.zeros((vec_env.num_envs,), dtype=torch.float32, device=DEVICE)
        for step in itertools.count(start=1):
            model.eval()
            # roll context back for new step
            states = states.roll(-1, dims=0)
            prev_actions = prev_actions.roll(-1, dims=0)
            prev_rewards = prev_rewards.roll(-1, dims=0)
            prev_dones = prev_dones.roll(-1, dims=0)
            steps = steps.roll(-1, dims=0)

            # fill last s-a-r tuple
            states[-1] = torch.as_tensor(state, device=DEVICE)
            prev_actions[-1] = torch.as_tensor(prev_action, device=DEVICE)
            prev_rewards[-1] = torch.as_tensor(prev_reward, device=DEVICE)
            prev_dones[-1] = torch.as_tensor(prev_done, device=DEVICE)
            steps[-1] = torch.as_tensor(prev_step, device=DEVICE)

            # predict next action
            with torch.cuda.amp.autocast():
                # [num_envs, seq_len, num_actions] -> [num_envs, num_actions]
                q, _, _, _, v, (policy_out, _), _ = model(
                    states=states[-step:].permute(1, 0, 2) if cont_s else states[-step:].permute(1, 0),
                    prev_actions=prev_actions[-step:].permute(1, 0, 2) if cont_a else prev_actions[-step:].permute(1, 0),
                    prev_rewards=prev_rewards[-step:].permute(1, 0),
                    prev_dones=prev_dones[-step:].permute(1, 0),
                    steps=steps[-step:].permute(1, 0),
                )
            if not cont_a:
                logits = q[:, -1]# - v[:, -1]
                dist = torch.distributions.Categorical(logits=logits)
                # action = dist.sample()
                action = dist.mode
            else:
                action = policy_out.mean[:, -1]
            # query the world
            state, reward, terminated, truncated, _ = vec_env.step(action.detach().cpu().numpy())
            done = terminated | truncated

            # relabel for the next step
            prev_action = action
            prev_reward = reward
            prev_done = done

            num_episodes += done.astype(int)
            returns += reward

            # log returns if done
            for i, d in enumerate(done):
                if d and num_episodes[i] <= eval_episodes:
                    log_key = (
                        tmp_env.unwrapped.pos_to_state(goal[0]),
                        tmp_env.unwrapped.pos_to_state(goal[1])) if np.prod(
                        goal.shape) == 4 else tmp_env.unwrapped.pos_to_state(goal)
                    eval_info[log_key].append(returns[i])
                    # reset return for this goal
                    returns[i] = 0.0

            # check that all goals are done
            if np.all(num_episodes > eval_episodes):
                break

            target_actions = prev_actions.roll(-1, dims=0)
            target_rewards = prev_rewards.roll(-1, dims=0)
            target_dones = prev_dones.roll(-1, dims=0)

            # fill last s-a-r tuple
            target_actions[-1] = torch.as_tensor(prev_action, device=DEVICE)
            target_rewards[-1] = torch.as_tensor(prev_reward, device=DEVICE)
            target_dones[-1] = torch.as_tensor(prev_done, device=DEVICE)

            if config.steps_tune:
                if step % config.tune_every == 0:
                    for _ in range(config.tune_updates):

                        _, update_log = train_step(
                            config, model, scaler, optim, policy_optim,
                            states=states[-step:].permute(1, 0, 2) if cont_s else states[-step:].permute(1, 0),
                            prev_actions=prev_actions[-step:].permute(1, 0, 2) if cont_a else prev_actions[
                                                                                              -step:].permute(1, 0),
                            prev_rewards=prev_rewards[-step:].permute(1, 0),
                            prev_dones=prev_dones[-step:].permute(1, 0),
                            target_actions=target_actions[-step:].permute(1, 0, 2) if cont_a else target_actions[
                                                                                              -step:].permute(1, 0),
                            rewards=target_rewards[-step:].permute(1, 0),
                            dones=target_dones[-step:].permute(1, 0),
                            continuous_actions=cont_a,
                            steps=steps[-step:].permute(1, 0),
                            update_targets=True,
                        )
                        print(update_log, flush=True)
            else:
                if done[0] and num_episodes[0] % config.tune_every == 0:
                    for _ in range(config.tune_updates):
                        train_step(
                            config, model, scaler, optim, policy_optim,
                            states=states[-step:].permute(1, 0, 2) if cont_s else states[-step:].permute(1, 0),
                            prev_actions=prev_actions[-step:].permute(1, 0, 2) if cont_a else prev_actions[
                                                                                              -step:].permute(1, 0),
                            prev_rewards=prev_rewards[-step:].permute(1, 0),
                            prev_dones=prev_dones[-step:].permute(1, 0),
                            target_actions=target_actions[-step:].permute(1, 0, 2) if cont_a else target_actions[
                                                                                                  -step:].permute(1, 0),
                            rewards=target_rewards[-step:].permute(1, 0),
                            dones=target_dones[-step:].permute(1, 0),
                            continuous_actions=cont_a,
                            steps=steps[-step:].permute(1, 0),
                        )

        vec_env.close()
        tmp_env.close()
    config.bc_weight = bc_w
    return eval_info


# @torch.no_grad()
# def evaluate_in_context_steps(env_name, model,  goals, max_steps, seed=None, cont_s=False, cont_a=False,):
#     vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
#     tmp_env = gym.make(env_name, goal_pos=goals[0])
#
#     if cont_s:
#         states = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.observation_space.shape[1]), dtype=torch.float32,
#                              device=DEVICE)
#     else:
#         states = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
#     if cont_a:
#         prev_actions = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.action_space.shape[1]),
#                                    dtype=torch.float32, device=DEVICE)
#     else:
#         prev_actions = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
#     prev_rewards = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
#     prev_dones = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
#     steps = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)
#
#     # to track number of episodes for each goal and returns
#     num_episodes = np.zeros(vec_env.num_envs)
#     returns = np.zeros(vec_env.num_envs)
#     # for logging
#     eval_info = defaultdict(list)
#
#     state, _ = vec_env.reset(seed=seed)
#     if cont_a:
#         prev_action = torch.zeros((vec_env.num_envs, vec_env.action_space.shape[1]), dtype=torch.float32, device=DEVICE)
#     else:
#         prev_action = 0
#
#     prev_reward, prev_done = 0, 0
#     prev_step = torch.zeros((vec_env.num_envs,), dtype=torch.float32, device=DEVICE)
#
#     for step in itertools.count(start=1):
#         if max_steps < step:
#             break
#         # roll context back for new step
#         states = states.roll(-1, dims=0)
#         prev_actions = prev_actions.roll(-1, dims=0)
#         prev_rewards = prev_rewards.roll(-1, dims=0)
#         prev_dones = prev_dones.roll(-1, dims=0)
#         steps = steps.roll(-1, dims=0)
#
#         # fill last s-a-r tuple
#         states[-1] = torch.as_tensor(state, device=DEVICE)
#         prev_actions[-1] = torch.as_tensor(prev_action, device=DEVICE)
#         prev_rewards[-1] = torch.as_tensor(prev_reward, device=DEVICE)
#         prev_dones[-1] = torch.as_tensor(prev_done, device=DEVICE)
#         steps[-1] = torch.as_tensor(prev_step, device=DEVICE)
#
#         # predict next action
#         with torch.cuda.amp.autocast():
#             # [num_envs, seq_len, num_actions] -> [num_envs, num_actions]
#             q, _, _, _, v, policy_out, _ = model(
#                 states=states[-step:].permute(1, 0, 2) if cont_s else states[-step:].permute(1, 0),
#                 prev_actions=prev_actions[-step:].permute(1, 0, 2) if cont_a else prev_actions[-step:].permute(1, 0),
#                 prev_rewards=prev_rewards[-step:].permute(1, 0),
#                 prev_dones=prev_dones[-step:].permute(1, 0),
#                 steps=steps[-step:].permute(1, 0),
#             )
#         if not cont_a:
#             logits = q[:, -1]  # - v[:, -1]
#             dist = torch.distributions.Categorical(logits=logits)
#             # action = dist.sample()
#             action = dist.mode
#         else:
#             action = policy_out.mean[:, -1]
#         # query the world
#         state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
#         prev_step += 1
#         done = terminated | truncated
#
#         # relabel for the next step
#         prev_action = action
#         prev_reward = reward
#         prev_done = done
#
#         num_episodes += done.astype(int)
#         returns += reward
#
#     for i in range(vec_env.num_envs):
#         log_key = (
#             tmp_env.unwrapped.pos_to_state(goals[i][0]), tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
#             goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
#         eval_info[log_key].append(returns[i])
#
#     vec_env.close()
#     tmp_env.close()
#     return eval_info


# @torch.no_grad()
# def evaluate_in_context_with_cache(env_name, model,  goals, eval_episodes, seed=None, cont_s=False, cont_a=False,):
#     vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
#     tmp_env = gym.make(env_name, goal_pos=goals[0])
#
#     kv_cache = model.init_cache(batch_size=vec_env.num_envs, dtype=torch.float16, device=DEVICE)
#     # to track number of episodes for each goal and returns
#     num_episodes = np.zeros(vec_env.num_envs)
#     returns = np.zeros(vec_env.num_envs)
#     # for logging
#     eval_info = defaultdict(list)
#
#     state, _ = vec_env.reset(seed=seed)
#     if cont_a:
#         prev_action = np.zeros((vec_env.num_envs, vec_env.action_space.shape[1]))
#     else:
#         prev_action = np.zeros(vec_env.num_envs)
#     prev_reward, prev_done = np.zeros(vec_env.num_envs), np.zeros(vec_env.num_envs)
#     for step in itertools.count(start=1):
#         # predict next action
#         with torch.cuda.amp.autocast():
#             # [num_envs, seq_len=1, num_actions] -> [num_envs, num_actions]
#             state_dtype = torch.float if cont_s else torch.long
#             action_dtype = torch.float if cont_a else torch.long
#             q, _, _, _, v, policy_out, kv_cache = model(
#                 states=torch.as_tensor(state, dtype=state_dtype, device=DEVICE)[:, None],
#                 prev_actions=torch.as_tensor(prev_action, dtype=action_dtype, device=DEVICE)[:, None],
#                 prev_rewards=torch.as_tensor(prev_reward, dtype=torch.float, device=DEVICE)[:, None],
#                 prev_dones=torch.as_tensor(prev_done, dtype=torch.float, device=DEVICE)[:, None],
#                 cache=kv_cache
#             )
#         if not cont_a:
#             logits = q[:, -1]  # - v[:, -1]
#             dist = torch.distributions.Categorical(logits=logits)
#             # action = dist.sample()
#             action = dist.mode
#         else:
#             action = policy_out.mean[:, -1]
#
#         # query the world
#         state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
#         done = terminated | truncated
#
#         # relabel for the next step
#         prev_action = action
#         prev_reward = reward
#         prev_done = done
#
#         num_episodes += done.astype(int)
#         returns += reward
#
#         # log returns if done
#         for i, d in enumerate(done):
#             if d and num_episodes[i] <= eval_episodes:
#                 log_key = (
#                     tmp_env.unwrapped.pos_to_state(goals[i][0]),
#                     tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
#                     goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
#                 eval_info[log_key].append(returns[i])
#                 # reset return for this goal
#                 returns[i] = 0.0
#
#         # check that all goals are done
#         if np.all(num_episodes > eval_episodes):
#             break
#
#     vec_env.close()
#     tmp_env.close()
#     return eval_info


# @torch.no_grad()
# def evaluate_in_context_with_cache_steps(env_name, model,  goals, max_steps, seed=None):
#     vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
#     tmp_env = gym.make(env_name, goal_pos=goals[0])
#
#     kv_cache = model.init_cache(batch_size=vec_env.num_envs, dtype=torch.float16, device=DEVICE)
#     # to track number of episodes for each goal and returns
#     num_episodes = np.zeros(vec_env.num_envs)
#     returns = np.zeros(vec_env.num_envs)
#     # for logging
#     eval_info = defaultdict(list)
#
#     state, _ = vec_env.reset(seed=seed)
#     prev_action, prev_reward, prev_done = np.zeros(vec_env.num_envs), np.zeros(vec_env.num_envs), np.zeros(vec_env.num_envs)
#     for step in itertools.count(start=1):
#         if max_steps < step:
#             break
#         # predict next action
#         with torch.cuda.amp.autocast():
#             # [num_envs, seq_len=1, num_actions] -> [num_envs, num_actions]
#             q, _, _, _, v, policy_out, kv_cache = model(
#                 states=torch.as_tensor(state, dtype=torch.long, device=DEVICE)[:, None],
#                 prev_actions=torch.as_tensor(prev_action, dtype=torch.long, device=DEVICE)[:, None],
#                 prev_rewards=torch.as_tensor(prev_reward, dtype=torch.float, device=DEVICE)[:, None],
#                 prev_dones=torch.as_tensor(prev_done, dtype=torch.float, device=DEVICE)[:, None],
#                 cache=kv_cache
#             )
#             logits = q[:, -1]# - v[:, -1]
#         dist = torch.distributions.Categorical(logits=logits)
#         # action = dist.sample()
#         action = dist.mode
#
#         # query the world
#         state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
#         done = terminated | truncated
#
#         # relabel for the next step
#         prev_action = action
#         prev_reward = reward
#         prev_done = done
#
#         num_episodes += done.astype(int)
#         returns += reward
#
#         # log returns if done
#
#     for i in range(vec_env.num_envs):
#         log_key = (
#             tmp_env.unwrapped.pos_to_state(goals[i][0]), tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
#             goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
#         eval_info[log_key].append(returns[i])
#     vec_env.close()
#     tmp_env.close()
#     return eval_info


def split_info_debug(eval_info, train_goals, test_goals, env, transform_keys=True, transform_dict=False):
    eval_info_train = defaultdict(list)
    eval_info_test = defaultdict(list)

    if isinstance(train_goals, np.ndarray):
        train_goals = train_goals.tolist()
    if isinstance(test_goals, np.ndarray):
        test_goals = test_goals.tolist()
    if transform_dict:
        train_goals = list(map(env.unwrapped.pos_to_state, train_goals))
        test_goals = list(map(env.unwrapped.pos_to_state, test_goals))
    # print(train_goals)
    # print(test_goals)
    for i, (k, v) in enumerate(eval_info.items()):
        # print(k)
        curr_goal = k
        if transform_keys:
            if type(k) is not int:
                curr_goal = [env.unwrapped.state_to_pos(k[0]).tolist(), env.unwrapped.state_to_pos(k[1]).tolist()]
            else:
                curr_goal = env.unwrapped.state_to_pos(k).tolist()
        elif not transform_dict:
            curr_goal = list(k)
        if curr_goal in train_goals:
            eval_info_train[str(curr_goal)] = v
        elif curr_goal in test_goals:
            eval_info_test[str(curr_goal)] = v
        else:
            raise ValueError()

    return eval_info_train, eval_info_test


def process_steps_eval(eval_info):
    returns = [h[0] for h in eval_info.values()]
    # print(returns)
    return np.mean(returns), np.std(returns)


def get_auc(points, lower=0):
    points = [p - lower for p in points]
    S = 0
    for a, b in zip(points, points[1:]):
        S += min(a, b) + abs(a - b) / 2
    return S


@pyrallis.wrap()
def train(config: TrainConfig):
    dict_config = asdict(config)
    dict_config["mlc_job"] = os.getenv("PLATFORM_JOB_NAME")
    wandb.init(
        project=config.project,
        group=config.group,
        name=config.name,
        config=dict_config,
    )

    key_shape = 2
    transform_dicts = False
    if "Key" in config.env_name:
        key_shape = 4
    elif "HalfCheetahVel-v0" == config.env_name:
        key_shape = 1
    elif "AntDir-v0" == config.env_name:
        key_shape = 1
    elif "HopperParams-v0" == config.env_name:
        key_shape = None
        transform_dicts = True
    elif "Walker2dParams-v0" == config.env_name:
        key_shape = None
        transform_dicts = True

    random_order = False
    sorted_order = False
    sample_ordered = False
    if config.data_order == 'random':
        random_order = True
    elif config.data_order == 'sorted':
        sorted_order = True
    elif config.data_order == 'sample':
        sample_ordered = True

    set_seed(config.train_seed)
    dataset = TuplesMapDataset(
        data_path=config.learning_histories_path,
        seq_len=config.seq_len,
        subsample=config.subsample,
        goal_dim=key_shape,
        random_order=random_order,
        sorted_order=sorted_order,
        sample_ordered=sample_ordered,
    )
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=config.batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=config.num_workers,
    )
    tmp_env = gym.make(config.env_name)

    # eval preparation
    test_goals_ood = None
    if "Key" in config.env_name:
        train_goals, test_goals = train_test_goals(
            grid_size=gym.make(config.env_name).unwrapped.size,
            num_train_goals=len(dataset.unique_goals),
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "Room" in config.env_name or "Janus" in config.env_name:
        train_goals, test_goals = train_test_goals_dr(
            grid_size=gym.make(config.env_name).unwrapped.size,
            num_train_goals=len(dataset.unique_goals),
            seed=config.data_seed
        )
    elif "Grid" in config.env_name:
        train_goals, test_goals = train_test_goals_gw(
            grid_size=gym.make(config.env_name).unwrapped.size,
            num_train_goals=len(dataset.unique_goals),
            seed=config.data_seed
        )
    elif "Semi-Circle" in config.env_name:
        train_goals, test_goals = train_test_goals_sc(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "HalfCheetahVel-v0" in config.env_name:
        train_goals, test_goals, test_goals_ood = train_test_goals_hcv(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "AntDir-v0" in config.env_name:
        train_goals, test_goals = train_test_goals_ant(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "HopperParams-v0" in config.env_name:
        train_goals, test_goals = train_test_goals_hopp(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "Walker2dParams-v0" in config.env_name:
        train_goals, test_goals = train_test_goals_walkp(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )

    train_goals = train_goals[:config.eval_train_goals]
    if "Params" not in config.env_name:
        eval_all_goals = np.vstack([train_goals, test_goals])
    else:
        eval_all_goals = train_goals + test_goals

    continuous_states = config.env_name in ["Semi-Circle-Sparse-v0", "Semi-Circle-v0", "HalfCheetahVel-v0", "AntDir-v0", "HopperParams-v0", "Walker2dParams-v0"]
    continuous_actions = config.env_name in ["Semi-Circle-Sparse-v0", "Semi-Circle-v0", "HalfCheetahVel-v0", "AntDir-v0", "HopperParams-v0", "Walker2dParams-v0"]
    transform_keys = config.env_name not in ["Semi-Circle-Sparse-v0", "Semi-Circle-v0", "HalfCheetahVel-v0", "AntDir-v0", "HopperParams-v0", "Walker2dParams-v0"]

    # model & optimizer & scheduler setup
    set_seed(config.train_seed)

    model = ADIQLTuples(
        num_states=tmp_env.observation_space.shape[0] if continuous_states else tmp_env.observation_space.n,
        num_actions=tmp_env.action_space.shape[0] if continuous_actions else tmp_env.action_space.n,
        hidden_dim=config.hidden_dim,
        seq_len=config.seq_len,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        attention_dropout=config.attention_dropout,
        residual_dropout=config.residual_dropout,
        embedding_dropout=config.embedding_dropout,
        normalize_qk=config.normalize_qk,
        pre_norm=config.pre_norm,
        use_ln=config.use_ln,
        continuous_states=continuous_states,
        continuous_actions=continuous_actions,
        detach_v=config.detach_v,
    ).to(DEVICE)

    # if needed, test beforehand
    # model = torch.compile(model)

    optim = torch.optim.Adam(
        params=model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=config.betas,
    )

    policy_optim = torch.optim.Adam(
        params=model.pi.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=config.betas,
    )

    total_updates = len(dataloader) * config.update_epochs
    # scheduler = cosine_annealing_with_warmup(
    #     optimizer=optim,
    #     warmup_steps=int(total_updates * config.warmup_ratio),
    #     total_steps=total_updates,
    # )

    # save config to the checkpoint
    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    scaler = torch.cuda.amp.GradScaler()
    print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
    global_step = 0
    # for i in range(10):
    #     print(dataset[i])
    # return
    set_seed(config.train_seed)
    for epoch in range(config.update_epochs):
        for batch in dataloader:
            states, prev_actions, prev_rewards, prev_dones, target_actions, rewards, dones, steps = [b.to(DEVICE) for b in batch]

            states = states.to(torch.long) if not continuous_states else states.to(torch.float)
            # print(states)
            prev_actions = prev_actions.to(torch.long) if not continuous_actions else prev_actions.to(torch.float)
            prev_rewards = prev_rewards.to(torch.float) #* config.reward_multiplier
            prev_dones = prev_dones.to(torch.float)
            target_actions = target_actions.to(torch.long) if not continuous_actions else target_actions.to(torch.float)
            rewards = rewards.to(torch.float) #* config.reward_multiplier
            dones = dones.to(torch.float)
            steps = steps.to(torch.float)

            with Timeit() as timer:
                update_targets = True
                if not config.use_iql or config.use_iql_vf:
                    update_targets = (global_step % config.policy_freq == 0)
                predicted_actions, log_dict = train_step(
                    config, model, scaler, optim, policy_optim,
                    states, prev_actions, prev_rewards, prev_dones, target_actions, rewards, dones, steps,
                    continuous_actions, update_targets=update_targets
                )
                # scheduler.step()

            with torch.no_grad():
                if not continuous_actions:
                    a = torch.argmax(predicted_actions.flatten(0, 1), dim=-1)
                    t = target_actions.flatten()
                    accuracy = torch.sum(a == t) / a.shape[0]
                    log_dict['accuracy'] = accuracy
                log_dict["udpate-time"] = timer.elapsed_time_gpu
                log_dict["epoch"] = epoch
                wandb.log(
                    log_dict, step=global_step,
                )

            global_step += 1
        model.eval()
        with Timeit() as eval_timer:
            if not config.tune_test:
                eval_info = evaluate_in_context(
                    env_name=config.env_name,
                    model=model,
                    goals=eval_all_goals,
                    eval_episodes=config.eval_episodes,
                    seed=config.eval_seed,
                    cont_s=continuous_states,
                    cont_a=continuous_actions,
                )
                if test_goals_ood is not None:
                    eval_info_ood = evaluate_in_context(
                        env_name=config.env_name,
                        model=model,
                        goals=test_goals_ood,
                        eval_episodes=config.eval_episodes,
                        seed=config.eval_seed,
                        cont_s=continuous_states,
                        cont_a=continuous_actions,
                    )
                if "Janus" in config.env_name:
                    eval_info_default = evaluate_in_context(
                        env_name="Janus-Default-19x19-v0",
                        # env_name=config.env_name,
                        model=model,
                        goals=eval_all_goals,
                        eval_episodes=config.eval_episodes,
                        seed=config.eval_seed,
                        cont_s=continuous_states,
                        cont_a=continuous_actions,
                    )
                    eval_info_inverted = evaluate_in_context(
                        env_name="Janus-Inverted-19x19-v0",
                        # env_name=config.env_name,
                        model=model,
                        goals=eval_all_goals,
                        eval_episodes=config.eval_episodes,
                        seed=config.eval_seed,
                        cont_s=continuous_states,
                        cont_a=continuous_actions,
                    )
            else:
                eval_info = evaluate_in_context_with_tune(
                    env_name=config.env_name,
                    original_model=model,
                    goals=eval_all_goals,
                    eval_episodes=config.eval_episodes,
                    seed=config.eval_seed,
                    cont_s=continuous_states,
                    cont_a=continuous_actions,
                    config=config,
                )
        eval_info_train, eval_info_test = split_info_debug(eval_info, train_goals, test_goals, env=tmp_env, transform_keys=transform_keys, transform_dict=transform_dicts)

        # eval_steps_info = evaluate_in_context_steps(
        #     env_name=config.env_name,
        #     model=model,
        #     goals=eval_all_goals,
        #     max_steps=config.eval_h,
        #     seed=config.eval_seed,
        #     cont_a=continuous_actions,
        #     cont_s=continuous_states,
        # )
        # eval_steps_info_train, eval_steps_info_test = split_info_debug(eval_steps_info, train_goals, test_goals,
        #                                                                env=tmp_env, transform_keys=transform_keys)
        # train_mean, train_std = process_steps_eval(eval_steps_info_train)
        # test_mean, test_std = process_steps_eval(eval_steps_info_test)

        # with Timeit() as cache_eval_timer:
        #     cache_eval_info = evaluate_in_context_with_cache(
        #         env_name=config.env_name,
        #         model=model,
        #         goals=eval_all_goals,
        #         eval_episodes=config.eval_episodes,
        #         seed=config.eval_seed,
        #         cont_s=continuous_states,
        #         cont_a=continuous_actions,
        #     )
        # cache_eval_info_train, cache_eval_info_test = split_info_debug(cache_eval_info, train_goals, test_goals,
        #                                                                env=tmp_env, transform_keys=transform_keys)

        max_return = 1.0
        min_return = -0.3
        auc_lower = 0
        if "Key" in config.env_name:
            max_return = 2.0
        elif "Grid" in config.env_name:
            max_return = 14.0
        elif "Sparse" in config.env_name:
            max_return = 50
        elif "Semi" in config.env_name:
            min_return = -30
            max_return = -4.5
        elif "HalfCheetahVel-v0" == config.env_name:
            max_return = 0
            min_return = -500
            auc_lower = -410.5
        elif "AntDir-v0" == config.env_name:
            max_return = 1000
            min_return = 0
            auc_lower = -69.8
        elif "HopperParams-v0" == config.env_name:
            max_return = 500
            min_return = 0
            auc_lower = 17
        elif "Walker2dParams-v0" == config.env_name:
            max_return = 250
            min_return = 0
            auc_lower = 3.6

        pic_name_train = per_episode_in_context(
            eval_info_train, ylim=[min_return, max_return + 0.5], name=f"train-viz", max_return=max_return,
        )
        pic_name_test = per_episode_in_context(
            eval_info_test, ylim=[min_return, max_return + 0.5], name=f"test-viz", max_return=max_return,
        )

        if "HalfCheetahVel-v0" == config.env_name:
            max_return = -78.6
        if "AntDir-v0" == config.env_name:
            max_return = 773.8
        if "HopperParams-v0" == config.env_name:
            max_return = 220.0
        if "Walker2dParams-v0" == config.env_name:
            max_return = 220.0

        auc_norm_coef = get_auc([max_return] * config.eval_episodes, lower=auc_lower)
        train_aucs = [get_auc(h, lower=auc_lower) / auc_norm_coef for h in eval_info_train.values()]
        test_aucs = [get_auc(h, lower=auc_lower) / auc_norm_coef for h in eval_info_test.values()]

        wandb.log({
            "eval/train_mean_return": np.mean([h[-1] for h in eval_info_train.values()]),
            "eval/train_median_return": np.median([h[-1] for h in eval_info_train.values()]),
            "eval/test_mean_return": np.mean([h[-1] for h in eval_info_test.values()]),
            "eval/test_median_return": np.median([h[-1] for h in eval_info_test.values()]),
            "eval/train_mean_return_half": np.mean([h[config.eval_episodes // 2] for h in eval_info_train.values()]),
            "eval/train_median_return_half": np.median(
                [h[config.eval_episodes // 2] for h in eval_info_train.values()]),
            "eval/test_mean_return_half": np.mean([h[config.eval_episodes // 2] for h in eval_info_test.values()]),
            "eval/test_median_return_half": np.median([h[config.eval_episodes // 2] for h in eval_info_test.values()]),
            "eval/train_mean_return_quarter": np.mean([h[config.eval_episodes // 4] for h in eval_info_train.values()]),
            "eval/train_median_return_quarter": np.median(
                [h[config.eval_episodes // 4] for h in eval_info_train.values()]),
            "eval/test_mean_return_quarter": np.mean([h[config.eval_episodes // 4] for h in eval_info_test.values()]),
            "eval/test_median_return_quarter": np.median(
                [h[config.eval_episodes // 4] for h in eval_info_test.values()]),
            "eval/train_auc": np.mean(train_aucs),
            "eval/test_auc": np.mean(test_aucs),
            "eval/train_graph": wandb.Image(pic_name_train),
            "eval/test_graph": wandb.Image(pic_name_test),
            "eval/eval-time": eval_timer.elapsed_time_gpu,
            "aggregated_metrics/train_avg_returns":
                (np.mean([h[-1] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_train.values()])) / 3,
            "aggregated_metrics/train_avg_all":
                (np.mean([h[-1] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_train.values()])) / 3 / max_return / 2 + np.mean(
                    train_aucs) / 2,
            "aggregated_metrics/test_avg_returns":
                (np.mean([h[-1] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test.values()])) / 3,
            "aggregated_metrics/test_avg_all":
                (np.mean([h[-1] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test.values()])) / 3 / max_return / 2 + np.mean(
                    test_aucs) / 2,
            # with cache
            # "cache-eval/train_mean_return": np.mean([h[-1] for h in cache_eval_info_train.values()]),
            # "cache-eval/train_median_return": np.median([h[-1] for h in cache_eval_info_train.values()]),
            # "cache-eval/test_mean_return": np.mean([h[-1] for h in cache_eval_info_test.values()]),
            # "cache-eval/test_median_return": np.median([h[-1] for h in cache_eval_info_test.values()]),
            # "cache-eval/eval-time": cache_eval_timer.elapsed_time_gpu,
            "epoch": epoch,
            # "horizon_returns/train_mean": train_mean,
            # "horizon_returns/train_std": train_std,
            # "horizon_returns/test_mean": test_mean,
            # "horizon_returns/test_std": test_std,
        }, step=global_step
        )

        if test_goals_ood is not None:
            test_ood_aucs = [get_auc(h, lower=auc_lower) / auc_norm_coef for h in eval_info_ood.values()]
            wandb.log({
                "ood/test_mean_return": np.mean([h[-1] for h in eval_info_ood.values()]),
                "ood/test_mean_return_half": np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_ood.values()]),
                "ood/test_mean_return_quarter": np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_ood.values()]),
                "ood/test_auc": np.mean(test_ood_aucs),
            }, step=global_step
            )

        if "Janus" in config.env_name:
            _, eval_info_test_default = split_info_debug(eval_info_default, train_goals, test_goals, env=tmp_env,
                                                         transform_keys=transform_keys, transform_dict=transform_dicts)
            _, eval_info_test_inverted = split_info_debug(eval_info_inverted, train_goals, test_goals, env=tmp_env,
                                                          transform_keys=transform_keys, transform_dict=transform_dicts)
            test_aucs_default = [get_auc(h) / auc_norm_coef for h in eval_info_test_default.values()]
            test_aucs_inverted = [get_auc(h) / auc_norm_coef for h in eval_info_test_inverted.values()]
            wandb.log({
                "default/test_mean_return": np.mean([h[-1] for h in eval_info_test_default.values()]),
                "default/test_mean_return_half": np.mean([h[config.eval_episodes // 2] for h in eval_info_test_default.values()]),
                "default/test_mean_return_quarter": np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test_default.values()]),
                "default/test_auc": np.mean(test_aucs_default),
                "inverted/test_mean_return": np.mean([h[-1] for h in eval_info_test_inverted.values()]),
                "inverted/test_mean_return_half": np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_test_inverted.values()]),
                "inverted/test_mean_return_quarter": np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test_inverted.values()]),
                "inverted/test_auc": np.mean(test_aucs_inverted),
            }, step=global_step
            )

            janus_default_dynamic_results = {}
            janus_inverted_dynamic_results = {}
            for k in eval_info_test:
                if tmp_env.is_first_dynamic(k):
                    janus_default_dynamic_results[k] = eval_info_test[k]
                else:
                    janus_inverted_dynamic_results[k] = eval_info_test[k]
            test_aucs_def = [get_auc(h) / auc_norm_coef for h in janus_default_dynamic_results.values()]
            test_aucs_inv = [get_auc(h) / auc_norm_coef for h in janus_inverted_dynamic_results.values()]

            wandb.log({
                "default/test_mean_return_janus": np.mean([h[-1] for h in janus_default_dynamic_results.values()]),
                "default/test_mean_return_half_janus": np.mean([h[config.eval_episodes // 2] for h in janus_default_dynamic_results.values()]),
                "default/test_mean_return_quarter_janus": np.mean(
                    [h[config.eval_episodes // 4] for h in janus_default_dynamic_results.values()]),
                "default/test_auc_janus": np.mean(test_aucs_def),
                "inverted/test_mean_return_janus": np.mean([h[-1] for h in janus_inverted_dynamic_results.values()]),
                "inverted/test_mean_return_half_janus": np.mean(
                    [h[config.eval_episodes // 2] for h in janus_inverted_dynamic_results.values()]),
                "inverted/test_mean_return_quarter_janus": np.mean(
                    [h[config.eval_episodes // 4] for h in janus_inverted_dynamic_results.values()]),
                "inverted/test_auc_janus": np.mean(test_aucs_inv),
            }, step=global_step
            )

        if config.checkpoints_path is not None:
            torch.save(
                model.state_dict(),
                os.path.join(config.checkpoints_path, f"model_{global_step}.pt"),
            )
        model.train()

    if config.checkpoints_path is not None:
        torch.save(
            model.state_dict(), os.path.join(config.checkpoints_path, f"model_last.pt")
        )


if __name__ == "__main__":
    train()