

from teneva import sample
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F
from teneva import sample
import math
import os
import warnings
import matplotlib.pyplot as plt


def plot_histogram(data, title, xlabel, ylabel, filename):
    plt.figure()
    plt.hist(data, bins=20, edgecolor='black')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(filename)
    plt.close()

def plot_bar(data, title, xlabel, ylabel, filename):
    plt.figure()
    plt.bar(range(len(data)), data, edgecolor='black')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(filename)
    plt.close()

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

def f_top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)


# %%
class DefaultHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DefaultHead, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x, targets=None):
        if targets == None:
            logits = self.linear(x[:, [-1], :]) 
            loss = None
        else:
            logits = self.linear(x)
            
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1),
                ignore_index=-1)

        return logits, loss


# %%
class CPHead(nn.Module):
    def __init__(self, input_dim, output_dim, n_tokens=2, r=5):

        super(CPHead, self).__init__()
        
        self.n_embd = input_dim
        self.n = output_dim
        self.d = n_tokens
        self.r = r
        self.lm_heads = []
        for k in range(n_tokens):
            sz = self.r * self.n
            self.lm_heads.append(
                nn.Linear(self.n_embd, sz, bias=False))
        self.lm_heads = nn.ModuleList(self.lm_heads)
        
        self.lm_head_weight = nn.Linear(self.n_embd, self.r, bias=True)

    
    def _build_core(self, k, x, targets=None, n_last=None):
        B, T, C = x.shape  # Here x is [batch_size, block_size, n_embd]
        n = self.n                           # Vocabulary size

        if targets is None:
            if n_last:
                # needed for speculative decoding check pass
                G = self.lm_heads[k](x[:, -n_last:, :])
                G = G.reshape(-1, n, self.r)
                G = nn.functional.log_softmax(G, dim=1)

                G = G[-n_last:]
            else:
                G = self.lm_heads[k](x[:, -1:, :])
                G = G.reshape(-1, n, self.r)
                G = nn.functional.log_softmax(G, dim=1)

                G = G[-1]  # We select last output
        else:
            G = self.lm_heads[k](x)
            G = G.reshape(B*T, n, self.r)
            G = nn.functional.log_softmax(G, dim=1)

            t = targets[k].reshape(-1, 1, 1).repeat(1, 1, self.r)
            G = torch.gather(G, dim=1, index=t).squeeze(1)

        return G

    
    def forward(self, x, targets=None, with_w_norm=True, check_k=None):
        pred = None
        loss = None
        
        if targets is None:
            if check_k is None:
                pred = self._head_forward_pred(x)
            else:
                pred = self._head_forward_check_last_k(x, k=check_k)

        else:        
            loss = self._head_forward_loss(x, targets, with_w_norm=with_w_norm)
            
        return pred, loss

    def _head_forward_loss(self, x, targets, with_w_norm=True):
        w = self.lm_head_weight(x).reshape(-1, self.r)

        w_norm = nn.functional.softmax(w, dim=1) 
        log_w = torch.log(w_norm)
        log_cores = [self._build_core(k, x, targets) for k in range(self.d)]
        log_cores.append(log_w)
        loss = torch.sum(torch.stack(log_cores), dim=0)
        loss = torch.logsumexp(loss, dim=1)
        loss = -1. * torch.mean(loss) / self.d

        if with_w_norm:
            wmx = torch.argmax(w, dim=1)
            fs = torch.bincount(wmx, minlength=self.r) / w.shape[0]
            ps = torch.mean(w_norm, dim=0)

            aux_ls = (fs * ps).sum() * self.r - 1

            loss += aux_ls * 1.E0

        return loss

    def _head_forward_pred(self, x):
        w = self.lm_head_weight(x).reshape(-1, self.r)
        w = nn.functional.softmax(w, dim=1)[-1, :]
        
        cores = [self._build_core(k, x) for k in range(self.d)]
        
        probs_next = []
        log_w = torch.log(w)
        for k in range(self.d):
            G_curr = cores[k] 
            preds = G_curr + log_w.expand(*G_curr.shape)
            preds = torch.logsumexp(preds, dim=1)
            probs_next.append(preds.reshape(1, -1))

        return {
            'log_w': log_w,
            'log_cores': cores
        }

    def _head_forward_check_last_k(self, x, k=4):
        w = self.lm_head_weight(x).reshape(-1, self.r)
        w = nn.functional.softmax(w, dim=-1)[-k:, :]
        
        core = self._build_core(0, x, n_last=k)

        probs_next = []
        exp_probs = None

        for token_idx in range(k):
            log_w = torch.log(w[token_idx])
            G_curr = core[token_idx] 

            preds = G_curr + log_w.expand(*core[token_idx].shape)
            preds = torch.logsumexp(preds, dim=1)
            probs_next.append(preds.reshape(1, -1))

            if exp_probs is None:
                exp_probs = w[token_idx].cpu().numpy()
            else:
                exp_probs += w[token_idx].cpu().numpy()

        # if self.manager and self.manager.writer:
        #     exp_probs = exp_probs / exp_probs.sum()
        #     exp_probs = exp_probs.reshape(-1)
        #     plot_bar(
        #         exp_probs, 
        #         'Expert Probs', 
        #         'Expert', 
        #         'Probability', 
        #         os.path.join(self.manager.fold, 'expert_probs.png')
        #     )

        return probs_next


