import datetime
import collections
import gc
import io
import os
import json
import pathlib
import re
import time
import random

import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter
from concurrent.futures import ThreadPoolExecutor, as_completed #[todo]
from VLM.sbert import SentenceBert #[todo]
from pathlib import Path #[todo]
import gc #[todo]


to_np = lambda x: x.detach().cpu().numpy()


def symlog(x):
    return torch.sign(x) * torch.log(torch.abs(x) + 1.0)


def symexp(x):
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)


class RequiresGrad:
    def __init__(self, model):
        self._model = model

    def __enter__(self):
        self._model.requires_grad_(requires_grad=True)

    def __exit__(self, *args):
        self._model.requires_grad_(requires_grad=False)


class TimeRecording:
    def __init__(self, comment):
        self._comment = comment

    def __enter__(self):
        self._st = torch.cuda.Event(enable_timing=True)
        self._nd = torch.cuda.Event(enable_timing=True)
        self._st.record()

    def __exit__(self, *args):
        self._nd.record()
        torch.cuda.synchronize()
        print(self._comment, self._st.elapsed_time(self._nd) / 1000)


class Logger:
    def __init__(self, logdir, step):
        self._logdir = logdir
        self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
        self._last_step = None
        self._last_time = None
        self._scalars = {}
        self._images = {}
        self._videos = {}
        self.step = step

    def scalar(self, name, value):
        self._scalars[name] = float(value)

    def image(self, name, value):
        self._images[name] = np.array(value)

    def video(self, name, value):
        self._videos[name] = np.array(value)

    def write(self, fps=False, step=False):
        if not step:
            step = self.step
        scalars = list(self._scalars.items())
        if fps:
            scalars.append(("fps", self._compute_fps(step)))
        print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
        with (self._logdir / "metrics.jsonl").open("a") as f:
            f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
        for name, value in scalars:
            if "/" not in name:
                self._writer.add_scalar("scalars/" + name, value, step)
            else:
                self._writer.add_scalar(name, value, step)
        for name, value in self._images.items():
            self._writer.add_image(name, value, step)
        for name, value in self._videos.items():
            name = name if isinstance(name, str) else name.decode("utf-8")
            if np.issubdtype(value.dtype, np.floating):
                value = np.clip(255 * value, 0, 255).astype(np.uint8)
            B, T, H, W, C = value.shape
            value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W))
            self._writer.add_video(name, value, step, 16)

        self._writer.flush()
        self._scalars = {}
        self._images = {}
        self._videos = {}

    def _compute_fps(self, step):
        if self._last_step is None:
            self._last_time = time.time()
            self._last_step = step
            return 0
        steps = step - self._last_step
        duration = time.time() - self._last_time
        self._last_time += duration
        self._last_step = step
        return steps / duration

    def offline_scalar(self, name, value, step):
        self._writer.add_scalar("scalars/" + name, value, step)

    def offline_video(self, name, value, step):
        if np.issubdtype(value.dtype, np.floating):
            value = np.clip(255 * value, 0, 255).astype(np.uint8)
        B, T, H, W, C = value.shape
        value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W))
        self._writer.add_video(name, value, step, 16)


