import time

import lightning
import torch
import torch.nn.functional as F
from distributions import (
    BaseDistribution,
    ExponentialDistribution,
    LogNormalDistribution,
    MixtureDistribution,
    WeibullDistribution,
)
from rejection_sampling import (
    get_categorical_rejection_constant,
    get_rejection_constant,
)
from torch import Tensor, nn
from torch.distributions import Categorical


######################
### ENCODERS
######################
class TimeEncoder(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.linear1 = nn.Linear(1, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, t: Tensor) -> Tensor:
        """
        Args:
            t: (batch_size, seq_len) or (batch_size,)

        Returns:
            t: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)
        """
        t = self.linear1(t.unsqueeze(-1))
        t = self.linear2(torch.sin(t))
        return t


class GRUEncoder(nn.Module):
    def __init__(self, num_classes: int, hidden_dim: int, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.mark_emb = nn.Embedding(num_classes, hidden_dim)
        self.time_emb = TimeEncoder(hidden_dim)

        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)

    def feature_emb(self, x: Tensor, t: Tensor) -> Tensor:
        x = self.mark_emb(x) # (batch_size, seq_len, hidden_dim)
        t = self.time_emb(t) # (batch_size, seq_len, hidden_dim)
        return x + t

    def forward(self, x: Tensor, t: Tensor, h: Tensor = None) -> Tensor:
        """
        Args:
            x: long (batch_size, seq_len)
            t: float (batch_size, seq_len)
            h: (batch_size, hidden_dim)

        Returns:
            h: (batch_size, seq_len, hidden_dim)
        """
        x = self.feature_emb(x, t)
        h, _ = self.rnn(x, h)
        return h

    def step(self, x: Tensor, t: Tensor, h: Tensor) -> Tensor:
        """
        Args:
            x: (batch_size,)
            t: (batch_size,)
            h: (batch_size, hidden_dim)

        Returns:
            h: (batch_size, hidden_dim)
        """
        assert x.ndim == t.ndim == 1
        assert h.ndim == 2 and h.shape[1] == self.hidden_dim

        x = self.feature_emb(x, t)
        _, h = self.rnn(x.unsqueeze(1), h.unsqueeze(0))
        return h.squeeze(0)


class TransformerEncoder(nn.Module):
    def __init__(self, num_classes: int, hidden_dim: int, num_layers: int = 2, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.mark_emb = nn.Embedding(num_classes, hidden_dim)
        self.time_emb = TimeEncoder(hidden_dim)

        encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=1, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

    def feature_emb(self, x: Tensor, t: Tensor) -> Tensor:
        x = self.mark_emb(x) # (batch_size, seq_len, hidden_dim)
        t = self.time_emb(t) # (batch_size, seq_len, hidden_dim)
        return x + t

    def forward(self, x: Tensor, t: Tensor, h: Tensor = None) -> Tensor:
        """
        Args:
            x: long (batch_size, seq_len)
            t: float (batch_size, seq_len)
            h: (batch_size, seq_len, hidden_dim)

        Returns:
            h: (batch_size, seq_len, hidden_dim)
        """
        x = self.feature_emb(x, t)
        h = self.transformer(x)
        return h


class CNNEncoder(nn.Module):
    def __init__(self, num_classes: int, hidden_dim: int, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.mark_emb = nn.Embedding(num_classes, hidden_dim)
        self.time_emb = TimeEncoder(hidden_dim)

        self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)

    def feature_emb(self, x: Tensor, t: Tensor) -> Tensor:
        x = self.mark_emb(x) # (batch_size, seq_len, hidden_dim)
        t = self.time_emb(t) # (batch_size, seq_len, hidden_dim)
        return x + t

    def forward(self, x: Tensor, t: Tensor, h: Tensor = None) -> Tensor:
        """
        Args:
            x: long (batch_size, seq_len)
            t: float (batch_size, seq_len)
            h: (batch_size, seq_len, hidden_dim)

        Returns:
            h: (batch_size, seq_len, hidden_dim)
        """
        x = self.feature_emb(x, t)
        x = x.transpose(1, 2)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        h = x.transpose(1, 2)
        return h


class GRUEncoderCategorical(nn.Module):
    def __init__(self, num_classes: int, hidden_dim: int, time_classes: int = 10, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.mark_emb = nn.Embedding(num_classes, hidden_dim)
        self.time_emb = nn.Embedding(time_classes, hidden_dim)

        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)

    def feature_emb(self, x: Tensor, t: Tensor) -> Tensor:
        x = self.mark_emb(x) # (batch_size, seq_len, hidden_dim)
        t = self.time_emb(t.long()) # (batch_size, seq_len, hidden_dim)
        return x + t

    def forward(self, x: Tensor, t: Tensor, h: Tensor = None) -> Tensor:
        """
        Args:
            x: long (batch_size, seq_len)
            t: float (batch_size, seq_len)
            h: (batch_size, hidden_dim)

        Returns:
            h: (batch_size, seq_len, hidden_dim)
        """
        x = self.feature_emb(x, t)
        h, _ = self.rnn(x, h)
        return h

    def step(self, x: Tensor, t: Tensor, h: Tensor) -> Tensor:
        """
        Args:
            x: (batch_size,)
            t: (batch_size,)
            h: (batch_size, hidden_dim)

        Returns:
            h: (batch_size, hidden_dim)
        """
        assert x.ndim == t.ndim == 1
        assert h.ndim == 2 and h.shape[1] == self.hidden_dim

        x = self.feature_emb(x, t)
        _, h = self.rnn(x.unsqueeze(1), h.unsqueeze(0))
        return h.squeeze(0)


######################
### DECODERS
######################
class BaseDecoder(nn.Module):
    def forward(self, h: Tensor) -> tuple[Categorical, BaseDistribution]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        raise NotImplementedError

    def rejection_constant(
        self,
        mark_proposal_dist: Categorical,
        mark_target_dist: Categorical,
        time_proposal_dist: BaseDistribution,
        time_target_dist: BaseDistribution,
        exact: bool = False,
        top_k_rejection_const: float = 1,
        num_points: int = 0,
        **kwargs,
    ) -> Tensor:
        if exact:
            delta = 0.0
        else:
            delta = 0.05

        mark_const = get_categorical_rejection_constant(proposal_dist=mark_proposal_dist, target_dist=mark_target_dist, delta=delta)

        time_const = get_rejection_constant(
            proposal_dist=time_proposal_dist,
            target_dist=time_target_dist,
            exact=exact,
            top_k=top_k_rejection_const,
            num_points=num_points,
        )

        return mark_const, time_const

    def brute_force_rejection_constant(
        self,
        mark_proposal_dist: Categorical,
        mark_target_dist: Categorical,
        time_proposal_dist: BaseDistribution,
        time_target_dist: BaseDistribution,
        exact: bool = False,
        num_dense_grid: int = 1000,
        **kwargs,
    ) -> Tensor:
        """
        Args:
            mark_proposal_dist: Mark proposal distribution with shape (B, 1)
            mark_target_dist: Mark target distribution with shape (B, L)
            time_proposal_dist: Time proposal distribution with shape (B, 1)
            time_target_dist: Time target distribution with shape (B, L)

        Returns:
            rejection_constant: Rejection constant with shape (B, L)
        """
        assert mark_proposal_dist.batch_shape[0] == mark_target_dist.batch_shape[0]
        assert mark_proposal_dist.batch_shape[1] == 1
        assert len(mark_proposal_dist.batch_shape) == len(mark_target_dist.batch_shape)
        assert time_proposal_dist.param_shape[0] == time_target_dist.param_shape[0]
        assert time_proposal_dist.param_shape[1] == 1
        assert len(time_proposal_dist.param_shape) == len(time_target_dist.param_shape)

        if exact:
            delta = 0.0
            percentile = 0.05
        else:
            delta = 0.05
            percentile = 0.05

        mark_const = get_categorical_rejection_constant(proposal_dist=mark_proposal_dist, target_dist=mark_target_dist, delta=delta)

        bounds = time_target_dist.percentile(torch.tensor([percentile, 1 - percentile], device=mark_const.device), exact=exact)

        # shape [N, B, L]
        x = torch.linspace(0, 1, num_dense_grid)
        x = x.view(-1, *[1] * len(time_target_dist.param_shape[:2])).repeat(1, *time_target_dist.param_shape[:2]).to(bounds.device)

        x = x * (bounds[...,1] - bounds[...,0]) + bounds[...,0]

        time_target_pdf = time_target_dist.log_prob(x).exp()
        time_proposal_pdf = time_proposal_dist.log_prob(x).exp()

        assert time_target_pdf.isfinite().all()
        assert time_proposal_pdf.isfinite().all()

        # shape [N, B, L] -> [B, L]
        time_const = (time_target_pdf / time_proposal_pdf).nan_to_num(0).clamp(1).max(dim=0).values

        return mark_const, time_const


class ExponentialDecoder(BaseDecoder):
    percentile = 0.95

    def __init__(self, num_classes: int, hidden_dim: int, **kwargs):
        super().__init__()
        self.mark_linear = nn.Linear(hidden_dim, num_classes)
        self.time_linear = nn.Linear(hidden_dim, 1)

    def forward(self, h: Tensor) -> tuple[Categorical, ExponentialDistribution]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        mark_dist = Categorical(logits=self.mark_linear(h))
        time_dist = ExponentialDistribution(lambda_param=torch.exp(self.time_linear(h).squeeze(-1)))
        return mark_dist, time_dist


class LogNormalDecoder(BaseDecoder):
    percentile = 0.95

    def __init__(self, num_classes: int, hidden_dim: int, **kwargs):
        super().__init__()
        self.mark_linear = nn.Linear(hidden_dim, num_classes)
        self.time_linear = nn.Linear(hidden_dim, 2)

    def forward(self, h: Tensor) -> tuple[Categorical, LogNormalDistribution]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        mark_dist = Categorical(logits=self.mark_linear(h))
        param = self.time_linear(h)
        time_dist = LogNormalDistribution(mean=param[...,0], std=F.softplus(param[...,1]))
        return mark_dist, time_dist


class MixtureLogNormalDecoder(BaseDecoder):
    def __init__(self, num_classes: int, hidden_dim: int, num_components: int, **kwargs):
        super().__init__()
        self.mark_linear = nn.Linear(hidden_dim, num_classes)
        self.time_linear = nn.Linear(hidden_dim, 3 * num_components)

    def forward(self, h: Tensor) -> tuple[Categorical, MixtureDistribution]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        mark_dist = Categorical(logits=self.mark_linear(h))
        param = self.time_linear(h).unflatten(-1, (-1, 3))
        time_dist = MixtureDistribution(
            logits=param[...,0],
            component_distribution=LogNormalDistribution(
                mean=param[...,1],
                std=F.softplus(param[...,2]),
            ),
        )
        return mark_dist, time_dist


class WeibullDecoder(BaseDecoder):
    def __init__(self, num_classes: int, hidden_dim: int, **kwargs):
        super().__init__()
        self.mark_linear = nn.Linear(hidden_dim, num_classes)
        self.time_linear = nn.Linear(hidden_dim, 2)

    def forward(self, h: Tensor) -> tuple[Categorical, WeibullDistribution]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        mark_dist = Categorical(logits=self.mark_linear(h))
        param = (torch.sigmoid(self.time_linear(h)) + 0.1) * 2
        time_dist = WeibullDistribution(shape=param[...,0], scale=param[...,1])
        return mark_dist, time_dist


class WeibullMixtureDecoder(BaseDecoder):
    def __init__(self, num_classes: int, hidden_dim: int, num_components: int, **kwargs):
        super().__init__()
        self.mark_linear = nn.Linear(hidden_dim, num_classes)
        self.time_linear = nn.Linear(hidden_dim, 3 * num_components)

    def forward(self, h: Tensor) -> tuple[Categorical, MixtureDistribution]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        mark_dist = Categorical(logits=self.mark_linear(h))
        param = self.time_linear(h).unflatten(-1, (-1, 3))

        concentration = (torch.sigmoid(param[...,1]) + 0.1) * 2
        scale = (torch.sigmoid(param[...,2]) + 0.1) * 2

        time_dist = MixtureDistribution(
            logits=param[...,0],
            component_distribution=WeibullDistribution(
                shape=concentration,
                scale=scale,
            ),
        )
        return mark_dist, time_dist


class CategoricalDecoder(BaseDecoder):
    def __init__(self, num_classes: int, hidden_dim: int, time_classes: int = 10, **kwargs):
        super().__init__()
        self.mark_linear = nn.Linear(hidden_dim, num_classes)
        self.time_linear = nn.Linear(hidden_dim, time_classes)

    def forward(self, h: Tensor) -> tuple[Categorical, Categorical]:
        """
        Args:
            h: (batch_size, seq_len, hidden_dim) or (batch_size, hidden_dim)

        Returns:
            mark_dist: Mark categorical distribution
            time_dist: Time distribution
        """
        mark_dist = Categorical(logits=self.mark_linear(h))
        time_dist = Categorical(logits=self.time_linear(h))
        return mark_dist, time_dist

    def rejection_constant(
        self,
        mark_proposal_dist: Categorical,
        mark_target_dist: Categorical,
        time_proposal_dist: Categorical,
        time_target_dist: Categorical,
        time_cat_delta: int = 0.05,
        mark_cat_delta: int = 0.05,
        **kwargs,
    ) -> Tensor:
        time_const = get_categorical_rejection_constant(time_proposal_dist, time_target_dist, delta=time_cat_delta)
        mark_const = get_categorical_rejection_constant(mark_proposal_dist, mark_target_dist, delta=mark_cat_delta)
        return mark_const, time_const



######################
### TPP MODULE
######################
class TPPModule(lightning.LightningModule):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: BaseDecoder,
        lr: float,
        log_transform: bool = False,
        weight_decay: float = 1e-6,
    ):
        super().__init__()
        self.log_transform = log_transform
        self.lr = lr
        self.weight_decay = weight_decay
        self.encoder = encoder
        self.decoder = decoder
        self.save_hyperparameters(ignore=['Encoder', 'Decoder'])

    def training_step(
        self,
        batch: tuple[Tensor, Tensor, Tensor],
        batch_idx: int,
        log_prefix: str = 'train',
    ) -> Tensor:
        x, t, mask = batch

        if self.log_transform:
            t = t.log1p()

        mask = mask[:, 1:]
        x_past, x_future = x[:, :-1], x[:, 1:]
        t_past, t_future = t[:, :-1], t[:, 1:]

        h = self.encoder.forward(x_past, t_past)
        mark_dist, time_dist = self.decoder.forward(h)

        mark_loss = -mark_dist.log_prob(x_future)
        mark_loss = (mark_loss * mask).sum() / mask.sum()

        if 'categorical' in self.decoder.__class__.__name__.lower():
            time_loss = -time_dist.log_prob(t_future.long())
        else:
            time_loss = -time_dist.log_prob(t_future.float() + 1e-8)
        time_loss = (time_loss * mask).sum() / mask.sum()

        loss = mark_loss + time_loss

        self.log(f'{log_prefix}_mark_loss', mark_loss, prog_bar=True)
        self.log(f'{log_prefix}_time_loss', time_loss, prog_bar=True)
        self.log(f'{log_prefix}_loss', loss, prog_bar=True)

        return loss

    def validation_step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> None:
        self.training_step(batch, batch_idx, 'val')

    @torch.no_grad()
    def sample(
        self,
        x: Tensor,
        t: Tensor,
        num_samples: int,
        seq_len: int,
        time_profile: bool = False,
        **kwargs,
    ) -> tuple[Tensor, Tensor]:
        if time_profile:
            times = {'Encoder': 0, 'Decoder': 0, 'Sample': 0}
            start_time_initial = time.time()
            start_time = time.time()

        x = x.repeat(num_samples, 1)
        t = t.repeat(num_samples, 1)

        h = self.encoder.forward(x, t)[:,-1]

        if time_profile:
            torch.cuda.synchronize()
            times['Encoder'] += (time.time() - start_time) * 1000

        for _ in range(seq_len):
            if time_profile:
                start_time = time.time()

            mark_dist, time_dist = self.decoder.forward(h)

            if time_profile:
                torch.cuda.synchronize()
                times['Decoder'] += (time.time() - start_time) * 1000
                start_time = time.time()

            x_sample = mark_dist.sample((1,)).squeeze(0)
            t_sample = time_dist.sample((1,)).squeeze(0).float()

            if time_profile:
                torch.cuda.synchronize()
                times['Sample'] += (time.time() - start_time) * 1000
                start_time = time.time()

            x = torch.cat([x, x_sample.unsqueeze(-1)], dim=1)
            t = torch.cat([t, t_sample.unsqueeze(-1)], dim=1)

            if hasattr(self.encoder, 'step'):
                h = self.encoder.step(x_sample, t_sample, h)
            else:
                h = self.encoder.forward(x, t)[:,-1]

            if time_profile:
                torch.cuda.synchronize()
                times['Encoder'] += (time.time() - start_time) * 1000

        x, t = x[:,-seq_len:], t[:,-seq_len:]
        if self.log_transform:
            t = t.expm1()
        if time_profile:
            times['Total'] = (time.time() - start_time_initial) * 1000
            return x, t, times
        return x, t

    @torch.no_grad()
    def rejection_sample(
        self,
        x: Tensor,
        t: Tensor,
        num_samples: int,
        seq_len: int,
        leap_size: int,
        top_k: int = 1,
        time_profile: bool = False,
        brute_force: bool = False,
        exact: bool = False,
        top_k_rejection_const: float = 1,
        num_points: int = 0,
        **kwargs,
    ) -> tuple[Tensor, Tensor, Tensor, list[list[int]]]:
        if time_profile:
            times = {'Encoder': 0, 'Decoder': 0, 'Sample': 0, 'Rejection constant': 0, 'Rejection step': 0}
            start_time_initial = time.time()
            start_time = time.time()

        x = x.repeat(num_samples, 1)
        t = t.repeat(num_samples, 1)

        h = self.encoder.forward(x, t)[:,-1]

        if time_profile:
            times['Encoder'] += (time.time() - start_time) * 1000

        batch_size = x.shape[0]

        x_samples = torch.empty((batch_size, seq_len), dtype=torch.long, device=x.device)
        t_samples = torch.empty((batch_size, seq_len), dtype=torch.float, device=x.device)
        current_lengths = torch.zeros(batch_size, dtype=torch.long, device=x.device)

        select_indices = []
        mark_consts = []
        time_consts = []

        idx = torch.arange(x.shape[0], device=x.device)
        mask_arange = torch.arange(leap_size, device=x.device).unsqueeze(0)
        keep_pad = torch.zeros(batch_size, 1, device=x.device)
        _rejection_func = self.decoder.brute_force_rejection_constant if brute_force else self.decoder.rejection_constant

        shortest_len = 0
        while shortest_len < seq_len:
            if time_profile:
                start_time = time.time()

            mark_dist, time_dist = self.decoder.forward(h.unsqueeze(-2))

            if time_profile:
                torch.cuda.synchronize()
                times['Decoder'] += (time.time() - start_time) * 1000
                start_time = time.time()

            x_sample = mark_dist.sample((leap_size,)).transpose(0, 1).squeeze(2) # (batch_size, seq_len)
            t_sample = time_dist.sample((leap_size,)).transpose(0, 1).squeeze(2).float() # (batch_size, seq_len)

            if time_profile:
                torch.cuda.synchronize()
                times['Sample'] += (time.time() - start_time) * 1000
                start_time = time.time()

            h_sample = self.encoder.forward(x_sample, t_sample, h.unsqueeze(0)) # (batch_size, seq_len, hidden_dim)

            if time_profile:
                torch.cuda.synchronize()
                times['Encoder'] += (time.time() - start_time) * 1000
                start_time = time.time()

            mark_cand_dist, time_cand_dist = self.decoder.forward(h_sample)

            if time_profile:
                torch.cuda.synchronize()
                times['Decoder'] += (time.time() - start_time) * 1000
                start_time = time.time()

            # (batch_size, seq_len)
            mark_const, time_const = _rejection_func(
                mark_proposal_dist=mark_dist,
                mark_target_dist=mark_cand_dist,
                time_proposal_dist=time_dist,
                time_target_dist=time_cand_dist,
                exact=exact,
                top_k_rejection_const=top_k_rejection_const,
                num_points=num_points,
                **kwargs,
            )

            mark_consts.append(mark_const)
            time_consts.append(time_const)

            rejection_const = mark_const * time_const

            if time_profile:
                torch.cuda.synchronize()
                times['Rejection constant'] += (time.time() - start_time) * 1000
                start_time = time.time()

            accept_prob = 1 / rejection_const

            u = torch.rand_like(accept_prob)
            keep = (u < accept_prob)

            # Skip the first index since it has to be accepted
            keep = torch.cat((keep[:, 1:].float(), keep_pad), dim=1)

            if top_k == 1:
                # Select the first index that is not accepted
                ind = keep.argmin(-1)
            else:
                # Select the top k indices that are not accepted
                if keep.device.type == 'cpu':
                    raise NotImplementedError
                ind = keep.topk(top_k, -1, largest=False).indices.max(-1).values

            # Update h with the selected indicese
            h = h_sample[idx, ind]

            # Append to data
            to_add = torch.clamp(ind + 1, max=(seq_len - current_lengths))
            mask = mask_arange < to_add.unsqueeze(1)

            # Flatten the mask and gather valid samples
            batch_indices, sample_indices = mask.nonzero(as_tuple=True)
            target_indices = current_lengths.repeat_interleave(to_add) + sample_indices - sample_indices.min()

            # Write valid samples into the preallocated tensors
            x_samples[batch_indices, target_indices] = x_sample[batch_indices, sample_indices]
            t_samples[batch_indices, target_indices] = t_sample[batch_indices, sample_indices]

            # Update current lengths
            current_lengths += to_add
            shortest_len = current_lengths.min()

            # Keep track of other stats
            select_indices.append(ind)

            if time_profile:
                torch.cuda.synchronize()
                times['Rejection step'] += (time.time() - start_time) * 1000
                start_time = time.time()

        if self.log_transform:
            t_samples = t_samples.expm1()
        if time_profile:
            torch.cuda.synchronize()
            times['Total'] = (time.time() - start_time_initial) * 1000
            return x_samples, t_samples, select_indices, times, mark_consts, time_consts
        return x_samples, t_samples, select_indices

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
