import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from mingpt.model import GPT, GPTConfig

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    block_size = model.get_block_size()
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
        logits, _, attns = model(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        x = torch.cat((x, ix), dim=1)

    return x, attns

def plot_matrix(x, verbose=True, title=None, xlabel=None, ylabel=None,
                xticklabels=None, yticklabels=None, save_pth=None, dpi=100):
    if torch.is_tensor(x):
        x = x.cpu().detach().numpy()
    x = x.squeeze()
    while len(x.shape) < 2:
        x = np.expand_dims(x, axis=0)
    assert len(x.shape) == 2
    fig, ax = plt.subplots(dpi=dpi, figsize=(12, 8))
    im = ax.imshow(x)
    cbar_height = max(ax.get_position().height, 0.2)
    cbar_yoffset = (cbar_height - ax.get_position().height) / 2
    cax = fig.add_axes([ax.get_position().x1 + 0.02,
                        ax.get_position().y0 - cbar_yoffset,
                        0.02,
                        cbar_height])
    fig.colorbar(im, cax=cax)
    if verbose:
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                ax.text(j, i, f'{x[i, j]:.1f}', ha='center', va='center',
                        color='w', fontsize='x-small')
        ax.set_xticks(np.arange(x.shape[1]))
        ax.set_yticks(np.arange(x.shape[0]))
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if xticklabels is not None:
        ax.set_xticks(np.arange(len(xticklabels)))
        ax.set_xticklabels(xticklabels)
    if yticklabels is not None:
        ax.set_yticks(np.arange(len(yticklabels)))
        ax.set_yticklabels(yticklabels)
    if title is not None:
        ax.set_title(title)

    if save_pth is None:
        plt.show()
    else:
        plt.savefig(save_pth)

def dec2bin(x, n_bits=None):
    bit_string = bin(x)[2:] if n_bits is None else (bin(x)[2:]).zfill(n_bits)
    return [int(bit) for bit in bit_string]

def bin2dec(x):
    return sum(bit * (2 ** i) for i, bit in enumerate(reversed(x)))

def prog2fn(prog):
    def fn(a, b):
        x = np.concatenate([a, b])
        for attn, op in prog:
            assert op in ['AND', 'XOR']
            op_fn = np.logical_and if op == 'AND' else np.logical_xor
            x = np.append(x, op_fn(x[attn[0]], x[attn[1]]))
        return x[-len(prog):]
    return fn

def econf2model(econf):
    mconf = GPTConfig(vocab_size=econf.vocab_size,
                      block_size=econf.block_size,
                      n_layer=econf.n_layers, n_head=econf.n_heads,
                      n_embd=econf.n_emb)
    return GPT(mconf)

def mingpt_params(model):
    params = {
        'W_emb':  model.tok_emb.weight,
        'W_pos':  model.pos_emb,
        'W_ln_1': model.blocks[0].ln1.weight,
        'b_ln_1': model.blocks[0].ln1.bias,
        'W_k':    model.blocks[0].attn.key.weight,
        'b_k':    model.blocks[0].attn.key.bias,
        'W_q':    model.blocks[0].attn.query.weight,
        'b_q':    model.blocks[0].attn.query.bias,
        'W_v':    model.blocks[0].attn.value.weight,
        'b_v':    model.blocks[0].attn.value.bias,
        'W_proj': model.blocks[0].attn.proj.weight,
        'b_proj': model.blocks[0].attn.proj.bias,
        'W_ln_2': model.blocks[0].ln2.weight,
        'b_ln_2': model.blocks[0].ln2.bias,
        'W_ff_1': model.blocks[0].mlp[0].weight,
        'b_ff_1': model.blocks[0].mlp[0].bias,
        'W_ff_2': model.blocks[0].mlp[2].weight,
        'b_ff_2': model.blocks[0].mlp[2].bias,
        'W_ln_f': model.ln_f.weight,
        'b_ln_f': model.ln_f.bias,
        'W_head': model.head.weight,
    }
    for p in params:
        params[p] = params[p].cpu().detach().numpy()
        if p not in ['W_emb', 'W_pos']:
            params[p] = params[p].T

    n_params_actual = sum(p.numel() for p in model.parameters())
    n_params_read_off = sum(params[p].size for p in params)
    assert n_params_read_off == n_params_actual

    return params
