import torch
import torch.nn as nn
import torch.nn.functional as F


from model_gpt import ModelGPT


class ModelGPTLR2(ModelGPT):

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    
        for _ in range(max_new_tokens // self.d):
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size:]
            
            # Forward the model to get the logits for the index in the sequence:
            logit_list, _ = self(idx_cond)

            if not(isinstance(logit_list, torch.Tensor)):
                for logits in logit_list:

                    #print(f'logits: {logits.shape}')
                    
                    # Pluck the logits at the final step and scale by temperature:
                    logits = logits / temperature
                    
                    # Optionally crop the logits to only the top k options:
                    if top_k is not None:
                        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                        logits[logits < v[:, [-1]]] = -float('Inf')
                    
                    # Apply softmax to convert logits to (normalized) probabilities:
                    probs = F.softmax(logits, dim=-1)
                    
                    # Sample from the distribution:
                    idx_next = torch.multinomial(probs, num_samples=1)
                    
                    # Append sampled index to the running sequence and continue:
                    idx = torch.cat((idx, idx_next), dim=1)
            else:
                # default sampling should be here, but we do not expect this branch
                raise ValueError

        return idx
    

    def _build_tt_core(self, k, x, targets=None):
        B, T, C = x.shape  # Here x is [batch_size, block_size, n_embd]
        r1 = 1 if k == 0 else self.r         # Left rank
        n = self.n                           # Vocabulary size
        r2 = 1 if k == self.d-1 else self.r  # Right rank

        G = self.lm_heads[k](x)
        G = G.reshape(B*T, r1, n, r2)
        G = nn.functional.log_softmax(G, dim=2)

        if targets is None:
            G = G[-1]  # We select last output
        else:
            t = targets[k].reshape(-1, 1, 1, 1).repeat(1, r1, 1, r2)
            G = torch.gather(G, dim=2, index=t).squeeze(2)

        return G

    def _build_lr_factors(self, x, targets=None):
        G1 = self._build_tt_core(0, x, targets)
        G2 = self._build_tt_core(1, x, targets)
        if targets is None:
            # G is r1 x n x r2
            log_U = G1[0, :, :]     # Will be [n, r]
            log_V = G2[:, :, 0].T   # Will be [n, r]
        else:
            # G is batch x r1 x r2
            log_U = G1[:, 0, :]     # Will be [batch, r]
            log_V = G2[:, :, 0]     # Will be [batch, r]
        return log_U, log_V

    def _head_forward(self, x, targets=None, with_w_norm=True):
        pred = None
        loss = None

        if targets is None:
            pred = self._head_forward_pred(x)
        else:        
            loss = self._head_forward_loss(x, targets, with_w_norm=with_w_norm)
            
        return pred, loss

    def _head_forward_loss(self, x, targets, with_w_norm=True):
        w = self.lm_head_weight(x).reshape(-1, self.r)

        w_norm = nn.functional.softmax(w, dim=1) 
        log_w = torch.log(w_norm)
        log_U, log_V = self._build_lr_factors(x, targets)

        loss = torch.logsumexp(log_U + log_V + log_w, dim=1)
        loss = -1. * torch.mean(loss) / self.d

        if with_w_norm:
            wmx = torch.argmax(w, dim=1)
            fs = torch.bincount(wmx, minlength=self.r) / w.shape[0]
            ps = torch.mean(w_norm, dim=0)

            aux_ls = (fs * ps).sum() * self.r - 1

            loss += aux_ls * 1.E0

        return loss

    def _head_forward_pred(self, x):
        w = self.lm_head_weight(x).reshape(-1, self.r)
        w = nn.functional.softmax(w, dim=1)[-1, :]
        item_next = torch.multinomial(w, num_samples=1)[0]
        log_U, log_V = self._build_lr_factors(x)
        U_curr = nn.functional.softmax(log_U[:, item_next]).unsqueeze(0)
        V_curr = nn.functional.softmax(log_V[:, item_next]).unsqueeze(0)

        logits = [U_curr, V_curr]
        return logits

        # idx_next1 = torch.multinomial(U_curr, num_samples=1)
        # idx_next2 = torch.multinomial(V_curr, num_samples=1)           
        # pred = torch.stack([idx_next1, idx_next2], dim=1)
        # return pred

    def _head_init(self, config):
        self.d = config.d
        self.n = config.vocab_size
        self.r = config.r

        if self.d != 2:
            raise NotImplementedError('Number of heads (d) should be 2')

        self.lm_heads = []
        for k in range(self.d):
            r1 = 1 if k == 0 else self.r
            r2 = 1 if k == self.d-1 else self.r
            sz = r1 * self.n * r2
            self.lm_heads.append(
                nn.Linear(config.n_embd, sz, bias=False))
        self.lm_heads = nn.ModuleList(self.lm_heads)
        
        self.lm_head_weight = nn.Linear(config.n_embd, self.r, bias=True)