import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import time
import copy
from tqdm import tqdm, trange
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
import gym
from optim.lamb import Lamb
from d4rl import infos
from models.discriminator import ContrastiveDiscriminator
from pytorch_metric_learning.losses import NTXentLoss
from torch.utils.data import TensorDataset, DataLoader


class EMA():
    '''
        empirical moving average
    '''
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new
def expectile_loss(pred: torch.Tensor, target: torch.Tensor, expectile: float = 0.7):
    """
    Compute the expectile loss between predicted and target values.

    Args:
        pred (torch.Tensor): Predicted values (e.g., value function V(s)).
        target (torch.Tensor): Target values (e.g., Q(s, a)).
        expectile (float): Expectile parameter τ (typically between 0.5 and 1.0).

    Returns:
        torch.Tensor: Scalar loss.
    """
    diff = target - pred
    weight = torch.where(diff > 0, expectile, 1 - expectile)
    loss = weight * diff.pow(2)
    return loss.mean()

class Trainer:

    def __init__(self,
                 model,
                 critic,
                 batch_size,
                 tau,
                 discount,
                 dataloader,
                 loss_fn,
                 eval_fns=None,
                 eta=1.0,
                 lr=3e-4,
                 weight_decay=1e-4,
                 lr_decay=False,
                 max_iters=100000,
                 grad_norm=1.0,
                 scale=1.0,
                 k_rewards=False,
                 device='cuda',
                 with_att_entropy_loss=False,
                 att_entropy_loss_weight=0.1,
                 att_entropy_loss_weight_decay=10,
                 eval_every_n_epoch=15,
                 state_mean=None,
                 state_std=None,
                 env_name=None,
                 q_loss_mean=None,
                 warmup_steps=10000,
                 q_scale=1.0,
                 q_min=0,
                 num_steps=100,
                 actor_optimizer='lamb',
                 return_quantiles=None,
                 contrastive_frequency=1,
                 beta=0.1,
                 bin_ranges=None,
                 adv_ranges=None,
                 post_train=False,
                 top_return_buffer=None,
                contrastive_data={}):

        self.device = device
        self.with_att_entropy_loss = with_att_entropy_loss
        self.init_att_entropy_loss_weight = att_entropy_loss_weight
        self.att_loss_weight = att_entropy_loss_weight
        self.max_iters = max_iters
        self.q_scale = q_scale
        self.q_min = q_min

        self.num_steps = num_steps
        self.return_quantiles=return_quantiles
        self.beta=beta
        self.post_train=post_train
        self.top_return_buffer= top_return_buffer
        self.contrastive_data=contrastive_data
        self.iter=0
        self.contrastive_frequency = contrastive_frequency

        self.bin_ranges = bin_ranges
        self.adv_ranges = adv_ranges

        self.att_entropy_loss_weight_decay = att_entropy_loss_weight_decay
        self.eval_every_n_epoch = eval_every_n_epoch
        self.state_mean = state_mean
        self.state_std = state_std
        self.env_name = env_name
        self.q_loss_mean = q_loss_mean

        self.actor = model
        self.actor_optimizer = Lamb(self.actor.parameters(), lr=lr, weight_decay=weight_decay) if actor_optimizer=="lamb" else torch.optim.Adam(self.actor.parameters(), lr=lr, weight_decay=weight_decay)

        # self.discriminator= ContrastiveDiscriminator(256).to(self.device)

        self.warmup_steps = warmup_steps

        self.best_model_state_dict = None
        self.best_critic_state_dict = None


        self.actor_lr_scheduler =torch.optim.lr_scheduler.LambdaLR(self.actor_optimizer, lambda s: min((s + 1) / self.warmup_steps, 1))

        self.batch_size = batch_size
        self.dataloader = dataloader
        self.data_iter = iter(dataloader)
        self.loss_fn = loss_fn
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.tau = tau
        self.discount = discount
        self.grad_norm = grad_norm
        self.eta = eta
        self.lr_decay = lr_decay
        self.scale = scale
        self.k_rewards = k_rewards

        self.start_time = time.time()
        self.step = 0



        # Initialize tracking variables for early stopping
        self.patience_counter = 0
        self.should_stop_critic = False  # Initialize the flag


    def simclr_loss_return_based(self,sample_state, sample_actions, sample_rtg, sample_adv,num_samples_simclr=20,bin_edges=None,contrastive_type="return"):

        B, N, state_dim = sample_state.shape
        _, _, action_dim = sample_actions.shape

        num_positives = num_samples_simclr
        ixs = np.concatenate([np.random.choice(np.arange(N), size=num_positives).reshape(1, -1) for _ in range(B)], axis=0)
        torch_ixs = torch.from_numpy(ixs).to(device=self.device)
        pos_states = torch.cat([sample_state[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])
        pos_actions = torch.cat([sample_actions[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])
        pos_rtg = torch.cat([sample_rtg[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])
        pos_adv = torch.cat([sample_adv[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])

        if contrastive_type == "return":
            labels=torch.FloatTensor(np.digitize(pos_rtg.cpu().numpy(), bin_edges).flatten()).to(self.device)
        elif contrastive_type == "advantage":
            labels=torch.FloatTensor(np.digitize(pos_adv.cpu().numpy(), bin_edges).flatten()).to(self.device)

        pos_latents = self.actor.get_latent(pos_states, pos_actions, pos_rtg)
        proj_out_dim = pos_latents.shape[-1]

        # labels = torch.arange(0, B).reshape(-1, 1).repeat(1, num_positives).flatten().to(self.device)
        simclr_loss = NTXentLoss(temperature=self.tau)

        dataset = TensorDataset(pos_latents.reshape(-1, proj_out_dim), labels)
        dataloader = DataLoader(dataset, batch_size=B, shuffle=True)

        # loss = simclr_loss(embeddings = pos_latents.reshape(-1, proj_out_dim), labels=labels)

        total_loss = 0.0
        for batch_latents, batch_labels in dataloader:
            batch_latents = batch_latents.to(self.device)
            batch_labels = batch_labels.to(self.device)

            loss = simclr_loss(embeddings=batch_latents, labels=batch_labels)
            total_loss += loss

        loss = total_loss / len(dataloader)


        return loss

    def simclr_loss_return_based_2(self,sample_state, sample_actions, sample_rtg, sample_adv,num_samples_simclr=20,bin_edges=None,contrastive_type="return",with_soft_contrastive=False):

        B, N, _ = sample_state.shape

        # Flatten
        flat_state = sample_state.reshape(-1, sample_state.shape[-1])  # [B*N, state_dim]
        flat_actions = sample_actions.reshape(-1, sample_actions.shape[-1])  # [B*N, action_dim]
        flat_rtg = sample_rtg.reshape(-1)  # [B*N]
        flat_adv = sample_adv.reshape(-1)  # [B*N]

        if contrastive_type == "return":
            all_labels = np.digitize(flat_rtg.cpu().numpy(), bin_edges)
        else:
            all_labels = np.digitize(flat_adv.cpu().numpy(), bin_edges)

        label_to_indices = {}
        for idx, label in enumerate(all_labels):
            label_to_indices.setdefault(label, []).append(idx)

        selected_indices = []
        for label, indices in label_to_indices.items():
            if len(indices) >= 2:
                chosen = np.random.choice(indices, size= np.minimum(num_samples_simclr,len(indices)), replace=False)
                selected_indices.extend(chosen)

        selected_indices = torch.LongTensor(selected_indices).to(self.device)

        states, actions, rtg, adv, labels= (
            flat_state[selected_indices],
            flat_actions[selected_indices],
            flat_rtg[selected_indices].unsqueeze(1),
            flat_adv[selected_indices],
            torch.LongTensor(all_labels[selected_indices.cpu().numpy()]).to(self.device)
        )
        pos_latents = self.actor.get_latent(states, actions, rtg)

        proj_out_dim = pos_latents.shape[-1]

        simclr_loss = NTXentLoss(temperature=self.tau)
        loss = simclr_loss(embeddings=pos_latents, labels=labels)

        if with_soft_contrastive:
            soft_contrast_loss =self.soft_contrastive_loss(pos_latents, rtg.squeeze(1))
            wandb.log({f"soft_contrastive_loss":soft_contrast_loss}, step=self.step)
            loss += soft_contrast_loss


        return loss

    def soft_contrastive_loss(self,embeddings, values, eps=1e-6):
        """
        embeddings: [B, D]
        values: [B] - scalar values (returns or advantages)
        """
        B = embeddings.size(0)

        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=-1)

        # Compute cosine similarity matrix: [B, B]
        sim_matrix = torch.matmul(embeddings, embeddings.T)

        # Compute advantage/return distance matrix: [B, B]
        value_diff = values.unsqueeze(0) - values.unsqueeze(1)  # [B, B]
        target_sim = 1.0 / (1.0 + torch.abs(value_diff))  # Soft target: closer values → higher similarity

        # Optional: mask self-comparisons if needed
        mask = ~torch.eye(B, dtype=torch.bool, device=embeddings.device)

        # Compute MSE between predicted and target similarity
        loss = F.mse_loss(sim_matrix[mask], target_sim[mask])

        return loss



    def get_post_train_loss(self,sample_state, sample_actions, sample_rtg,sample_adv,num_samples_simclr=3,bin_edges=None):

        if self.contrastive_data["contrastive_type"] == "return":
            top_buffer = self.contrastive_data["top_return_buffer"]
            top_states = (torch.FloatTensor(top_buffer["states"]).to(self.device))
            top_actions = (torch.FloatTensor(top_buffer["actions"]).to(self.device))
            top_rtg = (torch.FloatTensor(top_buffer["returns_to_go"]).to(self.device))

            low_buffer = self.contrastive_data["low_return_buffer"]
            low_states = (torch.FloatTensor(low_buffer["states"]).to(self.device))
            low_actions = (torch.FloatTensor(low_buffer["actions"]).to(self.device))
            low_rtg = (torch.FloatTensor(low_buffer["returns_to_go"]).to(self.device))

        elif self.contrastive_data["contrastive_type"] == "advantage":
            top_buffer = self.contrastive_data["top_advantage_buffer"]
            top_states = (torch.FloatTensor(top_buffer["states"]).to(self.device))
            top_actions = (torch.FloatTensor(top_buffer["actions"]).to(self.device))
            top_rtg = (torch.FloatTensor(top_buffer["advantage"]).to(self.device))


            low_buffer = self.contrastive_data["low_advantage_buffer"]
            low_states = (torch.FloatTensor(low_buffer["states"]).to(self.device))
            low_actions = (torch.FloatTensor(low_buffer["actions"]).to(self.device))
            low_rtg = (torch.FloatTensor(low_buffer["advantage"]).to(self.device))


        else:
            raise ValueError(f"Unsupported contrastive type: {self.contrastive_data['contrastive_type']}")

        top_latents = self.actor.get_latent(top_states, top_actions, top_rtg.unsqueeze(-1))

        B, N, state_dim = sample_state.shape
        state_action_latents = self.actor.get_latent(sample_state, sample_actions, sample_rtg).reshape(B*N,-1)

        # 1. Normalize latents (L2 norm = 1)
        state_action_latents = F.normalize(state_action_latents, p=2, dim=-1)  # shape: [B*N, D]
        top_latents = F.normalize(top_latents, p=2, dim=-1)  # shape: [T, D]

        # 2. Compute cosine similarity (dot product of normalized vectors)
        sim = torch.matmul(state_action_latents, top_latents.T)  # shape: [B*N, T]

        # 3. Softmax-weighted top latent target
        tau = 1
        weights = F.softmax(sim / tau, dim=1)  # shape: [B*N, T]
        target_latents = weights @ top_latents  # shape: [B*N, D]

        # 4. Compute MSE loss between current and target embeddings
        loss = F.mse_loss(state_action_latents, target_latents)
        return loss

    def get_post_train_loss_contrastive(self,sample_state, sample_actions, sample_rtg,sample_adv,num_samples_simclr=3,bin_edges=None):

        if self.contrastive_data["contrastive_type"] == "return":
            top_buffer = self.contrastive_data["top_return_buffer"]
            top_states = (torch.FloatTensor(top_buffer["states"]).to(self.device))
            top_actions = (torch.FloatTensor(top_buffer["actions"]).to(self.device))
            top_rtg = (torch.FloatTensor(top_buffer["returns_to_go"]).to(self.device))

            low_buffer = self.contrastive_data["low_return_buffer"]
            low_states = (torch.FloatTensor(low_buffer["states"]).to(self.device))
            low_actions = (torch.FloatTensor(low_buffer["actions"]).to(self.device))
            low_rtg = (torch.FloatTensor(low_buffer["returns_to_go"]).to(self.device))

        elif self.contrastive_data["contrastive_type"] == "advantage":
            top_buffer = self.contrastive_data["top_advantage_buffer"]
            top_states = (torch.FloatTensor(top_buffer["states"]).to(self.device))
            top_actions = (torch.FloatTensor(top_buffer["actions"]).to(self.device))
            top_rtg = (torch.FloatTensor(top_buffer["advantage"]).to(self.device))


            low_buffer = self.contrastive_data["low_advantage_buffer"]
            low_states = (torch.FloatTensor(low_buffer["states"]).to(self.device))
            low_actions = (torch.FloatTensor(low_buffer["actions"]).to(self.device))
            low_rtg = (torch.FloatTensor(low_buffer["advantage"]).to(self.device))

        else:
            raise ValueError(f"Unsupported contrastive type: {self.contrastive_data['contrastive_type']}")

        # === 2. Get anchor, positive, and negative embeddings ===
        B, N, _ = sample_state.shape
        anchor_latents = self.actor.get_latent(sample_state, sample_actions, sample_rtg).reshape(B * N, -1)
        num_samples = 1000  # Specify the number of samples
        positive_latents = self.actor.get_latent(top_states, top_actions, top_rtg.unsqueeze(-1))
        indices = torch.randperm(positive_latents.size(0))[:num_samples]
        positive_latents = positive_latents[indices]
        negative_latents = self.actor.get_latent(low_states, low_actions, low_rtg.unsqueeze(-1))
        indices = torch.randperm(negative_latents.size(0))[:num_samples]
        negative_latents = negative_latents[indices]

        # === 3. Normalize all embeddings ===
        anchor_latents = F.normalize(anchor_latents,  dim=-1)
        positive_latents = F.normalize(positive_latents,  dim=-1)
        negative_latents = F.normalize(negative_latents,  dim=-1)

        # === 4. Stack all embeddings and create labels ===
        # Each anchor is matched with a positive of the same label
        all_embeddings = torch.cat([anchor_latents, positive_latents, negative_latents], dim=0)
        num_anchors = anchor_latents.shape[0]
        num_positives = positive_latents.shape[0]
        num_negatives = negative_latents.shape[0]

        # Create labels: match anchors with their corresponding positives
        labels = torch.cat([
            torch.ones(num_anchors, device=self.device),  # anchors: label 0 to A-1
            torch.ones(num_positives, device=self.device),  # positives: same label space (repeat/overlap is okay)
            torch.zeros(num_negatives, device=self.device)
            # negatives: random labels
        ])

        # === 5. Compute NT-Xent loss ===
        ntxent_loss_fn = NTXentLoss(temperature=0.07)

        dataset= TensorDataset(all_embeddings, labels)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        total_loss = 0.0
        for batch_embeddings, batch_labels in dataloader:
            batch_embeddings = batch_embeddings.to(self.device)
            batch_labels = batch_labels.to(self.device)

            loss = ntxent_loss_fn(embeddings=batch_embeddings, labels=batch_labels)
            total_loss += loss
            torch.cuda.empty_cache()

        loss = total_loss / len(dataloader)
        return loss

    def get_avoidance_loss(self,sample_state, sample_actions, sample_rtg,sample_adv,num_samples_simclr=3,bin_edges=None):

        if self.contrastive_data["contrastive_type"] == "return":
            top_buffer = self.contrastive_data["top_return_buffer"]
            top_states = (torch.FloatTensor(top_buffer["states"]).to(self.device))
            top_actions = (torch.FloatTensor(top_buffer["actions"]).to(self.device))
            top_rtg = (torch.FloatTensor(top_buffer["returns_to_go"]).to(self.device))

            low_buffer = self.contrastive_data["low_return_buffer"]
            low_states = (torch.FloatTensor(low_buffer["states"]).to(self.device))
            low_actions = (torch.FloatTensor(low_buffer["actions"]).to(self.device))
            low_rtg = (torch.FloatTensor(low_buffer["returns_to_go"]).to(self.device))

        elif self.contrastive_data["contrastive_type"] == "advantage" or self.contrastive_data["contrastive_type"] == "advantage_with_avoidance":
            top_buffer = self.contrastive_data["top_advantage_buffer"]
            top_states = (torch.FloatTensor(top_buffer["states"]).to(self.device))
            top_actions = (torch.FloatTensor(top_buffer["actions"]).to(self.device))
            top_rtg = (torch.FloatTensor(top_buffer["advantage"]).to(self.device))


            low_buffer = self.contrastive_data["low_advantage_buffer"]
            low_states = (torch.FloatTensor(low_buffer["states"]).to(self.device))
            low_actions = (torch.FloatTensor(low_buffer["actions"]).to(self.device))
            low_rtg = (torch.FloatTensor(low_buffer["advantage"]).to(self.device))


        else:
            raise ValueError(f"Unsupported contrastive type: {self.contrastive_data['contrastive_type']}")

        low_latents = self.actor.get_latent(low_states, low_actions, low_rtg.unsqueeze(-1))

        B, N, state_dim = sample_state.shape
        state_action_latents = self.actor.get_latent(sample_state, sample_actions, sample_rtg).reshape(B*N,-1)

        # 1. Normalize latents (L2 norm = 1)
        state_action_latents = F.normalize(state_action_latents, p=2, dim=-1)  # shape: [B*N, D]
        low_latents = F.normalize(low_latents, p=2, dim=-1)  # shape: [T, D]

        # 2. Compute cosine similarity (dot product of normalized vectors)
        sim = torch.matmul(state_action_latents, low_latents.T)  # shape: [B*N, T]
        loss= sim.max(dim=1).values.mean()


        return 1+loss

    def simclr_loss(self,sample_state, sample_actions, sample_rtg,num_samples_simclr=3):

        B, N, state_dim = sample_state.shape
        _, _, action_dim = sample_actions.shape

        num_positives = num_samples_simclr
        ixs = np.concatenate([np.random.choice(np.arange(N), size=num_positives).reshape(1, -1) for _ in range(B)], axis=0)
        torch_ixs = torch.from_numpy(ixs).to(device=self.device)
        pos_states = torch.cat([sample_state[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])
        pos_actions = torch.cat([sample_actions[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])
        pos_rtg = torch.cat([sample_rtg[b_ix, torch_ixs[b_ix]].unsqueeze(0) for b_ix in range(B)])

        pos_latents = self.actor.get_latent(pos_states, pos_actions, pos_rtg)
        proj_out_dim = pos_latents.shape[-1]

        labels = torch.arange(0, B).reshape(-1, 1).repeat(1, num_positives).flatten().to(self.device)
        simclr_loss = NTXentLoss(temperature=0.1)

        loss = simclr_loss(embeddings = pos_latents.reshape(-1, proj_out_dim), labels=labels)
        return loss

    def train_iteration(self, num_steps, logger, iter_num=0, log_writer=None):
        logs = dict()
        train_start = time.time()

        self.actor.train()
        loss_metric = {
            'bc_loss': [],
            'ql_loss': [],
            'actor_loss': [],
            'target_q_mean': [],
            'att_entropy_loss': [],
            'contrastive_loss':[],
            'post_train_loss':[],
        }



        for iter, data in enumerate(tqdm(self.dataloader)):
            loss_metric = self.train_step(data, log_writer, loss_metric,post_train=self.post_train)



        logger.record_tabular('BC Loss', np.mean(loss_metric['bc_loss']))
        logger.record_tabular('QL Loss', np.mean(loss_metric['ql_loss']))
        logger.record_tabular('Actor Loss', np.mean(loss_metric['actor_loss']))
        logger.record_tabular('Target Q Mean', np.mean(loss_metric['target_q_mean']))
        logger.record_tabular('Att Entropy Loss', np.mean(loss_metric['att_entropy_loss']))
        logger.record_tabular('Contrastive Loss', np.mean(loss_metric['contrastive_loss']))
        logger.record_tabular('Post train Loss', np.mean(loss_metric['post_train_loss']))
        logger.dump_tabular()

        logs['time/training'] = time.time() - train_start
        logs["patience_counter"] = self.patience_counter

        eval_start = time.time()

        if iter_num % self.eval_every_n_epoch == 0:
            self.actor.eval()
            for eval_fn in self.eval_fns:
                outputs = eval_fn(self.actor, None)
                for k, v in outputs.items():
                    logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        logger.log('=' * 80)
        logger.log(f'Iteration {iter_num}')
        logger.log(f'Step {self.step}')
        best_ret = -10000
        best_nor_score = -10000
        for k, v in logs.items():
            if 'return_mean' in k:
                best_ret = max(best_ret, float(v))
            if 'normalized_score' in k:
                normalized_score = float(v)
                best_nor_score = max(best_nor_score, float(v))
                wandb.log({"normalized_score": float(v)}, step=self.step)
            logger.record_tabular(k, float(v))
        logger.record_tabular('Current actor learning rate', self.actor_optimizer.param_groups[0]['lr'])
        logger.dump_tabular()

        logs['Best_return_mean'] = best_ret
        logs['Best_normalized_score'] = best_nor_score
        return logs


    def scale_down_weight(self, decay_rate):
        # self.att_loss_weight = (self.init_att_entropy_loss_weight * (1 - self.step / self.max_iters) ** decay_rate)
        self.att_loss_weight =self.att_loss_weight*decay_rate



    def train_step(self, data ,log_writer=None, loss_metric={},post_train=False):
        '''
            Train the model for one step
            states: (batch_size, max_len, state_dim)
        '''
        states, actions, rewards, next_states, action_target, dones, rtg, timesteps, attention_mask = data[
            "states"], data["actions"], data["rewards"], data["next_states"], data["target_a"], data["terminals"], \
        data["returns_to_go"], data["timesteps"], data["traj_mask"]
        traj_returns = data["traj_returns"]
        adv= data["advantage"]

        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        action_target = action_target.to(self.device)
        dones = dones.to(self.device)
        rtg = rtg.to(self.device)  # B x T x 1
        rtg = rtg.float()
        adv = adv.to(self.device)  # B x T x 1
        adv = adv.float()
        timesteps = timesteps.to(self.device)
        attention_mask = attention_mask.to(self.device)
        traj_returns = traj_returns.to(self.device)
        actor_grad_norms = torch.tensor(0)

        # action_target = torch.clone(actions)

        action_dim = actions.shape[-1]

        state_mean_tensor = torch.from_numpy(self.state_mean).to(actions.device)
        state_std_tensor = torch.from_numpy(self.state_std).to(actions.device)
        batch_size, context_len,act_dim = actions.shape
        act_dim = actions.shape[2]
        '''Policy Training'''
        state_preds, action_preds, reward_preds, attentions = self.actor.forward(
            states, actions, rewards, targets=action_target, returns_to_go=rtg[:, :-1], timesteps=timesteps, attention_mask=attention_mask
        )


        # reps = self.actor.get_representation(states, actions, rewards, rtg[:, :-1], timesteps, attention_mask)


        att = torch.stack(attentions, 1).mean(2)  # mean over heads
        B, layers, T, _ = att.shape
        repeats= T // context_len
        entropy_mask = torch.repeat_interleave(attention_mask, repeats=repeats, dim=1).unsqueeze(1).repeat(1, layers, 1)
        # entropy_mask=entropy_mask.unsqueeze(-1).repeat(1, 1, 1, 60)
        log_att = torch.log(att + 1e-6)
        # diag_mask = np.arange(T)
        entropy = -((att * log_att))
        # entropy[..., diag_mask, diag_mask] = float('nan')
        entropy = entropy.nansum(-1)
        entropy = entropy.masked_fill(entropy_mask == 0, float('nan'))
        entropy = torch.nanmean(entropy, dim=(0, -1))

        for i in range(entropy.shape[0]):
            wandb.log({f"attention{i+1}": entropy[i]}, step=self.step)

        wandb.log({"attention_mean": torch.mean(entropy)}, step=self.step)
        action_preds = action_preds
        action_preds_ = action_preds.reshape(-1, action_dim)[attention_mask.reshape(-1) > 0]
        action_target_ = action_target.reshape(-1, action_dim)[attention_mask.reshape(-1) > 0]
        state_preds = state_preds[:, :-1]
        state_target = states[:, 1:]  # next state
        states_loss = ((state_preds - state_target) ** 2)[attention_mask[:, :-1] > 0].mean()
        if reward_preds is not None:
            reward_preds = reward_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            reward_target = rewards.reshape(-1, 1)[attention_mask.reshape(-1) > 0]/self.scale
            rewards_loss = F.mse_loss(reward_preds, reward_target)

        else:
            rewards_loss = 0

        if self.with_att_entropy_loss:
            att_entropy_loss = entropy.mean()/3.15  # mena
            # att_entropy_loss= entropy[0] # first layer
            # self.scale_down_weight(self.att_entropy_loss_weight_decay)
            att_entropy_loss = att_entropy_loss * self.att_loss_weight
            # att_entropy_loss=torch.clamp(att_entropy_loss, min=0.1)
            wandb.log({"att_entropy_loss_weight": self.att_loss_weight}, step=self.step)
        else:
            att_entropy_loss = 0
            self.att_loss_weight = 0

        contrastive_loss =0
        avoidance_loss=0
        if self.contrastive_data["contrastive_type"] != "no_contrast" and self.iter % self.contrastive_frequency ==0:
            if self.bin_ranges is not None and not post_train:
                if self.contrastive_data["contrastive_type"] == "return":
                    contrastive_loss = self.simclr_loss_return_based(states, actions, rtg[:, :-1], adv[:, :-1], num_samples_simclr=self.contrastive_data["num_samples_simclr"],
                                                                     bin_edges=self.bin_ranges,contrastive_type="return")
                elif self.contrastive_data["contrastive_type"] == "simclr":
                    contrastive_loss = self.simclr_loss(states, actions, rtg[:, :-1], num_samples_simclr=self.contrastive_data["num_samples_simclr"])
                elif self.contrastive_data["contrastive_type"] == "advantage":

                    contrastive_loss = self.simclr_loss_return_based_2(states, actions, rtg[:, :-1], adv[:, :-1], num_samples_simclr=self.contrastive_data["num_samples_simclr"],
                                                                     bin_edges= self.adv_ranges,contrastive_type="advantage",)
                elif self.contrastive_data["contrastive_type"] == "advantage_with_avoidance":

                    contrastive_loss = self.simclr_loss_return_based_2(states, actions, rtg[:, :-1], adv[:, :-1], num_samples_simclr=self.contrastive_data["num_samples_simclr"],
                                                                     bin_edges= self.adv_ranges,contrastive_type="advantage",
                                                                       with_soft_contrastive=self.contrastive_data["with_soft_contrastive"])
                    avoidance_loss =self.get_avoidance_loss(states, action_preds, rtg[:, :-1],adv[:, :-1], num_samples_simclr=self.contrastive_data["num_samples_simclr"],bin_edges=self.bin_ranges)
                    contrastive_loss+= avoidance_loss
                else:
                    raise ValueError(f"Unsupported contrastive type: {self.contrastive_data['contrastive_type']}")


                # embeddings = self.actor.get_representation(states, actions, rewards, rtg[:, :-1], timesteps,
                #                                            attention_mask=attention_mask)
                # batch_return_bins = np.digitize(
                #     np.digitize(rtg[:, :-1].flatten().cpu().detach().numpy(), self.return_quantiles[1:-1], right=False),
                #     self.return_quantiles[1:-1],
                #     right=False)  # [0, ..., num_bins-1]
                # contrast_loss=self.compute_rcrl_loss(embeddings, batch_return_bins, self.discriminator)
        post_train_loss=0
        if post_train:
            # post_train_loss= self.get_post_train_loss(states, action_preds, rtg[:, :-1],adv[:, :-1], num_samples_simclr=3, bin_edges=self.bin_ranges)
            post_train_loss=self.get_post_train_loss_contrastive(states, action_preds, rtg[:, :-1],adv[:, :-1], num_samples_simclr=3,bin_edges=self.bin_ranges)



        action_loss = F.mse_loss(action_preds, action_target, reduction="none")
        action_loss = action_loss.mean(dim=2)

        bc_loss = (action_loss )
        bc_loss = bc_loss.view(-1, 1)[attention_mask.reshape(-1) > 0].mean()

        # F.mse_loss(action_preds_, action_target_)+ states_loss + rewards_loss
        # bc_loss = loss + states_loss + rewards_loss + att_entropy_loss
        if self.post_train:
            actor_loss = -post_train_loss +bc_loss
        else:
            # if self.iter <= 0 and type(contrastive_loss) != int:
            #     actor_loss = contrastive_loss
            # else:
            #     actor_loss = bc_loss + (self.beta * contrastive_loss) +att_entropy_loss
            actor_loss = bc_loss + (self.beta * contrastive_loss) +att_entropy_loss  + states_loss + rewards_loss
                # if self.iter==self.post_train_step:
                #     self.post_train =True


        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        if self.grad_norm > 0:
            actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
        # actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
        self.actor_optimizer.step()
        wandb.log({"actor_learning_rate": self.actor_optimizer.param_groups[0]['lr']}, step=self.step)
        if hasattr(self.actor.transformer[0].attention, "sigmas"):
            wandb.log({"sigma_0": self.actor.transformer[0].attention.sigmas[0]}, step=self.step)
            wandb.log({"sigma_1": self.actor.transformer[0].attention.sigmas[1]}, step=self.step)
            wandb.log({"sigma_2": self.actor.transformer[0].attention.sigmas[2]}, step=self.step)
        self.actor_lr_scheduler.step()



        self.step += 1
        # mean_batch_entropy.append(entropy)
        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean(action_loss).detach().cpu().item()
            self.diagnostics['training/actor_grad_norms'] = actor_grad_norms.item()
            self.diagnostics['training/attn_entropy_weight'] = self.att_loss_weight
            self.diagnostics['training/beta'] = self.beta
            if hasattr(self.actor.transformer[0].attention, "sigmas"):
                self.diagnostics['training/sigma_0'] = self.actor.transformer[0].attention.sigmas[0].item()
                self.diagnostics['training/sigma_1'] = self.actor.transformer[0].attention.sigmas[1].item()
                self.diagnostics['training/sigma_2'] = self.actor.transformer[0].attention.sigmas[2].item()
            wandb.log({"actor_action_loss": torch.mean(action_loss).detach().cpu().item()}, step=self.step)
            wandb.log({"contrastive_loss": contrastive_loss.detach().cpu().item() if type(contrastive_loss) != int else 0}, step=self.step)
            wandb.log({"avoidance_loss": avoidance_loss.detach().cpu().item() if type(avoidance_loss) != int else 0}, step=self.step)
            wandb.log({"actor_state_loss": torch.mean(states_loss).detach().cpu().item()}, step=self.step)

        if log_writer is not None:
            if self.grad_norm > 0:
                log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
            log_writer.add_scalar('BC Loss', bc_loss.item(), self.step)
            # log_writer.add_scalar('QL Loss', q_loss.item(), self.step)


        loss_metric['bc_loss'].append(bc_loss.item())
        # loss_metric['ql_loss'].append(q_loss.item() if type(q_loss) != int else q_loss)
        loss_metric['actor_loss'].append(actor_loss.item())
        # loss_metric['target_q_mean'].append(0)
        loss_metric['att_entropy_loss'].append(att_entropy_loss.item() if type(att_entropy_loss) != int else 0)
        loss_metric['contrastive_loss'].append(contrastive_loss.item() if type(contrastive_loss) != int else 0 )
        loss_metric['post_train_loss'].append(post_train_loss.item() if type(post_train_loss) != int else 0 )


        return loss_metric