def simulate(
    agent,
    envs,
    cache,
    directory,
    logger,
    is_eval=False,
    limit=None,
    steps=0,
    episodes=0,
    state=None,
    train_env_name_list=None, #[todo]
    task_desc_dict=None, #[todo]
    # action_mask=None, #[todo]
    # state_mask=None, #[todo] 这个是类似于image的那个state，而非simulate的返回值那个state
    obs_filter=None, #[todo]

):
    # initialize or unpack simulation state
    if state is None:
        step, episode = 0, 0
        # #[todo] start
        if isinstance(directory, dict) and episodes and is_eval:
            done_episode = np.zeros(len(train_env_name_list))
        # #[todo] end
        done = np.ones(len(envs), bool)
        length = np.zeros(len(envs), np.int32)
        obs = [None] * len(envs)
        agent_state = None
        reward = [0] * len(envs)
    else:
        step, episode, done, length, obs, agent_state, reward = state

    while (steps and step < steps) or (episodes and episode < episodes):
        # #[todo] start
        # if train_env_name_list is not None:
        #     if "drawer-open-v1" in train_env_name_list:
        #         gc.collect()
        # #[todo] end
        # reset envs if necessary
        if done.any():
            indices = [index for index, d in enumerate(done) if d]
            results = [envs[i].reset() for i in indices]
            results = [r() for r in results]
            for index, result in zip(indices, results):
                #[todo] start
                curr_env_name = train_env_name_list[index % len(train_env_name_list)]
                # result["env_name"] = curr_env_name
                # print(f"[RESET] env_index={index}, env_id={envs[index].id}, curr_env_name={curr_env_name}")

                # result["log_task_name"] = curr_env_name
                if task_desc_dict is not None:
                    result["token_embed"] = task_desc_dict[curr_env_name]
                #[todo] end
                t = result.copy()
                t = {k: convert(v) for k, v in t.items()}
                # action will be added to transition in add_to_cache
                t["reward"] = 0.0
                t["discount"] = 1.0
                # initial state should be added to cache
                if not isinstance(cache, collections.OrderedDict):  # [todo]
                    add_to_cache(cache[curr_env_name], envs[index].id, t)
                else:
                    add_to_cache(cache, envs[index].id, t)
                # replace obs with done by initial state
                obs[index] = result
        # step agents
        #[todo] start
        if obs_filter is not None:
            valid_keys = [k for k in obs[0].keys() if "log_" not in k and k in obs_filter] #for k in obs[0]默认会解释成.keys()
            obs = {k: np.stack([o[k] for o in obs]) for k in valid_keys}
        else:
        #[todo] end
            obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}

        action, agent_state = agent(obs, done, agent_state)
        if isinstance(action, dict):
            action = [
                {k: np.array(action[k][i].detach().cpu()) for k in action}
                for i in range(len(envs))
            ]
        else:
            action = np.array(action)
        assert len(action) == len(envs)
        # print(f"action:{action}")
        # step envs
        results = [e.step(a) for e, a in zip(envs, action)]
        results = [r() for r in results]

        # add to cache
        for a, result, env in zip(action, results, envs):
            #[todo] start
            curr_env_name = env.task
            # result[0]["env_name"] = curr_env_name
            # result[0]["log_task_name"] = curr_env_name
            if task_desc_dict is not None:
                result[0]["token_embed"] = task_desc_dict[curr_env_name]
            #[todo] end
            o, r, d, info = result
            o = {k: convert(v) for k, v in o.items()}
            transition = o.copy()
            if isinstance(a, dict):
                transition.update(a)
            else:
                transition["action"] = a
            transition["reward"] = r
            #[todo] start
            d = info["done"] or d if "done" in info else d #适配metaworld
            transition["reward"] = 0.0 if "has_done" in info and info["has_done"] else r
            if "success" in info:
                transition["success"] = info["success"]
            #[todo] end
            transition["discount"] = info.get("discount", np.array(1 - float(d)))
            if not isinstance(cache, collections.OrderedDict):  # [todo]
                add_to_cache(cache[env.task], env.id, transition)
            else:
                add_to_cache(cache, env.id, transition)

        obs, reward, done = zip(*[p[:3] for p in results]) #[todo] change position
        obs = list(obs)
        reward = list(reward)
        done = np.stack(done)
        #[todo] start
        if isinstance(directory, dict) and is_eval:
            done_reshaped = done.reshape(-1, len(train_env_name_list))
            done_episode = np.clip(done_reshaped.sum(axis=0) + done_episode, a_min=None, a_max=int(episodes / len(train_env_name_list))) #切分后按列求和，获得每个环境对应的episode增加个数
            episode = int(done_episode.sum())
        else:
        #[todo] end
            episode += int(done.sum())
        length += 1
        step += len(envs)
        length *= 1 - done

        if done.any():
            indices = [index for index, d in enumerate(done) if d]
            # logging for done episode
            for i in indices:
                if isinstance(directory, dict): #[todo] equal to not isinstance(cache, collections.OrderedDict) for now
                    curr_env_name = train_env_name_list[i % len(train_env_name_list)]
                    curr_cache = cache[curr_env_name]
                    curr_directory = directory[curr_env_name]
                else:
                    curr_env_name = ""
                    curr_cache = cache
                    curr_directory = directory
                save_episodes(curr_directory, {envs[i].id: curr_cache[envs[i].id]})
                length = len(curr_cache[envs[i].id]["reward"]) - 1
                score = float(np.array(curr_cache[envs[i].id]["reward"]).sum())
                success = (np.array(curr_cache[envs[i].id]["success"]) > 0).any().astype("float") if "success" in curr_cache[envs[i].id] else None #[todo] 每次reset env.id都会重置，所以对应的就是这一个episode的值(这里是统计这个episode的这个环境有无成功过)
                video = curr_cache[envs[i].id]["image"]
                # record logs given from environments
                for key in list(curr_cache[envs[i].id].keys()):
                    if "log_" in key:
                        logger.scalar(
                            key, float(np.array(curr_cache[envs[i].id][key]).sum())
                        )
                        # log items won't be used later
                        curr_cache[envs[i].id].pop(key)

                if not is_eval:
                    step_in_dataset = erase_over_episodes(curr_cache, limit)
                    logger.scalar(f"dataset_size_{curr_env_name}", step_in_dataset)
                    logger.scalar(f"train_return_{curr_env_name}", score)
                    logger.scalar(f"train_length_{curr_env_name}", length)
                    logger.scalar(f"train_episodes_{curr_env_name}", len(curr_cache))
                    #[todo] start
                    if success is not None:
                        logger.scalar(f"train_success_{curr_env_name}", success)
                    #[todo] end
                    logger.write(step=logger.step)
                else:
                    if not "eval_lengths" in locals():
                        #[todo] start
                        if isinstance(directory, dict):
                            eval_lengths = {n:[] for n in train_env_name_list}
                            eval_scores = {n:[] for n in train_env_name_list}
                            eval_successes = {n: [] for n in train_env_name_list} if success is not None else None #这个装的应该是每个episode是否成功
                            eval_done = {n:False for n in train_env_name_list}
                        #[todo] end
                        else:
                            eval_lengths = []
                            eval_scores = []
                            eval_done = False
                    # start counting scores for evaluation
                    #[todo] start
                    if isinstance(directory, dict):
                        eval_scores[curr_env_name].append(score)
                        eval_successes[curr_env_name].append(success) if eval_successes is not None else None
                        eval_lengths[curr_env_name].append(length)
                    #[todo] end
                    else:
                        eval_scores.append(score)
                        eval_lengths.append(length)
                    #[todo] start
                    if isinstance(directory, dict):
                        score = sum(eval_scores[curr_env_name]) / len(eval_scores[curr_env_name])
                        success = sum(eval_successes[curr_env_name]) / len(eval_successes[curr_env_name]) if eval_successes is not None else None
                        length = sum(eval_lengths[curr_env_name]) / len(eval_lengths[curr_env_name])

                    #[todo] end
                    else:
                        score = sum(eval_scores) / len(eval_scores)
                        length = sum(eval_lengths) / len(eval_lengths)
                    logger.video(f"eval_policy_{curr_env_name}", np.array(video)[None])
                    # [todo] start
                    if isinstance(directory, dict):
                        if len(eval_scores[curr_env_name]) >= episodes / len(train_env_name_list) and not eval_done[curr_env_name]:
                            logger.scalar(f"eval_return_{curr_env_name}", score)
                            logger.scalar(f"eval_success_{curr_env_name}", success) if success is not None else None #统计的时候把所有env的这个值加起来再求平均就是最终想要的成功率
                            logger.scalar(f"eval_length_{curr_env_name}", length)
                            logger.scalar(f"eval_episodes_{curr_env_name}", len(eval_scores[curr_env_name]))
                            logger.write(step=logger.step)
                            eval_done[curr_env_name] = True
                    # [todo] end
                    else:
                        if len(eval_scores) >= episodes and not eval_done:
                            logger.scalar(f"eval_return_", score) #[todo]
                            logger.scalar(f"eval_length_", length)
                            logger.scalar(f"eval_episodes_", len(eval_scores))
                            logger.write(step=logger.step)
                            eval_done = True
    if is_eval:
        # keep only last item for saving memory. this cache is used for video_pred later
        if not isinstance(cache, collections.OrderedDict):  # [todo]
            for c in cache.values():
                # print(f"type of c: {type(c)}")
                while len(c) > 1:
                    c.popitem(last=False)
        else:
            while len(cache) > 1:
                # FIFO
                cache.popitem(last=False)
    return (step - steps, episode - episodes, done, length, obs, agent_state, reward)


