from opt_einsum import contract
import torch
import torch.nn as nn
from torch.nn import functional as F


from model_gpt import ModelGPT


class ModelGPTMT(ModelGPT):
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens // self.d):
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size:]
            
            logits_all, _ = self(idx_cond)

            for logits in logits_all:
                logits = logits[:, -1, :] / temperature
                
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                
                probs = F.softmax(logits, dim=-1)
                
                idx_next = torch.multinomial(probs, num_samples=1)
                
                idx = torch.cat((idx, idx_next), dim=1)

        return idx

    def _head_forward(self, x, targets=None):
        # Note: we expect here the targets here as a list of the length d
        # (where config.d are possible shifts).

        # DRAFT !!!!!

        if targets is not None:
            logits_all = [head(x) for head in self.lm_heads]
            loss_all = []
            for logits, t in zip(logits_all, targets):
                loss_all.append(F.cross_entropy(
                    logits.view(-1, logits.size(-1)), t.view(-1),
                    ignore_index=-1))
            loss = torch.stack(loss_all).mean(dim=0)
        else:
            logits_all = [head(x[:, [-1], :]) for head in self.lm_heads]
            loss = None

        return logits_all, loss

    def _head_init(self, config):
        self.d = config.d

        self.lm_heads = []
        for _ in range(self.d): # TODO: this is a draft for TT-code
            self.lm_heads.append(
                nn.Linear(config.n_embd, config.vocab_size, bias=False))