import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import dec2bin, prog2fn, sample


class AdditionDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.

    The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our
    encoding will simply be the n-digit first number, n-digit second number,
    and (n+1)-digit result, all simply concatenated together. Because each addition
    problem is so structured, there is no need to bother the model with encoding
    +, =, or other tokens. Each possible sequence has the same length, and simply
    contains the raw digits of the addition problem.

    As a few examples, the 2-digit problems:
    - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]
    - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]
    etc.

    We will also only train GPT on the final (n+1)-digits because the first
    two n-digits are always assumed to be given. So when we give GPT an exam later,
    we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like
    to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]
    in 3 sequential steps.

    fun exercise: does it help if the result is asked to be produced in reverse order?
    """

    def __init__(self, ndigit, split):
        self.split = split # train/test
        self.ndigit = ndigit
        self.vocab_size = 10 # 10 possible digits 0..9
        # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back
        self.block_size = ndigit + ndigit + ndigit + 1 - 1

        # split up all addition problems into either training data or test data
        num = (10**self.ndigit)**2 # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(num)
        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        nd = 10**self.ndigit
        a = idx // nd
        b = idx %  nd
        c = a + b
        render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes "0325028"
        dix = [int(s) for s in render] # convert each character to its token index
        # x will be input to GPT and y will be the associated expected outputs
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
        y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero
        return x, y


class BinaryDataset(Dataset):
    def __init__(self, n_bits, target_fn, ixes, interm_idxs, interm_loss_prop):
        self.n_bits = n_bits
        self.vocab_size = 2
        self.target_fn = target_fn
        self.ixes = ixes
        self.interm_idxs = interm_idxs
        self.interm_loss_prop = interm_loss_prop
        n_interm_loss = int(np.floor(interm_loss_prop * len(ixes)))
        self.interm_loss_idxs = np.random.permutation(len(ixes))[:n_interm_loss]

    def __len__(self):
        return self.ixes.shape[0]

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a and b
        idx = self.ixes[idx] # a random (2 * self.n_bits)-bit number
        if self.n_bits < 16:
            assert idx.shape == ()
            nd = 2 ** self.n_bits
            a = idx // nd # read off upper self.n_bits bits from idx
            b = idx %  nd # read off lower self.n_bits bits from idx
            a_bin = dec2bin(a, self.n_bits)
            b_bin = dec2bin(b, self.n_bits)
        else:
            assert idx.shape == (2 * self.n_bits,)
            a_bin = idx[:self.n_bits]
            b_bin = idx[self.n_bits:]
        c_bin = np.array(self.target_fn(a_bin, b_bin)).astype(int)
        dix = np.concatenate((a_bin, b_bin, c_bin))
        # x will be input to GPT and y will be the associated expected outputs
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        # we will only train in the output locations (-100 masks loss to zero)
        y[:2*self.n_bits-1] = -100
        if self.interm_loss_prop < 1.0 and idx not in self.interm_loss_idxs:
            y[self.interm_idxs-1] = -100
        return x, y


def give_addition_exam(model, econf, device, dataset,
                       batch_size=32, max_batches=-1, verbose=False):
    ndigit = econf.n_bits

    results = []
    loader = DataLoader(dataset, batch_size=batch_size)
    for b, (x, y) in enumerate(loader):
        x = x.to(device)
        d1d2 = x[:, :ndigit*2]
        d1d2d3, _ = sample(model, d1d2, ndigit+1)
        d3 = d1d2d3[:, -(ndigit+1):]
        factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(device)
        # decode the integers from individual digits
        d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
        d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
        d3i_pred = (d3 * factors).sum(1)
        d3i_gt = d1i + d2i
        correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            judge = 'YEP!!!' if correct[i] else 'NOPE'
            if verbose and not correct[i]:
                print("GPT claims that %03d + %03d = %03d (gt is %03d; %s)"
                      % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))

        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))
    return np.mean(results)


def give_exam(model, econf, device, dataset,
              batch_size=32, max_batches=-1, verbose=False):
    incl_seq_accs = []
    incl_tok_accs = []
    excl_seq_accs = []
    excl_tok_accs = []
    loader = DataLoader(dataset, batch_size=batch_size)
    for b, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)
        d1d2 = x[:, :econf.n_bits*2]
        d1d2d3, _ = sample(model, d1d2, econf.c_len)

        d3i = d1d2d3[:, -econf.c_len:]
        d3i_gt = y[:, -econf.c_len:]
        incl_seq_accs.extend((d3i == d3i_gt).all(dim=1).cpu().numpy())
        incl_tok_accs.extend((d3i == d3i_gt).float().mean(dim=1).cpu().numpy())

        e_idxs = np.setdiff1d(np.arange(2 * econf.n_bits, econf.seq_len),
                              econf.interm_idxs)
        d3e = d1d2d3[:, e_idxs]
        d3e_gt = y[:, e_idxs-1] # y does not include the first token
        excl_seq_accs.extend((d3e == d3e_gt).all(dim=1).cpu().numpy())
        excl_tok_accs.extend((d3e == d3e_gt).float().mean(dim=1).cpu().numpy())

        if max_batches >= 0 and b+1 >= max_batches:
            break

    print('\tsequence accuracy (incl interm): '
          f'{np.sum(incl_seq_accs)}/{len(incl_seq_accs)} = '
          f'{100*np.mean(incl_seq_accs):.2f}% correct')
    print(f'\ttoken accuracy (incl interm): {100*np.mean(incl_tok_accs):.2f}%')
    print('\tsequence accuracy (excl interm): '
          f'{np.sum(excl_seq_accs)}/{len(excl_seq_accs)} = '
          f'{100*np.mean(excl_seq_accs):.2f}% correct')
    print(f'\ttoken accuracy (excl interm): {100*np.mean(excl_tok_accs):.2f}%')
    return (np.mean(incl_seq_accs), np.mean(incl_tok_accs),
            np.mean(excl_seq_accs), np.mean(excl_tok_accs))
