import torch
import torch.nn as nn

from teneva import sample

all_experts_summation = True

from model_gpt_lr2 import ModelGPTLR2

import matplotlib.pyplot as plt
import os
import math
import warnings

CORRECT_SAMPLING = True

def plot_histogram(data, title, xlabel, ylabel, filename):
    plt.figure()
    plt.hist(data, bins=20, edgecolor='black')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(filename)
    plt.close()

def plot_bar(data, title, xlabel, ylabel, filename):
    plt.figure()
    plt.bar(range(len(data)), data, edgecolor='black')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(filename)
    plt.close()

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

def f_top_k(logits, thres = 0.9):
    # top_k gating
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)



class ModelGPTCP(ModelGPTLR2):

    def _build_core(self, k, x, targets=None, n_last=None):
        B, T, C = x.shape  # Here x is [batch_size, block_size, n_embd]
        n = self.n                           # Vocabulary size

        if targets is None:
            if n_last:
                # needed for speculative decoding check pass
                G = self.lm_heads[k](x[:, -n_last:, :])
                G = G.reshape(-1, n, self.r)
                G = nn.functional.log_softmax(G, dim=1)

                G = G[-n_last:]
            else:
                G = self.lm_heads[k](x[:, -1:, :])
                G = G.reshape(-1, n, self.r)
                G = nn.functional.log_softmax(G, dim=1)

                G = G[-1]  # We select last output
        else:
            

            G = self.lm_heads[k](x)
            G = G.reshape(B*T, n, self.r)
            G = nn.functional.log_softmax(G, dim=1)

            t = targets[k].reshape(-1, 1, 1).repeat(1, 1, self.r)
            G = torch.gather(G, dim=1, index=t).squeeze(1)

        return G
    
    
    @torch.no_grad()
    def generate_w_speculative(
        self, 
        idx, 
        max_new_tokens, 
        temperature=1.0, 
        top_k=0.9,
        top_p=None,
        generation_config=None, 
        return_stats=False
    ):
        # https://github.com/lucidrains/speculative-decoding/blob/main/speculative_decoding/speculative_decoding.py

        self.eval()

        greedy_generation = generation_config.greedy_generation if generation_config else False
        correct_sampling = generation_config.correct_sampling if generation_config else CORRECT_SAMPLING

        if greedy_generation:
            n_generated_tokens = 0 
            matched_tokens_list = [] 

            while n_generated_tokens < max_new_tokens:
                # If the sequence context is growing too long we must crop it:
                if idx.size(1) <= self.block_size-self.d:
                    idx_cond = idx
                else:
                    idx_cond = idx[:, -self.block_size-self.d:]
                
                logits_next, _ = self(idx_cond)
                

                idxs_next = torch.stack([logit.argmax(dim=-1) for logit in logits_next], dim=1)

                #now we need to check new generated tokens
                idx_cands = torch.cat((idx, idxs_next), dim=1)

                checked_logits, _ = self.forward_for_check(idx_cands, check_k=self.d + 1)

                
                checked_cands = torch.stack([logit.argmax(dim=-1) for logit in checked_logits])
                checked_cands = checked_cands.reshape(1, -1)


                # checked cands in idxs next should be flattened
                # Find the longest common prefix of checked_cands and idxs_next
                common_prefix_length = 0
                for i in range(idxs_next.size(1)):
                    if torch.equal(checked_cands[:, i], idxs_next[:, i]):
                        common_prefix_length += 1
                    else:
                        break

                # If the common prefix length is equal to the length of idxs_next, write entire checked_cands to idx
                n_generated_tokens += common_prefix_length
                idx = torch.cat((idx, checked_cands[:, :common_prefix_length+1]), dim=1)
                matched_tokens_list.append(common_prefix_length)
        else:
            # we have a draft model q and a big model p
            n_generated_tokens = 0 
            matched_tokens_list = [] 

            while n_generated_tokens < max_new_tokens:
                # If the sequence context is growing too long we must crop it:
                if idx.size(1) <= self.block_size-self.d:
                    idx_cond = idx
                else:
                    idx_cond = idx[:, -self.block_size-self.d:]
                
                if not correct_sampling:
                    logits_q, _ = self(idx_cond)
                    if top_k:
                        logits_q = [f_top_k(l, thres=top_k) for l in logits_q] 
                    samples = []
                    for l in logits_q:
                        probs = nn.functional.softmax(l / temperature, dim=-1)
                        print(probs.shape, 'probs default')
                        samples.append(torch.multinomial(probs, num_samples=1).squeeze(-1))

                    # print(samples)
                else:
                    pred_dict, _ = self(idx_cond)
                    log_w = pred_dict['log_w']
                    log_cores = pred_dict['log_cores']

                    unnorm_expert_probs = log_w

                    samples = []
                    logits_q = []
                    for i in range(len(log_cores)):
                        logits = log_cores[i] + unnorm_expert_probs.expand(*log_cores[i].shape)
                        logits = torch.logsumexp(logits, dim=-1)
                        logits_q.append(logits.unsqueeze(0))
                        probs = nn.functional.softmax(logits / temperature, dim=-1).unsqueeze(0)
                        samples.append(torch.multinomial(probs, num_samples=1).squeeze(-1))

                        # update unnorm_expert_probs
                        unnorm_expert_probs += log_cores[i][samples[-1]].reshape(*unnorm_expert_probs.shape)

                    # print(samples)


                # samples = [gumbel_sample(l, temperature=temperature, dim=-1) for l in logits_q]
                idxs_next = torch.stack(samples, dim=1)

                #now we need to check new generated tokens
                idx_cands = torch.cat((idx, idxs_next), dim=1)
                logits_p, _ = self.forward_for_check(idx_cands, check_k=self.d + 1)

                if top_k:
                    logits_p = [f_top_k(l, thres=top_k) for l in logits_p]

                prob_p = [safe_div(l, temperature).softmax(dim = -1) for l in logits_p] 
                prob_q = [safe_div(l, temperature).softmax(dim = -1) for l in logits_q] 
                # p has one more token
                prob_p_alligned, prob_p_last = prob_p[:-1], prob_p[-1]

                prob_p_for_token = torch.gather(
                    torch.stack(prob_p_alligned).squeeze(1), 
                    -1, 
                    torch.transpose(idxs_next, 0, 1)
                ).squeeze(-1).cpu()
                prob_q_for_token = torch.gather(
                    torch.stack(prob_q).squeeze(1), 
                    -1, 
                    torch.transpose(idxs_next, 0, 1)
                ).squeeze(-1).cpu()


                r = torch.rand(len(prob_p_for_token))

                accepted_tokens = 0
                for i in range(len(prob_p_for_token)):
                    # we accept first token always because we know, that it is correct
                    if (i == 0) or (r[i] < prob_p_for_token[i] / prob_q_for_token[i]):
                        accepted_tokens += 1
                    else:
                        break
                #print('accepted', accepted_tokens, 'tokens')

                if accepted_tokens < len(prob_p_for_token):
                    prob_last = torch.maximum(
                        prob_p[accepted_tokens] - prob_q[accepted_tokens], 
                        torch.zeros_like(prob_p[accepted_tokens])
                    )
                else:
                    prob_last = prob_p_last
                sample_last = torch.multinomial(prob_last, num_samples=1)[0]

                accepted_samples = samples[:accepted_tokens] + [sample_last]
                accepted_samples = torch.stack(accepted_samples, dim=1)

                idx = torch.cat((idx, accepted_samples), dim=1)
                
                matched_tokens_list.append(accepted_tokens)
                n_generated_tokens += accepted_tokens + 1

        if self.manager.writer and not return_stats:
            plot_histogram(
                matched_tokens_list, 
                'Matched Tokens Histogram', 
                'Matched Tokens', 
                'Frequency', 
                os.path.join(self.manager.fold, 'matched_tokens.png')
            )
                

        if return_stats:
            return idx, matched_tokens_list
        else:
            return idx
    

    def _head_forward(self, x, targets=None, with_w_norm=True, check_k=None):
        pred = None
        loss = None
        
        if targets is None:
            if check_k is None:
                pred = self._head_forward_pred(x)
            else:
                pred = self._head_forward_check_last_k(x, k=check_k)
        else:        
            loss = self._head_forward_loss(x, targets, with_w_norm=with_w_norm)
            
        return pred, loss


    def _head_forward_loss(self, x, targets, with_w_norm=True):
        w = self.lm_head_weight(x).reshape(-1, self.r)

        w_norm = nn.functional.softmax(w, dim=1) 
        log_w = torch.log(w_norm)
        log_cores = [self._build_core(k, x, targets) for k in range(self.d)]
        log_cores.append(log_w)

        loss = torch.sum(torch.stack(log_cores), dim=0)
        loss = torch.logsumexp(loss, dim=1)

        if self.manager.writer: 
            for i in range(self.d):
                with torch.no_grad():
                    needed_cores = [log_cores[i]]
                    needed_cores.append(log_w)

                    one_token_loss = torch.sum(torch.stack(needed_cores), dim=0)
                    one_token_loss = torch.logsumexp(one_token_loss, dim=1)
                    one_token_loss = -1. * torch.mean(one_token_loss)

                    self.manager.writer.add_scalar(f'loss_token_{i}', one_token_loss.item(), self.manager.current_step)

        loss = -1. * torch.mean(loss) / self.d

        if with_w_norm:
            wmx = torch.argmax(w, dim=1)
            fs = torch.bincount(wmx, minlength=self.r) / w.shape[0]
            ps = torch.mean(w_norm, dim=0)

            aux_ls = (fs * ps).sum() * self.r - 0.99
            if self.manager.writer:
                self.manager.writer.add_scalar(f'balancing loss', aux_ls.item(), self.manager.current_step)

            # Warning: No auxiliary loss is being applied
            # warnings.warn("No auxiliary loss is being applied. This may affect model performance.", UserWarning)
            loss += aux_ls * 1.E0 

        return loss

    def _head_forward_pred(self, x):
        w = self.lm_head_weight(x).reshape(-1, self.r)
        w = nn.functional.softmax(w, dim=1)[-1, :]
        
        cores = [self._build_core(k, x) for k in range(self.d)]
        
        if self.d == 2048:
            raise ValueError('''
                             nice sampling, but needs refactoring for a new paradigm, 
                             where samples itself are obtained only in generate function
                             ''')
            core1 = torch.nn.functional.softmax(cores[0], dim=0).cpu().unsqueeze(0)
            
            core1 = torch.einsum('ijk,k->ijk', core1, w.cpu())
            core2 = torch.nn.functional.softmax(cores[1], dim=0).cpu().unsqueeze(0).permute(2, 1, 0).numpy()
            tt_tensor = [core1, core2]
            idxs = sample(tt_tensor)
            pred = torch.tensor(idxs, dtype=int).to(w.device)
        else:
            # print('Doing bad sampling, which is a subject for a later change')
            probs_next = []

            all_experts_summation = False
            correct_sampling = CORRECT_SAMPLING

            if all_experts_summation:
                log_w = torch.log(w)
                for k in range(self.d):
                    G_curr = cores[k] 

                    preds = G_curr + log_w.expand(*G_curr.shape)
                    # now shape of preds is [vocab_size, expert_dim]                   
                    preds = torch.logsumexp(preds, dim=1)
                    # now shape of preds is [vocab_size]
                    probs_next.append(preds.reshape(1, -1))
                return probs_next
            elif not correct_sampling:
                item_next = torch.multinomial(w, num_samples=1)[0]
                #item_next = torch.argmax(w)
                for k in range(self.d):
                    # softmax is not needed, because we have log probs
                    G_curr = cores[k][:, item_next]
                    #G_curr = nn.functional.softmax(cores[k][:, item_next])
                    # print('pred', item_next, G_curr, cores[k][0])
                    #idx_next = torch.multinomial(G_curr, num_samples=1)
                    #idxs_next.append(idx_next)       
                    probs_next.append(G_curr.unsqueeze(0))

                return probs_next
            elif correct_sampling:
                warnings.warn('correct sampling', RuntimeWarning, stacklevel=2)
                log_w = torch.log(w)
                cores_next = []
                for k in range(self.d):
                    G_curr = cores[k] 

                    cores_next.append(G_curr)
                return {
                    'log_w': log_w,
                    'log_cores': cores_next
                }
            else:
                raise ValueError('Unknown sampling type')
            #pred = torch.stack(idxs_next, dim=1)

            
        # return pred
    

    def _head_forward_check_last_k(self, x, k=4):
        '''
        Evaluates only the first head for the last k tokens to check the correctness
        '''

        w = self.lm_head_weight(x).reshape(-1, self.r)
        w = nn.functional.softmax(w, dim=-1)[-k:, :]
        
        core = self._build_core(0, x, n_last=k)

        probs_next = []
        top_1_experts = []
        exp_probs = None

        all_experts_summation = True
        for token_idx in range(k):
            if all_experts_summation:
                # execute all experts, not just sample one\
                # G_curr = nn.functional.softmax(core[token_idx], dim=0)
                # G_curr = G_curr * w[token_idx]
            
                log_w = torch.log(w[token_idx])
                G_curr = core[token_idx] 

                preds = G_curr + log_w.expand(*core[token_idx].shape)
                # now shape of preds is [vocab_size, expert_dim]
                
                preds = torch.logsumexp(preds, dim=1)
                # now shape of preds is [vocab_size]
                probs_next.append(preds.reshape(1, -1))

                if exp_probs is None:
                    exp_probs = w[token_idx].cpu().numpy()
                else:
                    exp_probs += w[token_idx].cpu().numpy()

            else:

                warnings.warn('Unoptimal sampling!', RuntimeWarning, stacklevel=2)
                item_next = torch.multinomial(w[token_idx], num_samples=1)[0]
                #item_next = torch.argmax(w[token_idx])
                top_1_experts.append(item_next.item())
                # softmax is not needed, because we have log probs
                G_curr = core[token_idx][:, item_next]
                #G_curr = nn.functional.softmax(core[token_idx][:, item_next])
                # print('check last k', w)
                # print('check last k', item_next, G_curr, core[token_idx][0])
                # idx_next = torch.multinomial(G_curr, num_samples=1)
                #idxs_next.append(idx_next) 
                probs_next.append(G_curr.unsqueeze(0))      
            
        #pred = torch.stack(idxs_next, dim=1)
        if self.manager.writer:
            if not all_experts_summation:
                plot_histogram(
                    top_1_experts, 
                    'Top 1 Experts Histogram', 
                    'Top 1 Experts', 
                    'Frequency', 
                    os.path.join(self.manager.fold, 'top_1_experts.png')
                )
            else:
                exp_probs = exp_probs / exp_probs.sum()
                exp_probs = exp_probs.reshape(-1)
                plot_bar(
                    exp_probs, 
                    'Expert Probs', 
                    'Expert', 
                    'Probability', 
                    os.path.join(self.manager.fold, 'expert_probs.png')
                )

        return probs_next


    def forward_for_check(self, idx, targets=None, with_w_norm=True, check_k=None):
        x = self.body_forward(idx)

        return self._head_forward(x, targets, with_w_norm=with_w_norm, check_k=check_k)


    def _head_init(self, config):
        self.d = config.d
        self.n = config.vocab_size
        self.r = config.r
        self.config = config

        # print(self.config)
        # raise ValueError('stop')

        self.lm_heads = []
        for k in range(self.d):
            sz = self.r * self.n
            self.lm_heads.append(
                nn.Linear(config.n_embd, sz, bias=False))
        self.lm_heads = nn.ModuleList(self.lm_heads)
        
        self.lm_head_weight = nn.Linear(config.n_embd, self.r, bias=True)