import os
import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter


from collections import OrderedDict
from collections import OrderedDict
from tqdm import tqdm

########### nano gpt modification ###########

import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = ""
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""

    # Algorithm specific arguments
    env_id: str = "HalfCheetah-v4"
    """the id of the environment"""
    total_timesteps: int = 1_000_000
    """total timesteps of the experiments"""
    learning_rate: float = 3e-4
    """the learning rate of the optimizer"""
    final_learning_rate: float = 0.0
    """the final learning rate of the optimizer for annealing"""
    num_envs: int = 1
    """the number of parallel game environments"""
    num_steps: int = 2048
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 32
    """the number of mini-batches"""
    update_epochs: int = 10
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.0
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = None
    """the target KL divergence threshold"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""


    ############## nano GPT-2 specific arguments ##############

    num_eval_envs: int = 100
    num_eval_steps: int = 1000


    eval_freq: int = 10
    save_freq: int = 100


    anneal_lr: bool = False
    anneal_ent_coef: bool = False

    ent_coef: float = 0.0
    final_ent_coef: float = 0.0


    head_dim: int = 256
    actor_head_layers: int = 2 # min layers = 2 
    critic_head_layers: int = 2

    n_layer: int = 3
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.0
    bias: bool = True
    seq_len: int = 30
    optim: str = 'AdamW' #'Adam'
    use_gates: bool = False
    wo_ffn: bool = False
    norm_first: bool = False

    notes: str = ''
    tag: str = ''

    device: str = 'cuda:0'

    pomdp: bool = False
    save_model: bool = False
    save_best: bool = False
    log_gates_score: bool = False



class PartialObservation(gym.ObservationWrapper):
    def __init__(self, env: gym.Env, obs_indices: list):
        gym.ObservationWrapper.__init__(self, env)

        obsspace = env.observation_space
        self.obs_indices = obs_indices
        self.observation_space = gym.spaces.Box(
            low=np.array([obsspace.low[i] for i in obs_indices]),
            high=np.array([obsspace.high[i] for i in obs_indices]),
            dtype=np.float32,
        )

        self._env = env

    def observation(self, observation):
        filter_observation = self._filter_observation(observation)
        return filter_observation

    def _filter_observation(self, observation):
        observation = np.array([observation[i] for i in self.obs_indices])
        return observation

def make_env(env_id, pomdp, idx, capture_video, run_name, gamma):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
        env = gym.wrappers.NormalizeReward(env, gamma=gamma)
        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))

        if pomdp:
            if env_id == 'HalfCheetah-v4':
                env = PartialObservation(env, [0, 1, 2, 3, 8, 9, 10, 11, 12])
            elif env_id == 'Hopper-v4':
                env = PartialObservation(env, [0, 1, 2, 3, 4])
            elif env_id == 'Ant-v4':
                env = PartialObservation(env, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
            else:
                assert 0
        return env

    return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


#######################################################################################################################################
#######################################################################################################################################

class GRUGate(nn.Module):
    """
    Overview:
        GRU Gating Unit used in GTrXL.
        Inspired by https://github.com/dhruvramani/Transformers-RL/blob/master/layers.py
    """

    def __init__(self, input_dim: int, bg: float = 0.0, log_gates_score: bool = False):
        """
        Arguments:
            input_dim {int} -- Input dimension
            bg {float} -- Initial gate bias value. By setting bg > 0 we can explicitly initialize the gating mechanism to
            be close to the identity map. This can greatly improve the learning speed and stability since it
            initializes the agent close to a Markovian policy (ignore attention at the beginning). (default: {0.0})
        """
        super(GRUGate, self).__init__()
        self.Wr = nn.Linear(input_dim, input_dim, bias=False)
        self.Ur = nn.Linear(input_dim, input_dim, bias=False)
        self.Wz = nn.Linear(input_dim, input_dim, bias=False)
        self.Uz = nn.Linear(input_dim, input_dim, bias=False)
        self.Wg = nn.Linear(input_dim, input_dim, bias=False)
        self.Ug = nn.Linear(input_dim, input_dim, bias=False)
        self.bg = nn.Parameter(torch.full([input_dim], bg))  # bias
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        nn.init.xavier_uniform_(self.Wr.weight)
        nn.init.xavier_uniform_(self.Ur.weight)
        nn.init.xavier_uniform_(self.Wz.weight)
        nn.init.xavier_uniform_(self.Uz.weight)
        nn.init.xavier_uniform_(self.Wg.weight)
        nn.init.xavier_uniform_(self.Ug.weight)

        self.log_gates_score = log_gates_score

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        """        
        Arguments:
            x {torch.tensor} -- First input
            y {torch.tensor} -- Second input
        Returns:
            {torch.tensor} -- Output
        """
        r = self.sigmoid(self.Wr(y) + self.Ur(x))
        z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg)
        h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x)))

        # print(f'mean z: {z.mean()}')

        if self.log_gates_score:
            return torch.mul(1 - z, x) + torch.mul(z, h) , (1- z).mean()


        return torch.mul(1 - z, x) + torch.mul(z, h) 

#######################################################################################################################################
######################################################## nano gpt modification ########################################################

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
    
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.num_steps, config.num_steps))
                                        .view(1, 1, config.num_steps, config.num_steps))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y



class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.log_gates_score = config.log_gates_score
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

        if config.use_gates:
            self.skip_fn_1 = GRUGate(config.n_embd, 2.0, self.log_gates_score)
            self.skip_fn_2 = GRUGate(config.n_embd, 2.0, self.log_gates_score)
        else:
            self.skip_fn_1 = lambda x, y: x + y
            self.skip_fn_2 = lambda x, y: x + y


    def forward(self, x):
        

        if self.log_gates_score:
            gate_scores = {}

            x, gate_score_att = self.skip_fn_1(x, self.attn(self.ln_1(x)))
            x, gate_score_mlp = self.skip_fn_2(x, self.mlp(self.ln_2(x)))

            gate_scores['gate_score_att'] = gate_score_att
            gate_scores['gate_score_mlp'] = gate_score_mlp

            return x, gate_scores

        x = self.skip_fn_1(x, self.attn(self.ln_1(x)))
        x = self.skip_fn_2(x, self.mlp(self.ln_2(x)))

        # x = x + self.attn(self.ln_1(x))
        # x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        # self.transformer = nn.ModuleDict(dict(
        #     wte = nn.Embedding(config.vocab_size, config.n_embd),
        #     wpe = nn.Embedding(config.num_steps, config.n_embd),
        #     drop = nn.Dropout(config.dropout),
        #     h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        #     ln_f = LayerNorm(config.n_embd, bias=config.bias),
        # ))
        # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # self.pos_embedding = nn.Parameter(torch.randn(config.max_episode_steps, config.n_embd))
        
        self.pos_embedding = nn.Embedding(config.max_episode_steps, config.n_embd)

        self.transformer_layers = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, bias=config.bias)
        self.drop = nn.Dropout(config.dropout)

        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        # print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        
        # x: batch_size, sequence_length, embedding_dim (n_embd)
        # print(x.shape)

        t = x.shape[1]

        # assert 0
        device = x.device
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
        pos_emb = self.pos_embedding(pos) # position embeddings of shape (t, n_embd)

        # print(f'gpt x shape: {x.shape}, pos emb shape: {pos_emb.shape}')
        
        gating_score_dict = {}


        x = self.drop(x + pos_emb)
        for i, block in enumerate(self.transformer_layers):
            if self.config.log_gates_score:
                x, gate_scores = block(x)

                for k, v in gate_scores.items():
                    gating_score_dict[f'{k}_layer{i}'] = v.item()
            else:
                x = block(x)
        x = self.ln_f(x)

        if self.config.log_gates_score:
            return x, gating_score_dict


        return x


##########################################################################################
##########################################################################################

class Agent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()

        self.config = args
        self.obs_shape = envs.single_observation_space.shape
        self.encoder = layer_init(nn.Linear(self.obs_shape[0], args.n_embd))

        # self.critic = nn.Sequential(
        #     layer_init(nn.Linear(args.n_embd, 256)),
        #     nn.Tanh(),
        #     layer_init(nn.Linear(256, 256)),
        #     nn.Tanh(),
        #     layer_init(nn.Linear(256, 256)),
        #     nn.Tanh(),
        #     layer_init(nn.Linear(256, 1)),
        # )
        # self.actor_mean = nn.Sequential(
        #     layer_init(nn.Linear(args.n_embd, 256)),
        #     nn.Tanh(),
        #     layer_init(nn.Linear(256, 256)),
        #     nn.Tanh(),
        #     layer_init(nn.Linear(256, 256)),
        #     nn.Tanh(),
        #     layer_init(nn.Linear(256, np.array(envs.single_action_space.shape).prod()), std=0.01*np.sqrt(2)),
        # )


        #### init actor and critic 



        critic_intermediate_layers = []
        critic_prediction_layer = [('prediction_layer', layer_init(nn.Linear(args.head_dim, 1), std=1))]
        critic_post_transformer_layer = [('post_transformer_layer', layer_init(nn.Linear(args.n_embd, args.head_dim))),  (f'post_activation', nn.Tanh())]


        for i in range(args.critic_head_layers-2):
            critic_intermediate_layers += [(f'head_layer_{i}', layer_init(nn.Linear(args.head_dim, args.head_dim))), (f'activation_{i}', nn.Tanh())]

        self.critic = nn.Sequential(OrderedDict(critic_post_transformer_layer + critic_intermediate_layers + critic_prediction_layer))

        actor_intermediate_layers = []
        actor_prediction_layer = [('prediction_layer', layer_init(nn.Linear(args.head_dim, np.array(envs.single_action_space.shape).prod()), std=0.01))]
        actor_post_transformer_layer = [('post_transformer_layer', layer_init(nn.Linear(args.n_embd, args.head_dim))),  (f'post_activation', nn.Tanh())]


        for i in range(args.actor_head_layers-2):
            actor_intermediate_layers += [(f'head_layer_{i}', layer_init(nn.Linear(args.head_dim, args.head_dim))), (f'activation_{i}', nn.Tanh())]

        self.actor_mean = nn.Sequential(OrderedDict(actor_post_transformer_layer + actor_intermediate_layers + actor_prediction_layer))



        
        # self.actor_logstd = nn.Parameter(torch.zeros(1, action_space_shape[0]))
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
        # self.actor_logstd = nn.Parameter(torch.ones(1, np.array(envs.single_action_space.shape).prod()) * -0.5)

        self.transformer = GPT(args)

        # self.hidden_post_trxl = nn.Sequential(
        #     layer_init(nn.Linear(args.n_embd, args.n_embd)),
        #     nn.ReLU(),
        # )

    def get_value(self, x):

        if len(self.obs_shape) > 1:
            x = self.encoder(x.permute((0, 3, 1, 2)) / 255.0)
        else:
            x = self.encoder(x)


        if self.config.log_gates_score:
            x, gating_score_dict = self.transformer(x)
        else:
            x = self.transformer(x)

        aggregated_x = x[:, -1]


        return self.critic(aggregated_x).flatten()


    def get_action(self, x, deterministic=False):

        # print(f'init x shape: {x.shape}')

        if len(self.obs_shape) > 1:
            x = self.encoder(x.permute((0, 3, 1, 2)) / 255.0)
        else:
            x = self.encoder(x)


        # print(f'x shape after encoder: {x.shape}')
        if self.config.log_gates_score:
            x, gating_score_dict = self.transformer(x)
        else:
            x = self.transformer(x)

        aggregated_x = x[:, -1]


        # print(f'after trans : {x.shape}')

        action_mean = self.actor_mean(aggregated_x)
        if deterministic:
            return action_mean
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        return probs.sample()


    def get_action_and_value(self, x, action=None):
        
        # print('before transformer')
        # print(x.shape)


        if len(self.obs_shape) > 1:
            x = self.encoder(x.permute((0, 3, 1, 2)) / 255.0)
        else:
            x = self.encoder(x)


        if self.config.log_gates_score:
            x, gating_score_dict = self.transformer(x)
        else:
            x = self.transformer(x)
        # print('after transformer')
        # print(x.shape)


        aggregated_x = x[:, -1] # select the last elemet of the post processed hidden

        action_mean = self.actor_mean(aggregated_x)
        # print(action_mean.shape)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()

        # print(action.shape)
        # print(probs.log_prob(action).shape) 
        # print(probs.entropy().sum(1).shape) 
        # print(self.critic(aggregated_x).shape)

        if self.config.log_gates_score:
            return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(aggregated_x), gating_score_dict
            
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(aggregated_x)


##########################################################################################
##########################################################################################


def get_obs_seq_with_context(
    b_dones: torch.Tensor,
    b_obs:   torch.Tensor,
    mb_inds: torch.Tensor,
    seq_len: int
):
    """
    b_dones : тензор размера (batch_size,)
              b_dones[t] = 1, если эпизод закончился на шаге t, иначе 0
    b_obs   : тензор размера (batch_size, obs_dim)
              b_obs[t]   = наблюдение среды на шаге t
    mb_inds : тензор индексов (shape=(m,)) по которым формируется мини-батч
    seq_len : длина необходимой последовательности (контекста)

    Возвращает (mb_obs_seq, new_mb_inds):

    mb_obs_seq  : тензор размера (k, seq_len, obs_dim),
                  где k <= len(mb_inds), только валидные последовательности
    new_mb_inds : индексы (shape=(k,)) из исходного mb_inds,
                  которым соответствуют валидные последовательности
    """

    device = b_dones.device
    batch_size = b_dones.shape[0]

    # Пример: seq_len = 4 -> offset = [-3, -2, -1, 0]
    offset = torch.arange(-seq_len + 1, 1, device=device)  # shape = (seq_len,)

    # idxs.shape = (len(mb_inds), seq_len)
    # Каждая строка: [mb_inds[i] - (seq_len-1), ..., mb_inds[i]]
    idxs = mb_inds.unsqueeze(-1) + offset.unsqueeze(0)

    # -- 1) Проверка на диапазон индексов (неотрицательные и < batch_size)
    # Для каждой строки idxs[i] проверяем, что все элементы >=0 и < batch_size
    valid_mask = (idxs >= 0).all(dim=1) & (idxs < batch_size).all(dim=1)

    # -- 2) "Выбираем" соответствующие фрагменты из b_dones и b_obs
    # seq_dones.shape = (len(mb_inds), seq_len)
    # seq_obs.shape   = (len(mb_inds), seq_len, obs_dim)
    seq_dones = b_dones[idxs.clamp_min(0).clamp_max(batch_size-1)]  
    seq_obs   = b_obs[idxs.clamp_min(0).clamp_max(batch_size-1)]

    # -- 3) Проверяем отсутствие "done=1" внутри последовательности
    # Если в последовательности есть done=1, значит мы «пересекли» границу эпизодов
    # (seq_dones == 0).all(dim=1) -> True, если вся строка из нулей.
    valid_mask = valid_mask & (seq_dones == 0).all(dim=1)

    # -- 4) Применяем маску valid_mask
    mb_obs_seq  = seq_obs[valid_mask]
    new_mb_inds = mb_inds[valid_mask]

    return mb_obs_seq, new_mb_inds



def evaluate(args, agent, envs, device, deterministic = True):
    print("Evaluating")
    with torch.no_grad():


        agent.eval()


        observations = torch.zeros((args.max_episode_steps, args.num_eval_envs) + envs.single_observation_space.shape).to(device)
        eval_returns = []
        eval_lens = []

        step = 0
        
        

        obs, _ = envs.reset(seed = range(args.num_eval_envs))
        obs = torch.tensor(obs).to(device)


        while len(eval_returns) < args.num_eval_envs:
            # print(f'eval step: {step}')

            observations[step] = obs
            low_ind = step - args.seq_len + 1 if step >= args.seq_len else 0

            action = agent.get_action(observations[low_ind:step + 1].permute((1, 0, 2)), deterministic=deterministic)



            obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
            next_done = np.logical_or(terminations, truncations)
            obs = torch.tensor(obs).to(device)



            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info is not None and "episode" in info:
                        eval_lens += [info["episode"]["l"]]
                        eval_returns += [info["episode"]["r"]]

            step += 1

    # print(eval_returns)
    return np.array(eval_returns).mean()



def train(args = None):

    if args is None:
        args = tyro.cli(Args)

    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"ppo_trans_encoder_{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track and False:
        import wandb



        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            tags = ["ppo_mlp"]
        )
    writer = SummaryWriter(f"runs/{run_name}")

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.pomdp, i, False, run_name, args.gamma) for i in range(args.num_envs)]
    )
    eval_envs = gym.vector.AsyncVectorEnv(
        [make_env(args.env_id, args.pomdp, i, args.capture_video, run_name, args.gamma) for i in range(args.num_eval_envs)]
    )

    assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

    dummy_env = gym.make(args.env_id)
    max_episode_steps = dummy_env._max_episode_steps
    args.max_episode_steps = max_episode_steps
    args.seq_len = min(args.seq_len, max_episode_steps)


    agent = Agent(envs, args).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    model_params = sum(p.numel() for p in agent.parameters())
    args.model_params = model_params
    print(f'MODEL PARAMS: {model_params}')


    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )



    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs).to(device)

    best_return = -999999

    for iteration in tqdm(range(1, args.num_iterations + 1), desc="Training Progress"):


        if iteration % args.save_freq == 0 and args.save_model:
            model_path = f"runs/{run_name}/{args.exp_name}_timestep_{global_step}.pt"
            torch.save(agent.state_dict(), model_path)
            print(f"model saved to {model_path}")

        if iteration % args.eval_freq == 0:
            mean_eval_return = evaluate(args, agent, eval_envs, device, deterministic = False)

            writer.add_scalar("eval/episodic_return", mean_eval_return, global_step)
            print(f"Eval: global_step={global_step}, mean_eval_return={mean_eval_return}")


            if mean_eval_return > best_return and args.save_best:
                best_return = mean_eval_return

                model_path = f"runs/{run_name}/{args.exp_name}_best.pt"
                torch.save(agent.state_dict(), model_path)
                print(f"BEST model saved to {model_path}")


        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow = args.final_learning_rate + frac * (args.learning_rate - args.final_learning_rate)
            optimizer.param_groups[0]["lr"] = lrnow


        curr_rollout_step = 0
        last_done_step = 0

        # if self.config.log_gates_score:
        #     rollout_gating_score_dict = {}

        for step in range(0, args.num_steps):
            global_step += args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            low_ind = step - args.seq_len + 1 if curr_rollout_step >= args.seq_len else last_done_step

            # ALGO LOGIC: action logic
            with torch.no_grad():

                if args.log_gates_score:
                    action, logprob, _, value, rollout_gating_score_dict = agent.get_action_and_value(obs[low_ind:step + 1].permute((1, 0, 2)))
                else:
                    action, logprob, _, value = agent.get_action_and_value(obs[low_ind:step + 1].permute((1, 0, 2)))


                # print(value.shape)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
            next_done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)



            if next_done:
                curr_rollout_step = 0
                last_done_step = step + 1
                print(f'done step: {step}')
            else: 
                curr_rollout_step += 1

            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                        writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                        writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)


            if args.log_gates_score:
                for key, value in rollout_gating_score_dict.items():
                    writer.add_scalar("charts/" + key, value, global_step)


        # bootstrap value if not done
        with torch.no_grad():

            low_ind = step - args.seq_len + 1 if curr_rollout_step >= args.seq_len else last_done_step
            # print(low_ind, step, curr_rollout_step, last_done_step)

            # print(torch.cat((obs[low_ind+1:step + 1], next_obs.unsqueeze(0)), dim = 0).permute((1, 0, 2)).shape)
            # print(torch.cat((obs[low_ind:step + 1], next_obs.unsqueeze(0)), dim = 0)[-1] == next_obs)

            # conat last next_obs which not in obs list
            next_value = agent.get_value(torch.cat((obs[low_ind+1:step + 1], next_obs.unsqueeze(0)), dim = 0).permute((1, 0, 2))) # flatten
            # print(next_value.shape)

            # next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        b_dones = dones.reshape(-1)


        agent.train() 

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            # np.random.shuffle(b_inds)
            b_inds = torch.randperm(args.batch_size)

            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                # print(f"{b_dones.shape=}")
                # print(f"{b_obs.shape=}")


                mb_seq_obs, mb_inds = get_obs_seq_with_context(b_dones, b_obs, mb_inds.to(device), args.seq_len)

                # print(f'{mb_seq_obs.shape=}')

                # print(b_obs[mb_inds].shape)

                if args.log_gates_score:
                    _, newlogprob, entropy, newvalue, optimization_gating_score_dict = agent.get_action_and_value(mb_seq_obs, b_actions[mb_inds])

                else:
                    _, newlogprob, entropy, newvalue = agent.get_action_and_value(mb_seq_obs, b_actions[mb_inds])



                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()


                grads = []
                for name, param in list(agent.named_parameters()):
                    grads.append(param.grad.view(-1))
                    writer.add_scalar(
                        f"grads/{name}", torch.linalg.norm(torch.cat(grads)).item() if len(grads) > 0 else None, global_step
                    )


            if args.target_kl is not None and approx_kl > args.target_kl:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        if args.log_gates_score:
            for key, value in optimization_gating_score_dict.items():
                writer.add_scalar("optimization/" + key, value, global_step)


        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    # if args.save_model:
    #     model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
    #     torch.save(agent.state_dict(), model_path)
    #     print(f"model saved to {model_path}")
        # from cleanrl_utils.evals.ppo_eval import evaluate

    #     episodic_returns = evaluate(
    #         model_path,
    #         make_env,
    #         args.env_id,
    #         eval_episodes=10,
    #         run_name=f"{run_name}-eval",
    #         Model=Agent,
    #         device=device,
    #         gamma=args.gamma,
    #     )
    #     for idx, episodic_return in enumerate(episodic_returns):
    #         writer.add_scalar("eval/episodic_return", episodic_return, idx)

    #     if args.upload_model:
    #         from cleanrl_utils.huggingface import push_to_hub

    #         repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
    #         repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
    #         push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")

    envs.close()
    writer.close()

if __name__ == '__main__':
    train()