import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import fire
import pickle
from mingpt.utils import set_seed


print('cuda is available:', torch.cuda.is_available())
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def revbin2int(lst):
    return int(''.join([str(b) for b in lst][::-1]), 2)

def int2revbin(num, n):
    return [int(b) for b in bin(num)[2:].zfill(n)][::-1]


class TaskDataset(Dataset):
    def __init__(self, n_bits, ixes, task):
        self.n_bits = n_bits
        self.ixes = ixes
        self.task = task

    def __len__(self):
        return self.ixes.shape[0]

    def __getitem__(self, idx):
        n = self.n_bits
        idx = self.ixes[idx] # a random (2 * self.n_bits)-bit number
        if n < 16:
            assert idx.shape == ()
            nd = 2 ** n
            a = idx // nd # read off upper self.n_bits bits from idx
            b = idx %  nd # read off lower self.n_bits bits from idx
            a = int2revbin(a, n)
            b = int2revbin(b, n)
        else:
            assert idx.shape == (2 * n,)
            a = idx[:n]
            b = idx[n:]


        xor = np.logical_xor(a, b).astype(np.long)
        sum = np.array(int2revbin(revbin2int(a) + revbin2int(b), n))[:n]
        prd = np.array(int2revbin(revbin2int(a) * revbin2int(b), n))[:n]
        carries = []
        for i in range(n):
            last_carry = carries[-1] if len(carries) > 0 else 0
            carries.append(a[i] + b[i] + last_carry >= 2)

        if self.task == 'xor':
            out = xor
            inp = np.concatenate([a, b])
        elif self.task == 'xor_cont':
            out = xor
            inp = np.concatenate([a, b, [0], out[-n:-1]])
        elif self.task == 'xor_xor':
            out = np.concatenate([xor, xor])
            inp = np.concatenate([a, b, [0], out[-n:-1]])
        elif self.task == 'sum':
            out = sum
            inp = np.concatenate([a, b])
        elif self.task == 'prd':
            out = prd
            inp = np.concatenate([a, b])
        elif self.task == 'xor_sum':
            out = np.concatenate([xor, sum])
            inp = np.concatenate([a, b, [0], out[-n:-1]])
        elif self.task == 'carries_sum':
            out = np.concatenate([carries, sum])
            inp = np.concatenate([a, b, [0], out[-n:-1]])
        else:
            assert False
        return inp, out


def create_datasets(task, n_bits, n_train, n_val, n_test):
    n_total = n_train + n_val + n_test
    num = (2 ** n_bits) ** 2 # total number of input combinations
    if n_bits < 16:
        perm = np.random.choice(num, size=n_total, replace=False)
    else:
        perm = np.random.choice(2, size=(n_total, 2 * n_bits), replace=True)
    perm_splits = np.split(perm, np.cumsum([n_train, n_val, n_test]))

    train_dataset = TaskDataset(n_bits=n_bits, ixes=perm_splits[0], task=task)
    val_dataset = TaskDataset(n_bits=n_bits, ixes=perm_splits[1], task=task)
    test_dataset = TaskDataset(n_bits=n_bits, ixes=perm_splits[2], task=task)

    print(f'sample datapoint:\n'
          f'\tx: {train_dataset[0][0]}\n'
          f'\ty: {train_dataset[0][1]}')

    return train_dataset, val_dataset, test_dataset


class FPT(torch.nn.Module):
    def __init__(self):
        super(FPT, self).__init__()
        self.in_layer = torch.nn.Embedding(2, 768)
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.out_layer = torch.nn.Linear(768, 2)

    def forward(self, x):
        # x must be an array of vocab indices of shape (batch_size, seq_length)
        assert len(x.shape) == 2
        emb = self.in_layer(x)
        hidden_state = self.gpt2(inputs_embeds=emb).last_hidden_state
        logits = self.out_layer(hidden_state)
        return logits


def setup_fpt():
    fpt = FPT().to(device)

    for name, param in fpt.gpt2.named_parameters():
        # freeze all parameters except positional embeddings
        if 'ln' in name or 'wpe' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    optimizer = torch.optim.Adam(fpt.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    return fpt, optimizer, loss_fn


def train_fpt(model, dataset, loss_fn, optimizer, max_nepochs):
    print('training...')
    model.train()
    ret_accs = []
    loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
    n_epochs_without_progress = 0
    for i_epoch in range(max_nepochs):
        epoch_accs = []
        print(f'***** epoch {i_epoch+1}')
        for i_batch, (x, y) in enumerate(loader):
            x = x.to(device)
            y = y.to(device)
            logits = model(x)[:, -y.size(-1):, :]
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

            epoch_accs.extend((logits.argmax(dim=-1) == y).float()
                              .mean(dim=-1).cpu().numpy())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if len(epoch_accs) % 500 == 0:
                print(f'Samples: {i_epoch * len(dataset) + len(epoch_accs)}, '
                      f'Accuracy: {np.mean(epoch_accs[-500:])}')
                ret_accs.append(np.mean(epoch_accs[-500:]))
                if np.mean(epoch_accs[-500:]) > 0.9999:
                    # success
                    print('exiting training loop (success)...')
                    return ret_accs

        epoch_beg_acc = np.mean(epoch_accs[:500])
        epoch_end_acc = np.mean(epoch_accs[-500:])
        if epoch_beg_acc > epoch_end_acc:
            n_epochs_without_progress += 1
        else:
            n_epochs_without_progress = 0

        if n_epochs_without_progress >= 3:
            # training got stuck
            print('exiting training loop (no improvement over 3 epochs)...')
            break

    return ret_accs


def test_fpt(model, dataset):
    print('testing...')
    model.eval()
    accuracies = []
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    for i_batch, (x, y) in enumerate(loader):
        n_sample = x.shape[-1] - 2 * dataset.n_bits
        x = x.to(device)
        y = y.to(device)
        x = x[:, :2*dataset.n_bits]
        logits = model(x)
        preds = logits[:, -dataset.n_bits:, :].argmax(dim=-1)
        x = torch.cat((x, torch.zeros(x.size(0), 1).long().to(device)), dim=1)
        for _ in range(n_sample):
            logits = model(x)
            ix = logits[:, -1:, :].argmax(dim=-1)
            preds = torch.cat((preds, ix), dim=1)
            x = torch.cat((x, ix), dim=1)
        accuracies.extend((preds == y).float().mean(dim=-1)
                          .cpu().numpy() == 1.0)
    print(f'test set accuracy: {sum(accuracies)}/{len(accuracies)} '
          f'= {np.mean(accuracies)}')

    return accuracies


def main(res_pth, n_bits=8, task='carries_sum',
         n_train=10000, n_val=512, n_test=10000,
         n_trials=10):
    trial_train_accs = []
    trial_test_accs = []
    for t in range(n_trials): 
        train_dataset, val_dataset, test_dataset = create_datasets(
            task, n_bits, n_train, n_val, n_test
        )
        fpt, optimizer, loss_fn = setup_fpt()
        train_accs = train_fpt(fpt, train_dataset,
                               loss_fn, optimizer, max_nepochs=100)
        test_accs = test_fpt(fpt, test_dataset)

        trial_train_accs.append(train_accs)
        trial_test_accs.append(test_accs)

    if res_pth is not None:
        with open(res_pth, 'wb') as f:
            hyperparams = {
                'task': task,
                'n_bits': n_bits,
                'n_train': n_train,
                'n_test': n_test,
                'n_trials': n_trials,
            }
            pickle.dump((hyperparams, trial_train_accs, trial_test_accs), f)

if __name__ == '__main__':
    set_seed(42)
    fire.Fire(main)

