import torch
import torch.nn as nn
import tqdm

from modules import SinusoidalPositionEmbeddings, diagonalize_and_scale, in_batch_negative_sampling_sample, in_batch_negative_sampling
from diffusion import SteerRecDiffusion
from models.SASRec._model import SASRec
import numpy as np
class SteerRec(SASRec):
    def __init__(self, config: dict):
        super().__init__(config)
        self.config = config
        self.diff = SteerRecDiffusion(config=config)
        self.step_nn = nn.Sequential(
            SinusoidalPositionEmbeddings(config['hidden_size']),
            nn.Linear(config['hidden_size'], config['hidden_size'] * 2),
            nn.GELU(),
            nn.Linear(config['hidden_size'] * 2, config['hidden_size']),
        )
        self.denoise_nn = nn.Sequential(
            nn.Linear(config['hidden_size'] * 3, config['hidden_size'])
        )

    def get_embeddings(self, items):
        """Retrieves item embeddings for given item IDs."""
        return self.item_embeddings(items)

    def get_all_embeddings(self, device=None):
        """Retrieves all item embeddings as a tensor."""
        return self.item_embeddings.weight

    def load_item_embeddings(self):
        """Initializes a new, learnable item embedding layer."""
        self.item_embeddings = nn.Embedding(
            num_embeddings=self.config['item_num'] + 1,
            embedding_dim=self.config['hidden_size'],
            padding_idx=0
        )
        nn.init.normal_(self.item_embeddings.weight, 0, 1)

    def denoise(self, x, h, step):
        t = self.step_nn(step)
        if len(x.shape) < 3:
            return self.denoise_nn(torch.cat((x, h, t), dim=1))
        else:
            B, N, D = x.shape
            x = x.view(-1, D)
            h_expanded = h.unsqueeze(1).repeat(1, N, 1).view(-1, D)
            t_expanded = t.unsqueeze(1).repeat(1, N, 1).view(-1, D)
            input = torch.cat((x, h_expanded, t_expanded), dim=1)
            return self.denoise_nn(input).view(B, N, D)

    def get_neg_representation(self, neg_item_embeddings):
        h_neg_raw = torch.mean(neg_item_embeddings, dim=1)
        return h_neg_raw.detach()
        

    def forward(self, batch):
        h_pos = self.get_representation(batch)
        labels = batch['labels'].view(-1)
        x_pos = self.get_embeddings(labels)

        # Prepare negative condition (h_neg) from in-batch negatives.
        neg_item_ids = self._generate_negative_samples(batch)
        neg_item_embeddings = self.get_embeddings(neg_item_ids)
        h_neg = self.get_neg_representation(neg_item_embeddings) 

        t = torch.randint(0, self.config['timesteps'], (labels.shape[0],), device=h_pos.device).long()
        
        loss, loss_dict = self.diff.p_losses(
            denoise_model=self,
            x_start_pos=x_pos,
            h_pos=h_pos,
            h_neg=h_neg,
            t=t,
            loss_type=self.config['loss_type']
        )
        
        return {'loss': loss, 'reconstruction_loss': loss_dict['reconstruction_loss'], 'alignment_loss': loss_dict['alignment_loss']}

    def predict(self, batch, n_return_sequences=1):
        h_pos = self.get_representation(batch)

        # Prepare negative condition (h_neg) from random negatives for inference.
        num_neg_for_infer = self.config.get('neg_samples', 64) 
        item_pool_start = self.config['select_pool'][0]
        item_pool_end = self.config['select_pool'][1]
        batch_size = h_pos.shape[0]
        random_ids = torch.randint(
            item_pool_start, item_pool_end,
            (batch_size, num_neg_for_infer),
            device=h_pos.device
        )
        
        neg_embeddings = self.get_embeddings(random_ids)
        h_neg = self.get_neg_representation(neg_embeddings) 
        x = self.diff.sample(self, h_pos=h_pos, h_neg=h_neg)
        
        test_item_emb = self.get_all_embeddings(h_pos.device)
        scores = torch.matmul(x, test_item_emb.transpose(0, 1))[:,
                self.config['select_pool'][0]: self.config['select_pool'][1]]
        
        preds = scores.topk(n_return_sequences, dim=-1).indices + self.config['select_pool'][0]
        
        return preds

    def _generate_negative_samples(self, batch):
        if self.config['sample_func'] == 'batch':
            return in_batch_negative_sampling(batch['labels'])
        elif self.config['sample_func'] == 'random':
            return in_batch_negative_sampling_sample(batch['labels'], self.config['neg_samples'])
        labels_neg = []
        for index in range(len(batch['labels'])):
            import numpy as np
            neg_samples = np.random.choice(range(self.config['select_pool'][0], self.config['select_pool'][1]), size=1,
                                           replace=False)
            neg_samples = neg_samples[neg_samples != batch['labels'][index]]
            labels_neg.append(neg_samples.tolist())
        return torch.LongTensor(labels_neg).to(batch['labels'].device).reshape(-1, 1)



