import matplotlib.pyplot as plt
import torch
import numpy as np
from mingpt.experiment import ExperimentConfig, Experiment
from mingpt.utils import prog2fn, set_seed
import fire

def main(n_bits=8, verbose=False, fn='xor'):
    assert fn in ['xor', 'sum']
    
    prog = []
    if fn == 'xor':
        prog = [((i, n_bits + i), 'XOR') for i in range(n_bits)]
    elif fn == 'sum':
        prog = [((i, n_bits + i), 'XOR') for i in range(n_bits)]
        prog += [((i, n_bits + i), 'AND') for i in range(n_bits)]
        prog += [((0, 0), 'XOR')]
        for i in range(n_bits):
            prog += [((2 * n_bits + i, 4 * n_bits + 2 * i), 'AND')]
            prog += [((3 * n_bits + i, 4 * n_bits + 2 * i + 1), 'XOR')]
        prog += [((2 * n_bits + i, 4 * n_bits + 2 * i), 'XOR')
                 for i in range(n_bits)]

    econf = ExperimentConfig(
        n_bits=n_bits,
        vocab_size=2,
        target_fn=prog2fn(prog),
        c_len=len(prog),
        prog=prog,
        sum_idxs=[],
        interm_idxs=[],
        interm_loss_prop=0.0,
        n_emb=3 + (2 * n_bits + len(prog) - 1),
        n_layers=1,
        n_heads=1,
        lr=6e-4
    )
    exp = Experiment(econf)
    exp.handcode(all_params=True, print_params=verbose)
    exp.model.eval()

    a = np.random.randint(2, size=n_bits)
    b = np.random.randint(2, size=n_bits)
    # exp.eval_example(a, b, print_interpreted=False, plot_attns=False)

    y = np.array(econf.target_fn(a, b)).astype(int)
    x = np.concatenate([a, b, np.zeros(n_bits - 1)])
    x = torch.tensor([x]).long().to(exp.device)
    logits, _, attn = exp.model.forward(x, verbose=verbose)
    preds = (torch.argmax(torch.nn.functional.softmax(logits, dim=-1), axis=-1)
             .squeeze().cpu().detach().numpy()[-n_bits:].astype(int))
    print('sample generation...')
    print(f'a:\t\t\t{a}')
    print(f'b:\t\t\t{b}')
    print(f'target_fn(a, b) true:\t{y}')
    print(f'target_fn(a, b) pred:\t{preds}')
    print('correct' if np.all(y == preds) else 'incorrect')


if __name__ == '__main__':
    set_seed(42)
    fire.Fire(main)

