import os
import uuid
from dataclasses import asdict, dataclass
from typing import Optional, Tuple

import gymnasium as gym
import itertools
from collections import defaultdict
from gymnasium.vector import SyncVectorEnv
import numpy as np
import pyrallis
import torch
from torch.nn import functional as F  # noqa
from torch.utils.data import DataLoader

import wandb

import src.envs
from src.envs.ant_dir import train_test_goals_ant
from src.envs.dark_room import train_test_goals_dr
from src.envs.gridworld import train_test_goals_gw
from src.envs.half_cheetah_vel import train_test_goals_hcv
from src.envs.hopper_params import train_test_goals_hopp
from src.envs.semi_circle import train_test_goals_sc
from src.envs.walker_params import train_test_goals_walkp
from src.model_tuples_cache import ADTuples
from src.utils.data import TuplesMapDataset
from src.utils.misc import set_seed, Timeit
from src.utils.schedule import cosine_annealing_with_warmup
from src.envs.dark_key_to_door import train_test_goals
from src.utils.visualization import per_episode_in_context, plot_goal_pred

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class TrainConfig:
    # wandb params
    project: str = "AD"
    group: str = "AD-Tuples-K2D"
    name: str = "ad-tuples-k2d"
    # model params
    hidden_dim: int = 512
    num_layers: int = 4
    num_heads: int = 4
    seq_len: int = 200
    attention_dropout: float = 0.5
    residual_dropout: float = 0.1
    embedding_dropout: float = 0.3
    normalize_qk: bool = False
    pre_norm: bool = True
    # training params
    env_name: str = "Dark-Key2Door-9x9-v0"
    learning_rate: float = 3e-4
    warmup_ratio: float = 0.05
    betas: Tuple[float, float] = (0.9, 0.99)
    weight_decay: float = 0.0
    clip_grad: Optional[float] = 1.0
    subsample: int = 1
    batch_size: int = 128
    update_epochs: int = 1
    num_workers: int = 0
    label_smoothing: float = 0.0
    # Dataorder
    data_order: str = "default"
    # Goal objectiev
    goal_weight: float = 0.0
    goal_norm_coef: float = 1.0
    goal_noise_sigma: float = 0
    noise_decay: float = 0.999
    nonlinear_action: bool = False
    # UMAP logging
    log_umap: bool = False
    # evaluation params
    eval_every: int = 25_000
    eval_episodes: int = 200
    eval_train_goals: int = 10
    eval_test_goals: int = 50
    eval_h: int = 500
    # general params
    learning_histories_path: str = "trajectories"
    checkpoints_path: Optional[str] = None
    train_seed: int = 42
    data_seed: int = 0
    eval_seed: int = 42

    def __post_init__(self):
        assert (self.hidden_dim / self.num_heads) % 8 == 0, "head dim should be multiple of 8 for flash attn"

        self.name = f"{self.name}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