def add_to_cache(cache, id, transition):
    if id not in cache:
        cache[id] = dict()
        for key, val in transition.items():
            cache[id][key] = [convert(val)]
    else:
        for key, val in transition.items():
            if key not in cache[id]:
                # fill missing data(action, etc.) at second time
                cache[id][key] = [convert(0 * val)]
                cache[id][key].append(convert(val))
            else:
                cache[id][key].append(convert(val))


def erase_over_episodes(cache, dataset_size):
    step_in_dataset = 0
    for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
        if (
            not dataset_size
            or step_in_dataset + (len(ep["reward"]) - 1) <= dataset_size
        ):
            step_in_dataset += len(ep["reward"]) - 1
        else:
            del cache[key]
    return step_in_dataset


def convert(value, precision=32):
    value = np.array(value)
    if np.issubdtype(value.dtype, np.floating):
        dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
    elif np.issubdtype(value.dtype, np.signedinteger):
        dtype = {16: np.int16, 32: np.int32, 64: np.int64}[precision]
    elif np.issubdtype(value.dtype, np.uint8):
        dtype = np.uint8
    elif np.issubdtype(value.dtype, bool):
        dtype = bool
    else:
        raise NotImplementedError(value.dtype)
    return value.astype(dtype)


def save_episodes(directory, episodes):
    directory = pathlib.Path(directory).expanduser()
    directory.mkdir(parents=True, exist_ok=True)
    for filename, episode in episodes.items():
        length = len(episode["reward"])
        filename = directory / f"{filename}-{length}.npz"
        with io.BytesIO() as f1:
            np.savez_compressed(f1, **episode)
            f1.seek(0)
            with filename.open("wb") as f2:
                f2.write(f1.read())
    return True


def from_generator(generator, batch_size):
    while True:
        batch = []
        for _ in range(batch_size):
            batch.append(next(generator))
        data = {}
        for key in batch[0].keys():
            data[key] = []
            for i in range(batch_size):
                data[key].append(batch[i][key])
            data[key] = np.stack(data[key], 0)
        yield data


def sample_episodes(episodes, length, seed=0):
    np_random = np.random.RandomState(seed)
    while True:
        size = 0
        ret = None
        p = np.array(
            [len(next(iter(episode.values()))) for episode in episodes.values()]
        )
        p = p / np.sum(p)
        while size < length:
            episode = np_random.choice(list(episodes.values()), p=p)
            total = len(next(iter(episode.values())))
            # make sure at least one transition included
            if total < 2:
                continue
            if not ret:
                index = int(np_random.randint(0, total - 1))
                ret = {
                    k: v[index : min(index + length, total)].copy()
                    for k, v in episode.items()
                    if "log_" not in k
                }
                if "is_first" in ret:
                    ret["is_first"][0] = True
            else:
                # 'is_first' comes after 'is_last'
                index = 0
                possible = length - size
                ret = {
                    k: np.append(
                        ret[k], v[index : min(index + possible, total)].copy(), axis=0
                    )
                    for k, v in episode.items()
                    if "log_" not in k
                }
                if "is_first" in ret:
                    ret["is_first"][size] = True
            size = len(next(iter(ret.values())))
        yield ret


