import torch
import torch.nn as nn
import torch as th
import torch.nn.functional as F
from recbole.model.layers import TransformerEncoder
from recbole.model.abstract_recommender import SequentialRecommender
from RQ_diffusion import diffusion
import jax
import jax.random as jrandom
import jax.numpy as jnp
import math
import numpy as np
from typing import Any
import flax
import faiss
#import flax.linen as nn


def log(t, eps = 1e-6):
    return torch.log(t + eps)


def sample_gumbel(shape, device, dtype, eps=1e-6):
    u = torch.empty(shape, device=device, dtype=dtype).uniform_(0, 1)
    return -log(-log(u, eps), eps)


def sinkhorn_sorting_operator(r, n_iters=8):
    n = r.shape[1]
    for _ in range(n_iters):
        r = r - torch.logsumexp(r, dim=2, keepdim=True)
        r = r - torch.logsumexp(r, dim=1, keepdim=True)
    return torch.exp(r)


def gumbel_sinkhorn(r, n_iters=8, temperature=0.7):
    r = log(r)
    gumbel = sample_gumbel(r.shape, r.device, r.dtype)
    r = (r + gumbel) / temperature
    return sinkhorn_sorting_operator(r, n_iters)


def differentiable_topk(x, k, temperature=1.):
    *_, n, dim = x.shape
    topk_tensors = []

    for i in range(k):
        is_last = i == (k - 1)
        values, indices = (x / temperature).softmax(dim=-1).topk(1, dim=-1)
        topks = torch.zeros_like(x).scatter_(-1, indices, values)
        topk_tensors.append(topks)
        if not is_last:
            x.scatter_(-1, indices, float('-inf'))

    topks = torch.cat(topk_tensors, dim=-1)
    return topks.reshape(*_, k * n, dim)

class SiLU(nn.Module):
    def forward(self, x):
        return x * th.sigmoid(x)


