import torch
import torch.nn as nn
import torch.nn.functional as F

from model_gpt import ModelGPT



class ModelGPTRK1B(ModelGPT):

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    
        for _ in range(max_new_tokens // self.d):
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size:]
            
            # Forward the model to get the logits for the index in the sequence:
            logit_list, _ = self(idx_cond)

            if not(isinstance(logit_list, torch.Tensor)):
                for logits in logit_list:
                    0
                    # Pluck the logits at the final step and scale by temperature:
                    logits = logits[:, -1, :] / temperature
                    
                    # Optionally crop the logits to only the top k options:
                    if top_k is not None:
                        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                        logits[logits < v[:, [-1]]] = -float('Inf')
                    
                    # Apply softmax to convert logits to (normalized) probabilities:
                    probs = F.softmax(logits, dim=-1)
                    
                    # Sample from the distribution:
                    idx_next = torch.multinomial(probs, num_samples=1)
                    
                    # Append sampled index to the running sequence and continue:
                    idx = torch.cat((idx, idx_next), dim=1)
            else:
                # default sampling should be here, but we do not expect this branch
                raise ValueError

        return idx

    def _head_forward(self, x, targets=None, **kwargs):
        pred = None
        loss = None

        if targets is None:
            pred = self._head_forward_pred(x)
        else:  
            x, lm_heads, targets, d = self._get_log_target_partial(x, targets)
            return x, lm_heads, targets, d
        
            #loss = self._head_forward_loss(x, targets)
            
        return pred, loss

    def _head_forward_loss(self, x, targets, with_w_norm=True):
        
        preds = torch.stack(self._get_log_target(x, targets))
        loss = torch.sum(preds, dim=0)
        # this is needed only for ranks != 1, and in this baseline rank is 1
        # loss = torch.logsumexp(preds_sum, dim=1)
        loss = -1. * torch.mean(loss) / self.d

        return loss
    
    def _get_log_target_partial(self, x, targets):

        return x, self.lm_heads, targets, self.d


    def _get_log_target(self, x, targets):
        if targets is not None:
            print('deprecated method, use partial one instead')
            raise ValueError
        
        B, T, C = x.shape  # Here x is [batch_size, block_size, n_embd]

        G = self._predict(x)
        # G is a list of len self.d of 3d tensors
        G = [a.reshape(B*T, self.n) for a in G]
        G = [nn.functional.log_softmax(a, dim=1) for a in G]
        

        if targets is None:
            G = [a[-1] for a in G]  # We select last output
        else:
            G = [torch.gather(a, dim=1, index=t.reshape(-1, 1)).squeeze(1) for (a, t) in zip(G, targets)]
            
        return G

    def _predict(self, x):
        return [mod(x) for mod in self.lm_heads]

    def _head_forward_pred(self, x):

        logs = self._predict(x)
        return logs
        # sm_logs = [nn.functional.softmax(l.squeeze(0)[-1]) for l in logs]
        
        
        # idxs  = [torch.multinomial(p, num_samples=1) for p in sm_logs]
        
        # pred = torch.stack(idxs, dim=1)
        
        # return pred

    def _head_init(self, config):
        self.d = config.d
        self.n = config.vocab_size

        self.lm_heads = []
        for k in range(self.d):
            self.lm_heads.append(
                nn.Linear(config.n_embd, self.n, bias=False))
        self.lm_heads = nn.ModuleList(self.lm_heads)
        