def load_episodes(directory, limit=None, reverse=True):
    directory = pathlib.Path(directory).expanduser()
    episodes = collections.OrderedDict()
    total = 0
    if reverse:
        for filename in reversed(sorted(directory.glob("*.npz"))):
            try:
                with filename.open("rb") as f:
                    episode = np.load(f)
                    episode = {k: episode[k] for k in episode.keys()}
            except Exception as e:
                print(f"Could not load episode: {e}")
                continue
            # extract only filename without extension
            episodes[str(os.path.splitext(os.path.basename(filename))[0])] = episode
            total += len(episode["reward"]) - 1
            if limit and total >= limit:
                break
    else:
        for filename in sorted(directory.glob("*.npz")):
            try:
                with filename.open("rb") as f:
                    episode = np.load(f)
                    episode = {k: episode[k] for k in episode.keys()}
            except Exception as e:
                print(f"Could not load episode: {e}")
                continue
            episodes[str(filename)] = episode
            total += len(episode["reward"]) - 1
            if limit and total >= limit:
                break
    return episodes


class SampleDist:
    def __init__(self, dist, samples=100):
        self._dist = dist
        self._samples = samples

    @property
    def name(self):
        return "SampleDist"

    def __getattr__(self, name):
        return getattr(self._dist, name)

    def mean(self):
        samples = self._dist.sample(self._samples)
        return torch.mean(samples, 0)

    def mode(self):
        sample = self._dist.sample(self._samples)
        logprob = self._dist.log_prob(sample)
        return sample[torch.argmax(logprob)][0]

    def entropy(self):
        sample = self._dist.sample(self._samples)
        logprob = self.log_prob(sample)
        return -torch.mean(logprob, 0)


class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
    def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
        if logits is not None and unimix_ratio > 0.0:
            probs = F.softmax(logits, dim=-1)
            probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1]
            logits = torch.log(probs)
            super().__init__(logits=logits, probs=None)
        else:
            super().__init__(logits=logits, probs=probs)

    def mode(self):
        _mode = F.one_hot(
            torch.argmax(super().logits, axis=-1), super().logits.shape[-1]
        )
        return _mode.detach() + super().logits - super().logits.detach()

    def sample(self, sample_shape=(), seed=None):
        if seed is not None:
            raise ValueError("need to check")
        sample = super().sample(sample_shape).detach()
        probs = super().probs
        while len(probs.shape) < len(sample.shape):
            probs = probs[None]
        sample += probs - probs.detach()
        return sample


class DiscDist:
    def __init__(
        self,
        logits,
        low=-20.0,
        high=20.0,
        transfwd=symlog,
        transbwd=symexp,
        device="cuda",
    ):
        self.logits = logits
        self.probs = torch.softmax(logits, -1)
        self.buckets = torch.linspace(low, high, steps=255, device=device)
        self.width = (self.buckets[-1] - self.buckets[0]) / 255
        self.transfwd = transfwd
        self.transbwd = transbwd

    def mean(self):
        _mean = self.probs * self.buckets
        return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))

    def mode(self):
        _mode = self.probs * self.buckets
        return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True))

    # Inside OneHotCategorical, log_prob is calculated using only max element in targets
    def log_prob(self, x):
        x = self.transfwd(x)
        # x(time, batch, 1)
        below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
        above = len(self.buckets) - torch.sum(
            (self.buckets > x[..., None]).to(torch.int32), dim=-1
        )
        # this is implemented using clip at the original repo as the gradients are not backpropagated for the out of limits.
        below = torch.clip(below, 0, len(self.buckets) - 1)
        above = torch.clip(above, 0, len(self.buckets) - 1)
        equal = below == above

        dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
        dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
        total = dist_to_below + dist_to_above
        weight_below = dist_to_above / total
        weight_above = dist_to_below / total
        target = (
            F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None]
            + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
        )
        log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
        target = target.squeeze(-2)

        return (target * log_pred).sum(-1)

    def log_prob_target(self, target):
        log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
        return (target * log_pred).sum(-1)


class MSEDist:
    def __init__(self, mode, agg="sum"):
        self._mode = mode
        self._agg = agg

    def mode(self):
        return self._mode

    def mean(self):
        return self._mode

    def log_prob(self, value):
        assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
        distance = (self._mode - value) ** 2
        if self._agg == "mean":
            loss = distance.mean(list(range(len(distance.shape)))[2:])
        elif self._agg == "sum":
            loss = distance.sum(list(range(len(distance.shape)))[2:])
        else:
            raise NotImplementedError(self._agg)
        return -loss


class SymlogDist:
    def __init__(self, mode, dist="mse", agg="sum", tol=1e-8):
        self._mode = mode
        self._dist = dist
        self._agg = agg
        self._tol = tol

    def mode(self):
        return symexp(self._mode)

    def mean(self):
        return symexp(self._mode)

    def log_prob(self, value):
        assert self._mode.shape == value.shape
        if self._dist == "mse":
            distance = (self._mode - symlog(value)) ** 2.0
            distance = torch.where(distance < self._tol, 0, distance)
        elif self._dist == "abs":
            distance = torch.abs(self._mode - symlog(value))
            distance = torch.where(distance < self._tol, 0, distance)
        else:
            raise NotImplementedError(self._dist)
        if self._agg == "mean":
            loss = distance.mean(list(range(len(distance.shape)))[2:])
        elif self._agg == "sum":
            loss = distance.sum(list(range(len(distance.shape)))[2:])
        else:
            raise NotImplementedError(self._agg)
        return -loss


