import torch
import torch.nn as nn

from teneva import sample

import torch.nn.functional as F

from model_gpt import ModelGPT

from refactored_modules.sampling import greedy_speculative_decoding
from refactored_modules.sparse_cp_moe import CP_MoE

class ModelGPTCPSparse(ModelGPT):
    
    def _head_init(self, config):
        self.d = config.d
        self.n = config.vocab_size
        self.r = config.r

        # moe layer takes an arbitrary chunk of embeddings and routes them to a fixed number of experts
        self.moe = CP_MoE(
            input_dim = config.n_embd,
            output_dim = config.vocab_size,
            num_experts = config.r,
            n_tokens = config.d,
        )

    def _head_forward(self, x, targets=None, with_w_norm=True):
        pred = None
        loss = None

        if targets is None:
            pred = self._head_forward_pred(x)
        else:        
            loss = self._head_forward_loss(x, targets)
            
        return pred, loss
    

    def get_last_hidden_state(self, idx):
        # basically copy of forward, but without the last layer

        device = idx.device
        b, t = idx.size()
        
        msg = f'Cannot forward sequence of length {t}, '
        msg += f'block size is only {self.block_size}'
        assert t <= self.block_size, msg
        
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        # Forward the GPT model itself:
        # (token embeddings of shape (b, t, n_embd);
        #  position embeddings of shape (t, n_embd))
        tok_emb = self.transformer.wte(idx) 
        pos_emb = self.transformer.wpe(pos)

        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)

        return x
    
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, speculative=False):
        """Generation of the new sequence.
        
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t))
        and complete the sequence max_new_tokens times, feeding the predictions
        back into the model each time. Most likely you'll want to make sure to
        be in model.eval() mode of operation for this.
        
        """

        if speculative: 
            raise NotImplementedError()
        
        for _ in range(max_new_tokens):
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size:]
            
            # Forward the model to get the logits for the index in the sequence:
            logits, _ = self(idx_cond)
            
            # Pluck the logits at the final step and scale by temperature:
            logits = logits[:, -1, :] / temperature
            
            # Optionally crop the logits to only the top k options:
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            # Apply softmax to convert logits to (normalized) probabilities:
            probs = F.softmax(logits, dim=-1)
            
            # Sample from the distribution:
            idx_next = torch.multinomial(probs, num_samples=1)
            
            # Append sampled index to the running sequence and continue:
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

    def _head_forward_pred(self, x, heads_to_eval='all'):
        outputs, aux_loss = self.moe(x, heads_to_eval='all')

        outputs = [torch.logsoftmax(output, dim=-1) for output in outputs]

        return outputs

    def _head_forward_loss(self, x, targets, with_w_norm=True):
        # 1st is a list [6, 1024, 50k, 3], second is [6, 1024, 3]
        expert_outputs, combine_tensor, aux_loss = self.moe(x)

        log_w = torch.log(combine_tensor).reshape(-1, self.r)

        print(log_w)

        print('expert_outputs', expert_outputs[0].shape)
        print('combine_tensor', combine_tensor.shape)

        B, T, vocab_size, r = expert_outputs[0].shape
        expert_outputs = [expert_output.reshape(B*T, vocab_size, r) for expert_output in expert_outputs]
        expert_outputs = [torch.log_softmax(expert_output, dim=1) for expert_output in expert_outputs]

        gathered_outputs = []
        for i in range(self.d):
            t = targets[i].reshape(-1, 1, 1).repeat(1, 1, r)
            gathered_outputs.append(torch.gather(expert_outputs[i], dim=1, index=t).squeeze(1))

        log_cores = gathered_outputs
        log_cores.append(log_w)
        print('log_cores', log_cores[0].shape)
        print('log_w', log_w.shape)
        loss = torch.sum(torch.stack(log_cores), dim=0)
        loss = torch.logsumexp(loss, dim=1)

        loss = -1. * torch.mean(loss) / self.d

        #loss = loss + aux_loss

        return loss

        