# prototype, not working yet.
class SparseCPHead(nn.Module):
    def __init__(self, input_dim, output_dim, n_tokens=2, r=5, k=3):
        super(SparseCPHead, self).__init__()
        
        self.n_embd = input_dim
        self.n = output_dim
        self.d = n_tokens
        self.r = r
        self.k = k  # Number of top heads to execute
        
        self.lm_heads = nn.ModuleList([
            nn.Linear(self.n_embd, self.r * self.n, bias=False)
            for _ in range(n_tokens)
        ])
        
        self.lm_head_weight = nn.Linear(self.n_embd, self.r, bias=True)

    def _build_core(self, k, x, targets=None):
        B, T, C = x.shape
        n = self.n
        
        G = self.lm_heads[k](x)
        G = G.reshape(B*T, n, self.r)
        G = nn.functional.log_softmax(G, dim=1)

        if targets is None:
            G = G[-1]
        else:
            t = targets[k].reshape(-1, 1, 1).repeat(1, 1, self.r)
            G = torch.gather(G, dim=1, index=t).squeeze(1)

        return G

    def forward(self, x, targets=None):
        pred = None
        loss = None

        if targets is None:
            pred = self._head_forward_pred(x)
        else:        
            loss = self._head_forward_loss(x, targets)
            
        return pred, loss

    def _head_forward_loss(self, x, targets):
        w = self.lm_head_weight(x).reshape(-1, self.r)
        w_norm = nn.functional.softmax(w, dim=1)
        log_w = torch.log(w_norm)

        # Select top-k heads
        _, top_k_indices = torch.topk(w_norm, self.k, dim=1)
        
        log_cores = []
        for k in range(self.d):
            core = self._build_core(k, x, targets)
            # Only keep values for top-k heads
            core_masked = core.gather(1, top_k_indices)
            log_cores.append(core_masked)

        log_cores.append(log_w.gather(1, top_k_indices))
        loss = torch.sum(torch.stack(log_cores), dim=0)
        loss = torch.logsumexp(loss, dim=1)
        loss = -1. * torch.mean(loss) / self.d

        return loss

    def _head_forward_pred(self, x):
        w = self.lm_head_weight(x).reshape(-1, self.r)
        w = nn.functional.softmax(w, dim=1)[-1, :]
        
        # Select top-k heads
        _, top_k_indices = torch.topk(w, self.k)
        
        cores = []
        for k in range(self.d):
            core = self._build_core(k, x)
            # Only keep values for top-k heads
            core_masked = core[:, top_k_indices]
            cores.append(core_masked)
        
        if self.d == 2:
            core1 = torch.nn.functional.softmax(cores[0], dim=0).cpu().unsqueeze(0)
            core1 = torch.einsum('ijk,k->ijk', core1, w[top_k_indices].cpu())
            core2 = torch.nn.functional.softmax(cores[1], dim=0).cpu().unsqueeze(0).permute(2, 1, 0).numpy()
            tt_tensor = [core1, core2]
            idxs = sample(tt_tensor)
            pred = torch.tensor(idxs, dtype=int).to(w.device)
        else:
            item_next = torch.multinomial(w[top_k_indices], num_samples=1)[0]
            idxs_next = []
            for k in range(self.d):
                G_curr = nn.functional.softmax(cores[k][:, item_next])
                idx_next = torch.multinomial(G_curr, num_samples=1)
                idxs_next.append(idx_next)       
            
            pred = torch.stack(idxs_next, dim=1)

        return pred