class ContDist:
    def __init__(self, dist=None, absmax=None):
        super().__init__()
        self._dist = dist
        self.mean = dist.mean
        self.absmax = absmax

    def __getattr__(self, name):
        return getattr(self._dist, name)

    def entropy(self):
        return self._dist.entropy()

    def mode(self):
        out = self._dist.mean
        if self.absmax is not None:
            out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
        return out

    def sample(self, sample_shape=()):
        out = self._dist.rsample(sample_shape)
        if self.absmax is not None:
            out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
        return out

    def log_prob(self, x):
        return self._dist.log_prob(x)


class Bernoulli:
    def __init__(self, dist=None):
        super().__init__()
        self._dist = dist
        self.mean = dist.mean

    def __getattr__(self, name):
        return getattr(self._dist, name)

    def entropy(self):
        return self._dist.entropy()

    def mode(self):
        _mode = torch.round(self._dist.mean)
        return _mode.detach() + self._dist.mean - self._dist.mean.detach()

    def sample(self, sample_shape=()):
        return self._dist.rsample(sample_shape)

    def log_prob(self, x):
        _logits = self._dist.base_dist.logits
        log_probs0 = -F.softplus(_logits)
        log_probs1 = -F.softplus(-_logits)

        return torch.sum(log_probs0 * (1 - x) + log_probs1 * x, -1)


class UnnormalizedHuber(torchd.normal.Normal):
    def __init__(self, loc, scale, threshold=1, **kwargs):
        super().__init__(loc, scale, **kwargs)
        self._threshold = threshold

    def log_prob(self, event):
        return -(
            torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - self._threshold
        )

    def mode(self):
        return self.mean


class SafeTruncatedNormal(torchd.normal.Normal):
    def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
        super().__init__(loc, scale)
        self._low = low
        self._high = high
        self._clip = clip
        self._mult = mult

    def sample(self, sample_shape):
        event = super().sample(sample_shape)
        if self._clip:
            clipped = torch.clip(event, self._low + self._clip, self._high - self._clip)
            event = event - event.detach() + clipped.detach()
        if self._mult:
            event *= self._mult
        return event


class TanhBijector(torchd.Transform):
    def __init__(self, validate_args=False, name="tanh"):
        super().__init__()

    def _forward(self, x):
        return torch.tanh(x)

    def _inverse(self, y):
        y = torch.where(
            (torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y
        )
        y = torch.atanh(y)
        return y

    def _forward_log_det_jacobian(self, x):
        log2 = torch.math.log(2.0)
        return 2.0 * (log2 - x - torch.softplus(-2.0 * x))


def static_scan_for_lambda_return(fn, inputs, start):
    last = start
    indices = range(inputs[0].shape[0])
    indices = reversed(indices)
    flag = True
    for index in indices:
        # (inputs, pcont) -> (inputs[index], pcont[index])
        inp = lambda x: (_input[x] for _input in inputs)
        last = fn(last, *inp(index))
        if flag:
            outputs = last
            flag = False
        else:
            outputs = torch.cat([outputs, last], dim=-1)
    outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
    outputs = torch.flip(outputs, [1])
    outputs = torch.unbind(outputs, dim=0)
    return outputs


def lambda_return(reward, value, pcont, bootstrap, lambda_, axis):
    # Setting lambda=1 gives a discounted Monte Carlo return.
    # Setting lambda=0 gives a fixed 1-step return.
    # assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
    assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
    if isinstance(pcont, (int, float)):
        pcont = pcont * torch.ones_like(reward)
    dims = list(range(len(reward.shape)))
    dims = [axis] + dims[1:axis] + [0] + dims[axis + 1 :]
    if axis != 0:
        reward = reward.permute(dims)
        value = value.permute(dims)
        pcont = pcont.permute(dims)
    if bootstrap is None:
        bootstrap = torch.zeros_like(value[-1])
    next_values = torch.cat([value[1:], bootstrap[None]], 0)
    inputs = reward + pcont * next_values * (1 - lambda_)
    # returns = static_scan(
    #    lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
    #    (inputs, pcont), bootstrap, reverse=True)
    # reimplement to optimize performance
    returns = static_scan_for_lambda_return(
        lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, (inputs, pcont), bootstrap
    )
    if axis != 0:
        returns = returns.permute(dims)
    return returns


class Optimizer:
    def __init__(
        self,
        name,
        parameters,
        lr,
        eps=1e-4,
        clip=None,
        wd=None,
        wd_pattern=r".*",
        opt="adam",
        use_amp=False,
    ):
        assert 0 <= wd < 1
        assert not clip or 1 <= clip
        self._name = name
        self._parameters = parameters
        self._clip = clip
        self._wd = wd
        self._wd_pattern = wd_pattern
        self._opt = {
            "adam": lambda: torch.optim.Adam(parameters, lr=lr, eps=eps),
            "nadam": lambda: NotImplemented(f"{opt} is not implemented"),
            "adamax": lambda: torch.optim.Adamax(parameters, lr=lr, eps=eps),
            "sgd": lambda: torch.optim.SGD(parameters, lr=lr),
            "momentum": lambda: torch.optim.SGD(parameters, lr=lr, momentum=0.9),
        }[opt]()
        self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    def __call__(self, loss, params, retain_graph=True):
        assert len(loss.shape) == 0, loss.shape
        metrics = {}
        metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
        self._opt.zero_grad()
        self._scaler.scale(loss).backward(retain_graph=retain_graph)
        self._scaler.unscale_(self._opt)
        # loss.backward(retain_graph=retain_graph)
        norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
        if self._wd:
            self._apply_weight_decay(params)
        self._scaler.step(self._opt)
        self._scaler.update()
        # self._opt.step()
        self._opt.zero_grad()
        metrics[f"{self._name}_grad_norm"] = to_np(norm)
        return metrics

    def _apply_weight_decay(self, varibs):
        nontrivial = self._wd_pattern != r".*"
        if nontrivial:
            raise NotImplementedError
        for var in varibs:
            var.data = (1 - self._wd) * var.data


def args_type(default):
    def parse_string(x):
        if default is None:
            return x
        if isinstance(default, bool):
            return bool(["False", "True"].index(x))
        if isinstance(default, int):
            return float(x) if ("e" in x or "." in x) else int(x)
        if isinstance(default, (list, tuple)):
            return tuple(args_type(default[0])(y) for y in x.split(","))
        return type(default)(x)

    def parse_object(x):
        if isinstance(default, (list, tuple)):
            return tuple(x)
        return x

    return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)