class VQRec(SequentialRecommender):
    def __init__(self, config, dataset):
        super().__init__(config, dataset)

        # VQRec args
        self.code_dim = config['code_dim']
        self.code_cap = config['code_cap']
        self.pq_codes = dataset.pq_codes
        self.codes_to_diffusion = (self.pq_codes % (self.code_cap + 1)).to('cuda:0')
        self.temperature = config['temperature']
        self.index_assignment_flag = False
        self.sinkhorn_iter = config['sinkhorn_iter']
        self.fake_idx_ratio = config['fake_idx_ratio']

        self.train_stage = config['train_stage']
        assert self.train_stage in [
            'pretrain', 'inductive_ft'
        ], f'Unknown train stage: [{self.train_stage}]'

        # load parameters info
        self.n_layers = config['n_layers']
        self.n_heads = config['n_heads']
        self.hidden_size = config['hidden_size']  # same as embedding_size
        self.inner_size = config['inner_size']  # the dimensionality in feed-forward layer
        self.hidden_dropout_prob = config['hidden_dropout_prob']
        self.attn_dropout_prob = config['attn_dropout_prob']
        self.hidden_act = config['hidden_act']
        self.layer_norm_eps = config['layer_norm_eps']

        self.initializer_range = config['initializer_range']
        self.loss_type = config['loss_type']
        self.num_steps = config['num_steps']
        self.temperature_softmax = config['temperature_softmax']
        self.eval_batch_size = config['eval_batch_size']

        # define layers and loss
        self.pq_code_embedding = nn.Embedding(
            self.code_dim * (1 + self.code_cap), self.hidden_size, padding_idx=0)
        self.reassigned_code_embedding = None

        self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size * self.code_dim)
        self.trm_encoder = TransformerEncoder(
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            hidden_size=self.hidden_size * self.code_dim,
            inner_size=self.inner_size,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attn_dropout_prob=self.attn_dropout_prob,
            hidden_act=self.hidden_act,
            layer_norm_eps=self.layer_norm_eps
        )
        self.trans_matrix = nn.Parameter(torch.randn(self.code_dim, self.code_cap + 1, self.code_cap + 1))

        self.LayerNorm = nn.LayerNorm(self.hidden_size * self.code_dim, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)
        self.diffusion_state = diffusion.create_discrete_diffusion(self.code_cap, self.num_steps)
        self.diffusionclass = self.diffusion_state.static_state["diffusion"]
        self.rng_key = jrandom.PRNGKey(config['random_seed'])

        time_embed_dim = self.hidden_size * self.code_dim * 4
        self.time_embed = nn.Sequential(nn.Linear(self.hidden_size * self.code_dim, time_embed_dim), SiLU(), nn.Linear(time_embed_dim, self.hidden_size * self.code_dim))

        if self.loss_type == 'BPR':
            raise NotImplementedError()
        elif self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['CE']!")

        # parameters initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def code_projection(self):
        doubly_stochastic_matrix = gumbel_sinkhorn(torch.exp(self.trans_matrix), n_iters=self.sinkhorn_iter)
        trans = differentiable_topk(doubly_stochastic_matrix.reshape(-1, self.code_cap + 1), 1)
        trans = torch.ceil(trans.reshape(-1, self.code_cap + 1, self.code_cap + 1))
        raw_embed = self.pq_code_embedding.weight.reshape(self.code_dim, self.code_cap + 1, -1)
        trans_embed = torch.bmm(trans, raw_embed).reshape(-1, self.hidden_size)
        return trans_embed
    
    def timestep_embedding(self, timesteps, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.

        :param timesteps: a 1-D Tensor of N indices, one per batch element.
                        These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an [N x dim] Tensor of positional embeddings.
        """
        timesteps = timesteps.to('cuda:0')
        half = dim // 2
        freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to('cuda:0')
        args = timesteps[:, None].float() * freqs[None]
        embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
        if dim % 2:
            embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
            
    def forward(self, item_seq, x_t, t, train_flag = False):
        position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
        position_embedding = self.position_embedding(position_ids)
        indices = 257 * torch.arange(0, 32).unsqueeze(0).unsqueeze(0).expand_as(x_t)
        indices = indices.to("cuda:0")
        condition = x_t != 0
        x_t[condition] += indices[condition]
        if self.index_assignment_flag:
            emb = F.embedding(x_t, self.reassigned_code_embedding, padding_idx=0)
            pq_code_emb = emb.reshape(emb.shape[0], emb.shape[1], -1)
        else:
            pq_code_emb = self.pq_code_embedding(x_t).mean(dim=-2) #[8192, 50, 300]

        batch_size = item_seq.size()[0]
        timestep_value = int(t)
        timesteps = torch.full((batch_size,), timestep_value, dtype=torch.int32)
        emb_t = self.time_embed(self.timestep_embedding(timesteps, self.hidden_size * self.code_dim))
        emb_t = emb_t.unsqueeze(1).expand(-1, 50, -1)
        pq_code_emb = pq_code_emb + emb_t
        input_emb = pq_code_emb + position_embedding
        input_emb = self.LayerNorm(input_emb)
        input_emb = self.dropout(input_emb)

        extended_attention_mask = self.get_attention_mask(item_seq)

        output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True)
        trm_output = output[-1]
        if train_flag:
            return trm_output  # [B H]
        else:
            trm_output_view = trm_output.view(-1, self.hidden_size).cpu().numpy()
            results_indices = np.zeros(trm_output_view.shape[0], dtype=np.int32)
            dimension = self.hidden_size
            res = faiss.StandardGpuResources()  # 使用默认GPU
            for i in range(self.code_dim):
                start_idx = 257 * i
                end_idx = 257 * (i + 1)
                codebook_part = self.reassigned_code_embedding[start_idx:end_idx].detach().cpu().numpy()
                
                index = faiss.IndexFlatL2(dimension)
                gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
                gpu_index.add(codebook_part)

                # 搜索当前子区间内的最近向量
                D, I = gpu_index.search(trm_output_view[i::self.code_dim], 1)  # 每次跳过 code_dim 取一个embedding
                results_indices[i::self.code_dim] = I.squeeze() + start_idx  # 将索引调整回全局codebook的索引

            # 将结果转换为原始批次和序列的形状
            indices = results_indices.reshape(trm_output.shape[0], trm_output.shape[1], trm_output.shape[2] // self.hidden_size)
            '''trm_output_view = trm_output.view(-1, self.hidden_size)
            reassigned_code_embedding = self.reassigned_code_embedding.cpu().numpy() 
            dimension = self.hidden_size
            index = faiss.IndexFlatL2(dimension) 
            index.add(reassigned_code_embedding) 
            D, I = index.search(trm_output_view.cpu().numpy(), 1)
            indices = I.reshape(self.eval_batch_size, 50, self.code_dim)'''

            return indices, trm_output


    def calculate_item_emb(self):
        if self.index_assignment_flag:
            #pq_code_emb = F.embedding(self.pq_codes, self.reassigned_code_embedding, padding_idx=0).mean(dim=-2)
            emb = F.embedding(self.pq_codes, self.reassigned_code_embedding, padding_idx=0)
            pq_code_emb = emb.reshape(emb.shape[0], -1)
        else:
            pq_code_emb = self.pq_code_embedding(self.pq_codes).mean(dim=-2)
        return pq_code_emb  # [B H]

    def generate_fake_neg_item_emb(self, item_index):
        rand_idx = torch.randint_like(input=item_index, high=self.code_cap)
        # flatten pq codes
        base_id = (torch.arange(self.code_dim).to(item_index.device) * (self.code_cap + 1)).unsqueeze(0)
        rand_idx = rand_idx + base_id + 1
        
        mask = torch.bernoulli(torch.full_like(item_index, self.fake_idx_ratio, dtype=torch.float))
        fake_item_idx = torch.where(mask > 0, rand_idx, item_index)
        return self.pq_code_embedding(fake_item_idx).mean(dim=-2)

    def seq_item_contrastive_task(self, seq_output, same_pos_id, interaction):
        pos_id = interaction['item_id']
        pos_pq_code = self.pq_codes[pos_id]
        if self.index_assignment_flag:
            pos_items_emb = F.embedding(pos_pq_code, self.reassigned_code_embedding, padding_idx=0).mean(dim=-2)
        else:
            pos_items_emb = self.pq_code_embedding(pos_pq_code).mean(dim=-2)
        pos_items_emb = F.normalize(pos_items_emb, dim=1)

        pos_logits = (seq_output * pos_items_emb).sum(dim=1, keepdim=True) / self.temperature
        pos_logits = torch.exp(pos_logits)

        neg_logits = torch.matmul(seq_output, pos_items_emb.transpose(0, 1)) / self.temperature
        neg_logits = torch.where(same_pos_id, torch.tensor([0], dtype=torch.float, device=same_pos_id.device), neg_logits)
        neg_logits = torch.exp(neg_logits).sum(dim=1).reshape(-1, 1)

        fake_item_emb = self.generate_fake_neg_item_emb(pos_pq_code)
        fake_item_emb = F.normalize(fake_item_emb, dim=-1)
        fake_logits = (seq_output * fake_item_emb).sum(dim=1, keepdim=True) / self.temperature
        fake_logits = torch.exp(fake_logits)

        loss = -torch.log(pos_logits / (neg_logits + fake_logits))
        return loss.mean()
    
    def pretrain(self, interaction):
        item_seq = interaction[self.ITEM_SEQ] #([8192, 50])
        item_seq_len = interaction[self.ITEM_SEQ_LEN] # 8192
        seq_output = self.forward(item_seq, item_seq_len) #[8192, 300]
        seq_output = F.normalize(seq_output, dim=1)

        # Remove sequences with the same next item
        pos_id = interaction['item_id'] # 8192
        same_pos_id = (pos_id.unsqueeze(1) == pos_id.unsqueeze(0)) #[8192, 8192]
        same_pos_id = torch.logical_xor(same_pos_id, torch.eye(pos_id.shape[0], dtype=torch.bool, device=pos_id.device))

        return self.seq_item_contrastive_task(seq_output, same_pos_id, interaction)
    
    def calculate_loss(self, interaction):
        if self.train_stage == 'pretrain':
            return self.pretrain(interaction)
        
        item_seq = interaction[self.ITEM_SEQ]
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        #if diffusion.has_state:
        #    diffusion.set_state(diffusion_state)
        
        time_key, rng_key = jrandom.split(self.rng_key)
        t = self.diffusionclass.sample_t(time_key, shape=())

        pq_code_seq = self.codes_to_diffusion[item_seq]  #torch.Size([2048, 50, 32])
        #x_start = pq_code_seq
        x_0 = jnp.array(pq_code_seq.cpu().numpy())  
        #从 q(x_{t+1} | x_0) 中采样
        '''q_t, x_t_plus_1, transition_probs = self.diffusionclass.sample_and_compute_posterior_q(
            rng_key,
            x_0,
            t,
            return_logits=False,
            return_transition_probs=True,
            step_size=1)'''
        x_t_numpy = np.array(x_0)  # x_t_plus_1
        x_t = torch.from_numpy(x_t_numpy).to(device='cuda:0')
        output = self.forward(item_seq, x_t, t, train_flag = True)

        seq_output = self.gather_indexes(output, item_seq_len - 1)

        pos_items = interaction[self.POS_ITEM_ID]

        test_item_emb = self.calculate_item_emb()
        
        if self.temperature > 0:
            seq_output = F.normalize(seq_output, dim=-1)
            test_item_emb = F.normalize(test_item_emb, dim=-1)
        
        logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
        
        if self.temperature > 0:
            logits /= self.temperature
            
            loss = self.loss_fct(logits, pos_items)
            return loss

    def predict(self, interaction):
        raise NotImplementedError()
    
    def calcu_logits(self, item_seq, x_t, t, item_seq_len):
        output = self.forward(item_seq, x_t, t)

        seq_output = self.gather_indexes(output, item_seq_len - 1)

        test_item_emb = self.calculate_item_emb()
        
        if self.temperature > 0:
            seq_output = F.normalize(seq_output, dim=-1)
            test_item_emb = F.normalize(test_item_emb, dim=-1)
        
        logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))

        return logits


    def full_sort_predict(self, interaction):
        item_seq = interaction[self.ITEM_SEQ]
        item_seq_len = interaction[self.ITEM_SEQ_LEN]

        @flax.struct.dataclass
        class SamplingState:
            x: Any  # current predicted sequence
            x0: Any  # only used if predict_x0 is true
            key: jnp.ndarray  # PRNGKey
            t: int  # current step

        extra_key, rng_key = jrandom.split(self.rng_key)
        
        def sampling_step(step, state):
            del step

            t = state.t  # initially, num_steps, and decreases from there.
            key = state.key

            logits, x0 = diffusion.p_forward(
                self.calcu_logits,  #item_seq, x_t, t
                item_seq=item_seq,
                item_seq_len = item_seq_len,
                x_t=state.x,
                t=t,
                diffusion=self.diffusionclass,
                predict_x0=True,
                return_x0=True,
                return_logits=True,
                maximum_likelihood=False,
                step_size=1)

            sampling_key, key = jrandom.split(state.key)
            sample = jrandom.categorical(sampling_key, logits, axis=-1)

            mask = (t == 1)
            sample = mask * logits.argmax(-1) + (1 - mask) * sample

            return SamplingState(x=sample, key=key, x0=x0, t=t - 1)
        
        x = self.codes_to_diffusion[item_seq]  #torch.Size([2048, 50, 32])
        #x = jnp.array(x.cpu().numpy()) 
        state = SamplingState(x, x, rng_key, self.num_steps)
        
        # Replacing jax.lax.fori_loop with a Python for loop
        for _ in range(self.num_steps):
            state = sampling_step(_, state)
        
        # Accessing the final state directly
        final_state = state
        
        # Assuming 'seq_output' should be the 'x' field from the final state
        seq_output = self.gather_indexes(final_state.x, item_seq_len - 1)

        test_items_emb = self.calculate_item_emb()
        
        if self.temperature > 0:
            seq_output = F.normalize(seq_output, dim=-1)
            test_items_emb = F.normalize(test_items_emb, dim=-1)
        
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B n_items]
        return scores