@torch.no_grad()
def evaluate_in_context(env_name, model,  goals, eval_episodes, seed=None, cont_s=False, cont_a=False,):
    vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
    tmp_env = gym.make(env_name, goal_pos=goals[0])

    if cont_s:
        states = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.observation_space.shape[1]), dtype=torch.float32,
                             device=DEVICE)
    else:
        states = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
    if cont_a:
        prev_actions = torch.zeros((model.seq_len, vec_env.num_envs, vec_env.action_space.shape[1]),
                                   dtype=torch.float32, device=DEVICE)
    else:
        prev_actions = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
    prev_rewards = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)

    # to track number of episodes for each goal and returns
    num_episodes = np.zeros(vec_env.num_envs)
    returns = np.zeros(vec_env.num_envs)
    # predicted_goals = {
    #     (tmp_env.unwrapped.pos_to_state(goal[0]),
    #     tmp_env.unwrapped.pos_to_state(goal[1])) if np.prod(goal.shape) == 4
    #     else tmp_env.unwrapped.pos_to_state(goal):
    #         [] for goal in goals
    # }
    # seq_embeddings = {
    #     (tmp_env.unwrapped.pos_to_state(goal[0]),
    #     tmp_env.unwrapped.pos_to_state(goal[1])) if np.prod(goal.shape) == 4
    #     else tmp_env.unwrapped.pos_to_state(goal):
    #         [] for goal in goals
    # }
    # for logging
    eval_info = defaultdict(list)

    state, _ = vec_env.reset(seed=seed)
    if cont_a:
        prev_action = torch.zeros((vec_env.num_envs, vec_env.action_space.shape[1]), dtype=torch.float32, device=DEVICE)
    else:
        prev_action = 0
    prev_reward = 0
    for step in itertools.count(start=1):
        # roll context back for new step
        states = states.roll(-1, dims=0)
        prev_actions = prev_actions.roll(-1, dims=0)
        prev_rewards = prev_rewards.roll(-1, dims=0)

        # fill last s-a-r tuple
        states[-1] = torch.as_tensor(state, device=DEVICE)
        prev_actions[-1] = torch.as_tensor(prev_action, device=DEVICE)
        prev_rewards[-1] = torch.as_tensor(prev_reward, device=DEVICE)

        # predict next action
        with torch.cuda.amp.autocast():
            # [num_envs, seq_len, num_actions] -> [num_envs, num_actions]
            logits, _ = model(
                states=states[-step:].permute(1, 0, 2) if cont_s else states[-step:].permute(1, 0),
                prev_actions=prev_actions[-step:].permute(1, 0, 2) if cont_a else prev_actions[-step:].permute(1, 0),
                prev_rewards=prev_rewards[-step:].permute(1, 0),
                return_repr=False,
            )
            logits = logits[:, -1]
            # pred_g = pred_g[:, -1]
            # embeddings = embeddings[:, -1]

        if not cont_a:
            dist = torch.distributions.Categorical(logits=logits)
            # action = dist.sample()
            action = dist.mode
        else:
            action = logits

        # for idx, goal in enumerate(goals):
        #     log_key = (
        #         tmp_env.unwrapped.pos_to_state(goals[idx][0]),
        #         tmp_env.unwrapped.pos_to_state(goals[idx][1])) if np.prod(
        #         goals[idx].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[idx])
            # predicted_goals[log_key].append(pred_g[idx].cpu().numpy())
            # seq_embeddings[log_key].append(embeddings[idx].cpu().numpy())
        # query the world
        state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
        done = terminated | truncated

        # relabel for the next step
        prev_action = action
        prev_reward = reward

        num_episodes += done.astype(int)
        returns += reward

        # log returns if done
        for i, d in enumerate(done):
            if d and num_episodes[i] <= eval_episodes:
                if "Params" not in env_name:
                    log_key = (
                        tmp_env.unwrapped.pos_to_state(goals[i][0]),
                        tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
                        goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
                else:
                    log_key = tmp_env.unwrapped.pos_to_state(goals[i])
                eval_info[log_key].append(returns[i])
                # reset return for this goal
                returns[i] = 0.0

        # check that all goals are done
        if np.all(num_episodes > eval_episodes):
            break

    vec_env.close()
    tmp_env.close()
    return eval_info

@torch.no_grad()
def evaluate_in_context_steps(env_name, model,  goals, max_steps, seed=None):
    vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
    tmp_env = gym.make(env_name, goal_pos=goals[0])

    states = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
    prev_actions = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.long, device=DEVICE)
    prev_rewards = torch.zeros((model.seq_len, vec_env.num_envs), dtype=torch.float32, device=DEVICE)

    # to track number of episodes for each goal and returns
    num_episodes = np.zeros(vec_env.num_envs)
    returns = np.zeros(vec_env.num_envs)
    # for logging
    eval_info = defaultdict(list)

    state, _ = vec_env.reset(seed=seed)
    prev_action, prev_reward = 0, 0
    for step in itertools.count(start=1):
        if max_steps < step:
            break
        # roll context back for new step
        states = states.roll(-1, dims=0)
        prev_actions = prev_actions.roll(-1, dims=0)
        prev_rewards = prev_rewards.roll(-1, dims=0)

        # fill last s-a-r tuple
        states[-1] = torch.as_tensor(state, device=DEVICE)
        prev_actions[-1] = torch.as_tensor(prev_action, device=DEVICE)
        prev_rewards[-1] = torch.as_tensor(prev_reward, device=DEVICE)

        # predict next action
        with torch.cuda.amp.autocast():
            # [num_envs, seq_len, num_actions] -> [num_envs, num_actions]
            logits, _ = model(
                states=states[-step:].permute(1, 0),
                prev_actions=prev_actions[-step:].permute(1, 0),
                prev_rewards=prev_rewards[-step:].permute(1, 0),
            )
            logits = logits[:, -1]
        dist = torch.distributions.Categorical(logits=logits)
        # action = dist.sample()
        action = dist.mode

        # query the world
        state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
        done = terminated | truncated

        # relabel for the next step
        prev_action = action
        prev_reward = reward

        num_episodes += done.astype(int)
        returns += reward

        # log returns if done
    for i in range(vec_env.num_envs):
        log_key = (
            tmp_env.unwrapped.pos_to_state(goals[i][0]), tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
            goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
        eval_info[log_key].append(returns[i])

    vec_env.close()
    tmp_env.close()
    return eval_info


@torch.no_grad()
def evaluate_in_context_with_cache(env_name, model,  goals, eval_episodes, seed=None):
    vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
    tmp_env = gym.make(env_name, goal_pos=goals[0])

    kv_cache = model.init_cache(batch_size=vec_env.num_envs, dtype=torch.float16, device=DEVICE)
    # to track number of episodes for each goal and returns
    num_episodes = np.zeros(vec_env.num_envs)
    returns = np.zeros(vec_env.num_envs)
    # for logging
    eval_info = defaultdict(list)

    state, _ = vec_env.reset(seed=seed)
    prev_action, prev_reward = np.zeros(vec_env.num_envs), np.zeros(vec_env.num_envs)
    for step in itertools.count(start=1):
        # predict next action
        with torch.cuda.amp.autocast():
            # [num_envs, seq_len=1, num_actions] -> [num_envs, num_actions]
            logits, kv_cache = model(
                states=torch.as_tensor(state, dtype=torch.long, device=DEVICE)[:, None],
                prev_actions=torch.as_tensor(prev_action, dtype=torch.long, device=DEVICE)[:, None],
                prev_rewards=torch.as_tensor(prev_reward, dtype=torch.float, device=DEVICE)[:, None],
                cache=kv_cache
            )
            logits = logits[:, -1]
        dist = torch.distributions.Categorical(logits=logits)
        # action = dist.sample()
        action = dist.mode

        # query the world
        state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
        done = terminated | truncated

        # relabel for the next step
        prev_action = action
        prev_reward = reward

        num_episodes += done.astype(int)
        returns += reward

        # log returns if done
        for i, d in enumerate(done):
            if d and num_episodes[i] <= eval_episodes:
                log_key = (
                    tmp_env.unwrapped.pos_to_state(goals[i][0]),
                    tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
                    goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
                eval_info[log_key].append(returns[i])
                # reset return for this goal
                returns[i] = 0.0

        # check that all goals are done
        if np.all(num_episodes > eval_episodes):
            break

    vec_env.close()
    tmp_env.close()
    return eval_info


@torch.no_grad()
def evaluate_in_context_with_cache_steps(env_name, model,  goals, max_steps, seed=None):
    vec_env = SyncVectorEnv([lambda goal=goal: gym.make(env_name, goal_pos=goal) for goal in goals])
    tmp_env = gym.make(env_name, goal_pos=goals[0])

    kv_cache = model.init_cache(batch_size=vec_env.num_envs, dtype=torch.float16, device=DEVICE)
    # to track number of episodes for each goal and returns
    num_episodes = np.zeros(vec_env.num_envs)
    returns = np.zeros(vec_env.num_envs)
    # for logging
    eval_info = defaultdict(list)

    state, _ = vec_env.reset(seed=seed)
    prev_action, prev_reward = np.zeros(vec_env.num_envs), np.zeros(vec_env.num_envs)
    for step in itertools.count(start=1):
        if max_steps < step:
            break
        # predict next action
        with torch.cuda.amp.autocast():
            # [num_envs, seq_len=1, num_actions] -> [num_envs, num_actions]
            logits, kv_cache = model(
                states=torch.as_tensor(state, dtype=torch.long, device=DEVICE)[:, None],
                prev_actions=torch.as_tensor(prev_action, dtype=torch.long, device=DEVICE)[:, None],
                prev_rewards=torch.as_tensor(prev_reward, dtype=torch.float, device=DEVICE)[:, None],
                cache=kv_cache
            )
            logits = logits[:, -1]
        dist = torch.distributions.Categorical(logits=logits)
        # action = dist.sample()
        action = dist.mode

        # query the world
        state, reward, terminated, truncated, _ = vec_env.step(action.cpu().numpy())
        done = terminated | truncated

        # relabel for the next step
        prev_action = action
        prev_reward = reward

        num_episodes += done.astype(int)
        returns += reward

        # log returns if done

    for i in range(vec_env.num_envs):
        log_key = (
            tmp_env.unwrapped.pos_to_state(goals[i][0]), tmp_env.unwrapped.pos_to_state(goals[i][1])) if np.prod(
            goals[i].shape) == 4 else tmp_env.unwrapped.pos_to_state(goals[i])
        eval_info[log_key].append(returns[i])
    vec_env.close()
    tmp_env.close()
    return eval_info


def split_info_debug(eval_info, train_goals, test_goals, env, transform_keys=True, transform_dict=False):
    eval_info_train = defaultdict(list)
    eval_info_test = defaultdict(list)

    if isinstance(train_goals, np.ndarray):
        train_goals = train_goals.tolist()
    if isinstance(test_goals, np.ndarray):
        test_goals = test_goals.tolist()
    if transform_dict:
        train_goals = list(map(env.unwrapped.pos_to_state, train_goals))
        test_goals = list(map(env.unwrapped.pos_to_state, test_goals))
    # print(train_goals)
    # print(test_goals)
    for i, (k, v) in enumerate(eval_info.items()):
        # print(k)
        curr_goal = k
        if transform_keys:
            if type(k) is not int:
                curr_goal = [env.unwrapped.state_to_pos(k[0]).tolist(), env.unwrapped.state_to_pos(k[1]).tolist()]
            else:
                curr_goal = env.unwrapped.state_to_pos(k).tolist()
        elif not transform_dict:
            curr_goal = list(k)
        if curr_goal in train_goals:
            eval_info_train[str(curr_goal)] = v
        elif curr_goal in test_goals:
            eval_info_test[str(curr_goal)] = v
        else:
            raise ValueError()

    return eval_info_train, eval_info_test


def process_steps_eval(eval_info):
    returns = [h[0] for h in eval_info.values()]
    # print(returns)
    return np.mean(returns), np.std(returns)


def get_auc(points, lower=0):
    points = [p - lower for p in points]
    S = 0
    for a, b in zip(points, points[1:]):
        S += min(a, b) + abs(a - b) / 2
    return S


@pyrallis.wrap()
def train(config: TrainConfig):
    dict_config = asdict(config)
    dict_config["mlc_job"] = os.getenv("PLATFORM_JOB_NAME")
    wandb.init(
        project=config.project,
        group=config.group,
        name=config.name,
        config=dict_config,
    )

    key_shape = 2
    transform_dicts = False
    if "Key" in config.env_name:
        key_shape = 4
    elif "HalfCheetahVel-v0" == config.env_name:
        key_shape = 1
    elif "AntDir-v0" == config.env_name:
        key_shape = 1
    elif "HopperParams-v0" == config.env_name:
        key_shape = None
        transform_dicts = True
    elif "Walker2dParams-v0" == config.env_name:
        key_shape = None
        transform_dicts = True


    random_order = False
    sorted_order = False
    sample_ordered = False
    if config.data_order == 'random':
        random_order = True
    elif config.data_order == 'sorted':
        sorted_order = True
    elif config.data_order == 'sample':
        sample_ordered = True

    set_seed(config.train_seed)
    dataset = TuplesMapDataset(
        data_path=config.learning_histories_path,
        seq_len=config.seq_len,
        subsample=config.subsample,
        goal_dim=key_shape,
        random_order=random_order,
        sorted_order=sorted_order,
        sample_ordered=sample_ordered,
    )
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=config.batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=config.num_workers,
    )
    tmp_env = gym.make(config.env_name)

    # eval preparation
    test_goals_ood = None
    if "Key" in config.env_name:
        train_goals, test_goals = train_test_goals(
            grid_size=gym.make(config.env_name).unwrapped.size,
            num_train_goals=len(dataset.unique_goals),
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "Room" in config.env_name or "Janus" in config.env_name:
        train_goals, test_goals = train_test_goals_dr(
            grid_size=gym.make(config.env_name).unwrapped.size,
            num_train_goals=len(dataset.unique_goals),
            seed=config.data_seed
        )
    elif "Grid" in config.env_name:
        train_goals, test_goals = train_test_goals_gw(
            grid_size=gym.make(config.env_name).unwrapped.size,
            num_train_goals=len(dataset.unique_goals),
            seed=config.data_seed
        )
    elif "Semi-Circle" in config.env_name:
        train_goals, test_goals = train_test_goals_sc(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "HalfCheetahVel-v0" in config.env_name:
        train_goals, test_goals, test_goals_ood = train_test_goals_hcv(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "AntDir-v0" in config.env_name:
        train_goals, test_goals = train_test_goals_ant(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "HopperParams-v0" in config.env_name:
        train_goals, test_goals = train_test_goals_hopp(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )
    elif "Walker2dParams-v0" in config.env_name:
        train_goals, test_goals = train_test_goals_walkp(
            num_test_goals=config.eval_test_goals,
            seed=config.data_seed
        )

    train_goals = train_goals[:config.eval_train_goals]
    if "Params" not in config.env_name:
        eval_all_goals = np.vstack([train_goals, test_goals])
    else:
        eval_all_goals = train_goals + test_goals

    continuous_states = config.env_name in ["Semi-Circle-Sparse-v0", "Semi-Circle-v0", "HalfCheetahVel-v0", "AntDir-v0",
                                            "HopperParams-v0", "Walker2dParams-v0"]
    continuous_actions = config.env_name in ["Semi-Circle-Sparse-v0", "Semi-Circle-v0", "HalfCheetahVel-v0",
                                             "AntDir-v0", "HopperParams-v0", "Walker2dParams-v0"]
    transform_keys = config.env_name not in ["Semi-Circle-Sparse-v0", "Semi-Circle-v0", "HalfCheetahVel-v0",
                                             "AntDir-v0", "HopperParams-v0", "Walker2dParams-v0"]

    # model & optimizer & scheduler setup
    set_seed(config.train_seed)

    model = ADTuples(
        num_states=tmp_env.observation_space.shape[0] if continuous_states else tmp_env.observation_space.n,
        num_actions=tmp_env.action_space.shape[0] if continuous_actions else tmp_env.action_space.n,
        num_params=key_shape,
        hidden_dim=config.hidden_dim,
        seq_len=config.seq_len,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        attention_dropout=config.attention_dropout,
        residual_dropout=config.residual_dropout,
        embedding_dropout=config.embedding_dropout,
        normalize_qk=config.normalize_qk,
        pre_norm=config.pre_norm,
        continuous_states=continuous_states,
        continuous_actions=continuous_actions,
        nonlinear_action_head=config.nonlinear_action,
    ).to(DEVICE)

    # if needed, test beforehand
    # model = torch.compile(model)

    optim = torch.optim.Adam(
        params=model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=config.betas,
    )

    total_updates = len(dataloader) * config.update_epochs
    # scheduler = cosine_annealing_with_warmup(
    #     optimizer=optim,
    #     warmup_steps=int(total_updates * config.warmup_ratio),
    #     total_steps=total_updates,
    # )

    # save config to the checkpoint
    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    scaler = torch.cuda.amp.GradScaler()
    print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
    global_step = 0
    # for i in range(10):
    #     print(dataset[i])
    # return
    set_seed(config.train_seed)
    for epoch in range(config.update_epochs):
        for batch in dataloader:
            states, prev_actions, prev_rewards, prev_dones, target_actions, rewards, dones, steps = [b.to(DEVICE) for b in batch]
            # print(goals)

            states = states.to(torch.long) if not continuous_states else states.to(torch.float)
            # print(states)
            prev_actions = prev_actions.to(torch.long) if not continuous_actions else prev_actions.to(torch.float)
            prev_rewards = prev_rewards.to(torch.float)
            prev_dones = prev_dones.to(torch.float)
            target_actions = target_actions.to(torch.long) if not continuous_actions else target_actions.to(torch.float)
            rewards = rewards.to(torch.float)
            dones = dones.to(torch.float)
            steps = steps.to(torch.float)
            # global_steps = global_steps.to(torch.float)
            # goals = goals.to(torch.float) / config.goal_norm_coef
            # print(goals.shape)
            # goals += torch.randn_like(goals) * config.goal_noise_sigma * config.noise_decay ** global_steps.unsqueeze(-1).repeat(1, 1, key_shape)

            with Timeit() as timer:
                with torch.cuda.amp.autocast():
                    predicted_actions, _ = model(
                        states=states, prev_actions=prev_actions, prev_rewards=prev_rewards, return_repr=False
                    )

                    # contexts = embeddings[:, -1]
                    # normalized_vectors = F.normalize(contexts, p=2, dim=1)
                    # cosine_similarity_matrix = normalized_vectors @ normalized_vectors.T - torch.eye(
                    #     normalized_vectors.shape[0]).to(DEVICE)

                    # goal_loss = F.mse_loss(
                    #     input=predicted_params.flatten(0, 1),
                    #     target=goals.flatten(0, 1),
                    # )
                    if continuous_actions:
                        action_loss = F.mse_loss(
                            input=predicted_actions.flatten(0, 1),
                            target=target_actions.flatten(0, 1),
                        )
                    else:
                        action_loss = F.cross_entropy(
                            input=predicted_actions.flatten(0, 1),
                            target=target_actions.flatten(0, 1),
                            label_smoothing=config.label_smoothing,
                        )
                loss = action_loss #+ config.goal_weight * goal_loss
                scaler.scale(loss).backward()
                if config.clip_grad is not None:
                    scaler.unscale_(optim)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad)
                scaler.step(optim)
                scaler.update()
                optim.zero_grad(set_to_none=True)
                # scheduler.step()
            log_dict = {
                    # "goal_loss": goal_loss.item(),
                    "action_loss": action_loss.item(),
                    "loss": loss.item(),
                    "udpate-time": timer.elapsed_time_gpu,
                    "epoch": epoch,
                    # "cosine_sim_mean": cosine_similarity_matrix.mean(),
                }
            with torch.no_grad():
                if not continuous_actions:
                    a = torch.argmax(predicted_actions.flatten(0, 1), dim=-1)
                    t = target_actions.flatten()
                    accuracy = torch.sum(a == t) / a.shape[0]
                    log_dict['accuracy'] = accuracy
                wandb.log(log_dict, step=global_step)

            global_step += 1
        model.eval()
        with Timeit() as eval_timer:
            eval_info = evaluate_in_context(
                env_name=config.env_name,
                model=model,
                goals=eval_all_goals,
                eval_episodes=config.eval_episodes,
                seed=config.eval_seed,
                cont_s=continuous_states,
                cont_a=continuous_actions,
            )
            if test_goals_ood is not None:
                eval_info_ood = evaluate_in_context(
                    env_name=config.env_name,
                    model=model,
                    goals=test_goals_ood,
                    eval_episodes=config.eval_episodes,
                    seed=config.eval_seed,
                    cont_s=continuous_states,
                    cont_a=continuous_actions,
                )
            if "Janus" in config.env_name:
                eval_info_default = evaluate_in_context(
                    env_name="Janus-Default-19x19-v0",
                    model=model,
                    goals=eval_all_goals,
                    eval_episodes=config.eval_episodes,
                    seed=config.eval_seed,
                    cont_s=continuous_states,
                    cont_a=continuous_actions,
                )
                eval_info_inverted = evaluate_in_context(
                    env_name="Janus-Inverted-19x19-v0",
                    model=model,
                    goals=eval_all_goals,
                    eval_episodes=config.eval_episodes,
                    seed=config.eval_seed,
                    cont_s=continuous_states,
                    cont_a=continuous_actions,
                )
            # print(seq_embeddings)
            # raise ValueError()
        eval_info_train, eval_info_test = split_info_debug(eval_info, train_goals, test_goals, env=tmp_env, transform_keys=transform_keys, transform_dict=transform_dicts)


        # eval_steps_info = evaluate_in_context_steps(
        #     env_name=config.env_name,
        #     model=model,
        #     goals=eval_all_goals,
        #     max_steps=config.eval_h,
        #     seed=config.eval_seed,
        # )
        # eval_steps_info_train, eval_steps_info_test = split_info_debug(eval_steps_info, train_goals, test_goals,
        #                                                                env=tmp_env, transform_keys=transform_keys)
        # train_mean, train_std = process_steps_eval(eval_steps_info_train)
        # test_mean, test_std = process_steps_eval(eval_steps_info_test)

        # with Timeit() as cache_eval_timer:
        #     cache_eval_info = evaluate_in_context_with_cache(
        #         env_name=config.env_name,
        #         model=model,
        #         goals=eval_all_goals,
        #         eval_episodes=config.eval_episodes,
        #         seed=config.eval_seed,
        #     )
        # cache_eval_info_train, cache_eval_info_test = split_info_debug(cache_eval_info, train_goals, test_goals,
        #                                                                env=tmp_env, transform_keys=transform_keys)

        max_return = 1.0
        min_return = -0.3
        auc_lower = 0
        if "Key" in config.env_name:
            max_return = 2.0
        elif "Grid" in config.env_name:
            max_return = 14.0
        elif "Sparse" in config.env_name:
            max_return = 50
        elif "Semi" in config.env_name:
            min_return = -30
            max_return = -4.5
        elif "HalfCheetahVel-v0" == config.env_name:
            max_return = 0
            min_return = -500
            auc_lower = -410.5
        elif "AntDir-v0" == config.env_name:
            max_return = 1000
            min_return = 0
            auc_lower = -69.8
        elif "HopperParams-v0" == config.env_name:
            max_return = 500
            min_return = 0
            auc_lower = 17
        elif "Walker2dParams-v0" == config.env_name:
            max_return = 250
            min_return = 0
            auc_lower = 3.6

        if "HalfCheetahVel-v0" == config.env_name:
            max_return = -78.6
        if "AntDir-v0" == config.env_name:
            max_return = 773.8
        if "HopperParams-v0" == config.env_name:
            max_return = 220.0
        if "Walker2dParams-v0" == config.env_name:
            max_return = 220.0

        pic_name_train = per_episode_in_context(
            eval_info_train, ylim=[min_return, max_return + 0.5], name=f"train-viz", max_return=max_return,
        )
        pic_name_test = per_episode_in_context(
            eval_info_test, ylim=[min_return, max_return + 0.5], name=f"test-viz", max_return=max_return,
        )
        auc_norm_coef = get_auc([max_return] * config.eval_episodes, lower=auc_lower)
        train_aucs = [get_auc(h, lower=auc_lower) / auc_norm_coef for h in eval_info_train.values()]
        test_aucs = [get_auc(h, lower=auc_lower) / auc_norm_coef for h in eval_info_test.values()]
        wandb.log({
            "eval/train_mean_return": np.mean([h[-1] for h in eval_info_train.values()]),
            "eval/train_median_return": np.median([h[-1] for h in eval_info_train.values()]),
            "eval/test_mean_return": np.mean([h[-1] for h in eval_info_test.values()]),
            "eval/test_median_return": np.median([h[-1] for h in eval_info_test.values()]),
            "eval/train_mean_return_half": np.mean([h[config.eval_episodes//2] for h in eval_info_train.values()]),
            "eval/train_median_return_half": np.median([h[config.eval_episodes//2] for h in eval_info_train.values()]),
            "eval/test_mean_return_half": np.mean([h[config.eval_episodes//2] for h in eval_info_test.values()]),
            "eval/test_median_return_half": np.median([h[config.eval_episodes//2] for h in eval_info_test.values()]),
            "eval/train_mean_return_quarter": np.mean([h[config.eval_episodes//4] for h in eval_info_train.values()]),
            "eval/train_median_return_quarter": np.median([h[config.eval_episodes//4] for h in eval_info_train.values()]),
            "eval/test_mean_return_quarter": np.mean([h[config.eval_episodes//4] for h in eval_info_test.values()]),
            "eval/test_median_return_quarter": np.median([h[config.eval_episodes//4] for h in eval_info_test.values()]),
            "eval/train_auc": np.mean(train_aucs),
            "eval/test_auc": np.mean(test_aucs),
            "eval/train_graph": wandb.Image(pic_name_train),
            "eval/test_graph": wandb.Image(pic_name_test),
            "eval/eval-time": eval_timer.elapsed_time_gpu,
            "aggregated_metrics/train_avg_returns":
                (np.mean([h[-1] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_train.values()])) / 3,
            "aggregated_metrics/train_avg_all":
                (np.mean([h[-1] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_train.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_train.values()])) / 3 / max_return / 2 + np.mean(train_aucs) / 2,
            "aggregated_metrics/test_avg_returns":
                (np.mean([h[-1] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test.values()])) / 3,
            "aggregated_metrics/test_avg_all":
                (np.mean([h[-1] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_test.values()]) + np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test.values()])) / 3 / max_return / 2 + np.mean(
                    test_aucs) / 2,
            # with cache
            # "cache-eval/train_mean_return": np.mean([h[-1] for h in cache_eval_info_train.values()]),
            # "cache-eval/train_median_return": np.median([h[-1] for h in cache_eval_info_train.values()]),
            # "cache-eval/test_mean_return": np.mean([h[-1] for h in cache_eval_info_test.values()]),
            # "cache-eval/test_median_return": np.median([h[-1] for h in cache_eval_info_test.values()]),
            # "cache-eval/eval-time": cache_eval_timer.elapsed_time_gpu,
            "epoch": epoch,
            # "horizon_returns/train_mean": train_mean,
            # "horizon_returns/train_std": train_std,
            # "horizon_returns/test_mean": test_mean,
            # "horizon_returns/test_std": test_std,
        }, step=global_step
        )
        if test_goals_ood is not None:
            test_ood_aucs = [get_auc(h, lower=auc_lower) / auc_norm_coef for h in eval_info_ood.values()]
            wandb.log({
                "ood/test_mean_return": np.mean([h[-1] for h in eval_info_ood.values()]),
                "ood/test_mean_return_half": np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_ood.values()]),
                "ood/test_mean_return_quarter": np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_ood.values()]),
                "ood/test_auc": np.mean(test_ood_aucs),
            }, step=global_step
            )
        if "Janus" in config.env_name:
            _, eval_info_test_default = split_info_debug(eval_info_default, train_goals, test_goals, env=tmp_env,
                                                               transform_keys=transform_keys, transform_dict=transform_dicts)
            _, eval_info_test_inverted = split_info_debug(eval_info_inverted, train_goals, test_goals, env=tmp_env,
                                                          transform_keys=transform_keys, transform_dict=transform_dicts)
            test_aucs_default = [get_auc(h) / auc_norm_coef for h in eval_info_test_default.values()]
            test_aucs_inverted = [get_auc(h) / auc_norm_coef for h in eval_info_test_inverted.values()]
            wandb.log({
                "default/test_mean_return": np.mean([h[-1] for h in eval_info_test_default.values()]),
                "default/test_mean_return_half": np.mean([h[config.eval_episodes // 2] for h in eval_info_test_default.values()]),
                "default/test_mean_return_quarter": np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test_default.values()]),
                "default/test_auc": np.mean(test_aucs_default),
                "inverted/test_mean_return": np.mean([h[-1] for h in eval_info_test_inverted.values()]),
                "inverted/test_mean_return_half": np.mean(
                    [h[config.eval_episodes // 2] for h in eval_info_test_inverted.values()]),
                "inverted/test_mean_return_quarter": np.mean(
                    [h[config.eval_episodes // 4] for h in eval_info_test_inverted.values()]),
                "inverted/test_auc": np.mean(test_aucs_inverted),
            }, step=global_step
            )

            janus_default_dynamic_results = {}
            janus_inverted_dynamic_results = {}
            for k in eval_info_test:
                if tmp_env.is_first_dynamic(k):
                    janus_default_dynamic_results[k] = eval_info_test[k]
                else:
                    janus_inverted_dynamic_results[k] = eval_info_test[k]
            test_aucs_def = [get_auc(h) / auc_norm_coef for h in janus_default_dynamic_results.values()]
            test_aucs_inv = [get_auc(h) / auc_norm_coef for h in janus_inverted_dynamic_results.values()]

            wandb.log({
                "default/test_mean_return_janus": np.mean([h[-1] for h in janus_default_dynamic_results.values()]),
                "default/test_mean_return_half_janus": np.mean(
                    [h[config.eval_episodes // 2] for h in janus_default_dynamic_results.values()]),
                "default/test_mean_return_quarter_janus": np.mean(
                    [h[config.eval_episodes // 4] for h in janus_default_dynamic_results.values()]),
                "default/test_auc_janus": np.mean(test_aucs_def),
                "inverted/test_mean_return_janus": np.mean([h[-1] for h in janus_inverted_dynamic_results.values()]),
                "inverted/test_mean_return_half_janus": np.mean(
                    [h[config.eval_episodes // 2] for h in janus_inverted_dynamic_results.values()]),
                "inverted/test_mean_return_quarter_janus": np.mean(
                    [h[config.eval_episodes // 4] for h in janus_inverted_dynamic_results.values()]),
                "inverted/test_auc_janus": np.mean(test_aucs_inv),
            }, step=global_step
            )


        if config.checkpoints_path is not None:
            torch.save(
                model.state_dict(),
                os.path.join(config.checkpoints_path, f"model_{global_step}.pt"),
            )
        model.train()

    if config.checkpoints_path is not None:
        torch.save(
            model.state_dict(), os.path.join(config.checkpoints_path, f"model_last.pt")
        )


if __name__ == "__main__":
    train()