def static_scan(fn, inputs, start):
    last = start
    indices = range(inputs[0].shape[0])
    flag = True
    for index in indices:
        # inp = lambda x: (_input[x] for _input in inputs)
        inp = lambda x: (_input if (isinstance(_input, str) or not hasattr(_input, "__getitem__")) else _input[x] for _input in inputs) #[todo]

        last = fn(last, *inp(index))
        if flag:
            if type(last) == type({}):
                outputs = {
                    key: value.clone().unsqueeze(0) for key, value in last.items()
                }
            else:
                outputs = []
                for _last in last:
                    if type(_last) == type({}):
                        outputs.append(
                            {
                                key: value.clone().unsqueeze(0)
                                for key, value in _last.items()
                            }
                        )
                    else:
                        outputs.append(_last.clone().unsqueeze(0))
            flag = False
        else:
            if type(last) == type({}):
                for key in last.keys():
                    outputs[key] = torch.cat(
                        [outputs[key], last[key].unsqueeze(0)], dim=0
                    )
            else:
                for j in range(len(outputs)):
                    if type(last[j]) == type({}):
                        for key in last[j].keys():
                            outputs[j][key] = torch.cat(
                                [outputs[j][key], last[j][key].unsqueeze(0)], dim=0
                            )
                    else:
                        outputs[j] = torch.cat(
                            [outputs[j], last[j].unsqueeze(0)], dim=0
                        )
    if type(last) == type({}):
        outputs = [outputs]
    return outputs


class Every:
    '''
    该类把时间轴分为以 every 为间隔的格子，_last 始终保持在某个格子“锚点”上。每次调用，返回从上次锚点到当前 step 的完整格子数（可能为 0、1、2...），然后把锚点推进到刚好与这些完整格子对齐的位置（即 old_last + count * every）。
    第一次调用返回 1 并把 _last 设为调用时的 step。也就是说将“首次遇到”当作一次触发（不同实现里有时会返回 0，这是设计选择）。
    '''
    def __init__(self, every):
        self._every = every
        self._last = None

    def __call__(self, step):
        if not self._every:
            return 0
        if self._last is None:
            self._last = step
            return 1
        count = int((step - self._last) / self._every)
        self._last += self._every * count
        return count


class Once:
    def __init__(self):
        self._once = True

    def __call__(self):
        if self._once:
            self._once = False
            return True
        return False


class Until:
    def __init__(self, until):
        self._until = until

    def __call__(self, step):
        if not self._until:
            return True
        return step < self._until


def weight_init(m):
    if isinstance(m, nn.Linear):
        in_num = m.in_features
        out_num = m.out_features
        denoms = (in_num + out_num) / 2.0
        scale = 1.0 / denoms
        std = np.sqrt(scale) / 0.87962566103423978
        nn.init.trunc_normal_(
            m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
        )
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        space = m.kernel_size[0] * m.kernel_size[1]
        in_num = space * m.in_channels
        out_num = space * m.out_channels
        denoms = (in_num + out_num) / 2.0
        scale = 1.0 / denoms
        std = np.sqrt(scale) / 0.87962566103423978
        nn.init.trunc_normal_(
            m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
        )
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)
    elif isinstance(m, nn.LayerNorm):
        m.weight.data.fill_(1.0)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)


def uniform_weight_init(given_scale):
    def f(m):
        if isinstance(m, nn.Linear):
            in_num = m.in_features
            out_num = m.out_features
            denoms = (in_num + out_num) / 2.0
            scale = given_scale / denoms
            limit = np.sqrt(3 * scale)
            nn.init.uniform_(m.weight.data, a=-limit, b=limit)
            if hasattr(m.bias, "data"):
                m.bias.data.fill_(0.0)
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            space = m.kernel_size[0] * m.kernel_size[1]
            in_num = space * m.in_channels
            out_num = space * m.out_channels
            denoms = (in_num + out_num) / 2.0
            scale = given_scale / denoms
            limit = np.sqrt(3 * scale)
            nn.init.uniform_(m.weight.data, a=-limit, b=limit)
            if hasattr(m.bias, "data"):
                m.bias.data.fill_(0.0)
        elif isinstance(m, nn.LayerNorm):
            m.weight.data.fill_(1.0)
            if hasattr(m.bias, "data"):
                m.bias.data.fill_(0.0)

    return f


def tensorstats(tensor, prefix=None):
    metrics = {
        "mean": to_np(torch.mean(tensor)),
        "std": to_np(torch.std(tensor)),
        "min": to_np(torch.min(tensor)),
        "max": to_np(torch.max(tensor)),
    }
    if prefix:
        metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
    return metrics


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def enable_deterministic_run():
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)


def recursively_collect_optim_state_dict(
    obj, path="", optimizers_state_dicts=None, visited=None
):
    if optimizers_state_dicts is None:
        optimizers_state_dicts = {}
    if visited is None:
        visited = set()
    # avoid cyclic reference
    if id(obj) in visited:
        return optimizers_state_dicts
    else:
        visited.add(id(obj))
    attrs = obj.__dict__
    if isinstance(obj, torch.nn.Module):
        attrs.update(
            {k: attr for k, attr in obj.named_modules() if "." not in k and obj != attr}
        )
    for name, attr in attrs.items():
        new_path = path + "." + name if path else name
        if isinstance(attr, torch.optim.Optimizer):
            optimizers_state_dicts[new_path] = attr.state_dict()
        elif hasattr(attr, "__dict__"):
            optimizers_state_dicts.update(
                recursively_collect_optim_state_dict(
                    attr, new_path, optimizers_state_dicts, visited
                )
            )
    return optimizers_state_dicts


def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
    for path, state_dict in optimizers_state_dicts.items():
        keys = path.split(".")
        obj_now = obj
        for key in keys:
            obj_now = getattr(obj_now, key)
        obj_now.load_state_dict(state_dict)

#[todo] start

def config2list(input):
    if isinstance(input, list):
        train_env_name_list = input
    elif isinstance(input, tuple):
        train_env_name_list = list(input)
    else:
        train_env_name_list = [x.strip() for x in input.strip("[]").split(",")]
    return train_env_name_list

def list_to_str(lst):
    # 先将每个元素转为字符串，再用逗号连接，最后包裹[]
    return "[" + ",".join(str(elem) for elem in lst) + "]"

# def sample_merged_data(train_env_name_list, dataset, max_workers=None):
#     """
#     从多个环境的数据集中并行采样，并合并成一个 batch。
#     """
#     if max_workers is None:
#         max_workers = len(train_env_name_list)
#
#     # 并行采样
#     with ThreadPoolExecutor(max_workers=max_workers) as executor:
#         future_to_env = {executor.submit(next, dataset[n]): n for n in train_env_name_list}
#         batches = []
#         for future in as_completed(future_to_env):
#             env_name = future_to_env[future]
#             try:
#                 batch = future.result()
#                 batches.append(batch)
#             except Exception as e:
#                 print(f"[WARN] Failed to sample from env {env_name}: {e}")
#
#     if not batches:
#         raise ValueError("No batches collected!")
#
#     # 合并 batch
#     keys = batches[0].keys()
#     merged = {k: np.concatenate([b[k] for b in batches], axis=0) for k in keys}
#     return merged

def sample_merged_data(train_env_name_list, dataset, max_workers=None, obs_filter=None):
    """
    从多个环境的数据集中并行采样，并合并成一个 batch（保证环境顺序与 train_env_name_list 一致，无过滤，必须所有环境采样成功）。
    """
    if max_workers is None:
        max_workers = len(train_env_name_list)

    # 初始化字典：按环境名存储采样结果
    env_to_batch = {}
    # 并行采样
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_env = {executor.submit(next, dataset[n]): n for n in train_env_name_list}
        for future in as_completed(future_to_env):
            env_name = future_to_env[future]
            try:
                batch = future.result()
                env_to_batch[env_name] = batch
            except Exception as e:
                # 采样失败直接抛异常（也可保留打印，最终统一校验）
                raise RuntimeError(f"采样环境 {env_name} 失败: {e}") from e

    # 严格按 train_env_name_list 顺序提取batch（无过滤，必须所有环境都有结果）
    try:
        ordered_batches = [env_to_batch[env] for env in train_env_name_list]
    except KeyError as e:
        raise ValueError(f"环境 {e} 未采样到数据，请检查该环境的数据集是否正常！")

    # 合并 batch（顺序与 train_env_name_list 完全一致）
    if obs_filter is not None:
        # print(f"ordered_batches[0].keys():{ordered_batches[0].keys()}")
        valid_keys = [k for k in ordered_batches[0].keys() if k in obs_filter]  # for k in obs[0]默认会解释成.keys()
    else:
        valid_keys = ordered_batches[0].keys()

    merged = {k: np.concatenate([b[k] for b in ordered_batches], axis=0) for k in valid_keys}
    return merged


def get_list_prefix(list):
    '''
    input: list (e.g.["walker_walk", "walker_stand", "cartpole_balance"])
    output dict (e.g.{"walker":["walker_walk","walker_stand"],"cartpole":["cartpole_balance"]}
    '''
    # 初始化空字典：键:前缀（如a/b/c），值:对应元素列表
    prefix_dict = {}

    # 1. 遍历原始列表，按前缀动态分组
    for elem in list:
        # 拆分前缀：按下划线分割，处理无下划线的异常元素（避免索引报错）
        parts = elem.split("_")
        if len(parts) < 2:
            print(f"警告：元素「{elem}」无下划线，前缀按整体处理")
            prefix = elem  # 无下划线时，前缀为元素本身
        else:
            prefix = parts[0]  # 有下划线时，取第一个部分为前缀

        # 动态添加：前缀不存在则初始化空列表，再追加元素
        prefix_dict.setdefault(prefix, []).append(elem)
    return prefix_dict

def get_grouping_dict(logdir):
    # 定义用于存储分组结果的字典
    grouped_folder_names = {}

    # 先判断logdir根目录是否存在（提高健壮性）
    if not logdir.exists() or not logdir.is_dir():
        print(f"错误：根目录 {logdir} 不存在或不是有效目录")
    else:
        # 自动遍历logdir下的所有一级子目录（作为分组，无需手动指定）
        # iterdir()遍历logdir内容，is_dir()筛选出一级子目录（即所有group文件夹）
        for group_dir in logdir.iterdir():
            if group_dir.is_dir():
                group_name = group_dir.name  # 自动获取分组名称（如group0、group8等）
                # 拼接该分组下的train_eps路径
                train_eps_path = group_dir / "train_eps"

                # 提取train_eps下的所有子文件夹名称
                if train_eps_path.exists() and train_eps_path.is_dir():
                    folder_names = [
                        item.name for item in train_eps_path.iterdir()
                        if item.is_dir()
                    ]
                else:
                    folder_names = []  # 无train_eps目录则存入空列表

                # 存入分组结果字典
                grouped_folder_names[group_name] = folder_names
    return grouped_folder_names


def get_latest_ckpt(ckpt_dir, use_modify_time=True):
    """
    从ckpt文件列表中筛选出最新的文件
    Args:
        ckpt_dir: ckpt文件完整路径
        use_modify_time: 是否使用文件修改时间（推荐）；False则使用创建时间
    Returns:
        最新ckpt文件的完整路径
    """
    # 筛选目录下所有.ckpt格式的文件（返回完整文件路径列表）
    ckpt_file_list = [
        str(file_path) for file_path in Path(ckpt_dir).glob("*.pt")
        if os.path.isfile(file_path)  # 确保是文件，排除目录（防止重名目录干扰）
    ]
    if use_modify_time:
        # os.path.getmtime()：获取文件最后修改时间（时间戳格式）
        ckpt_with_time = [(ckpt, os.path.getmtime(ckpt)) for ckpt in ckpt_file_list]
    else:
        # os.path.getctime()：获取文件创建时间（注意：Linux下是元数据修改时间）
        ckpt_with_time = [(ckpt, os.path.getctime(ckpt)) for ckpt in ckpt_file_list]

    # 按时间戳降序排序，取第一个（时间戳最大=最新）
    ckpt_with_time.sort(key=lambda x: x[1], reverse=True)
    return ckpt_with_time[0][0]


import os


def has_group_dir_simple(root_path):
    """
    用字符串处理判断指定路径下是否存在 group_ 格式的目录
    :param root_path: 目标根路径
    :return: bool，存在返回True，不存在返回False
    """
    # 1. 先判断根路径是否存在
    if not os.path.exists(root_path):
        print(f"错误：根路径 {root_path} 不存在")
        return False

    # 2. 遍历根路径下的所有内容
    for item in os.listdir(root_path):
        item_full_path = os.path.join(root_path, item)
        # 3. 判断是否是目录
        if os.path.isdir(item_full_path):
            # 4. 字符串格式校验：以group_开头，后面跟着纯数字
            if item.startswith("group_"):
                return True

    # 遍历完毕未找到符合格式的目录
    return False


def find_closest_pt_files_step(target_dir, target_step, threshold):
    """
    在指定目录查找符合条件的.pt文件

    参数:
        target_dir: 目标目录路径
        target_step: 目标step值（数值类型）
        threshold: step差值阈值（正数）

    返回:
        符合条件的steps（无则返回None）
    """
    # 存储符合条件的文件信息：(差值, 完整文件路径)
    valid_files = []

    # 遍历目录下的所有文件
    for filename in os.listdir(target_dir):
        # 只处理.pt后缀的文件
        if filename.endswith(".pt"):
            # 提取文件名中的数字部分（去掉.pt后缀）
            step_str = filename[:-3]
            try:
                # 尝试转换为数值类型（支持整数/浮点数）
                step = float(step_str)
                # 计算与目标值的绝对差值
                diff = abs(step - target_step)

                # 筛选出差值小于阈值的文件
                if diff < threshold:
                    file_path = os.path.join(target_dir, filename)
                    valid_files.append((diff, file_path))
            except ValueError:
                # 文件名不是纯数字（如"abc123.pt"），跳过
                continue

    # 无符合条件的文件
    if not valid_files:
        return None

    # 按差值升序排序，取差值最小的文件
    valid_files.sort(key=lambda x: x[0])
    closest_file = valid_files[0][1]

    # return closest_file
    return int(os.path.basename(closest_file)[:-3])

#[todo] end