import matplotlib.pyplot as plt
import torch
import numpy as np
from mingpt.experiment import ExperimentConfig, Experiment
from mingpt.utils import set_seed, prog2fn
import pickle
import fire


def main(res_pth=None, n_bits=8, n_emb=96, n_heads=1,
         n_train=256, n_updates=4000, strong_sprv=True,
         wpos_mode=0, n_trials=1):
    # wpos modes: 0: learnable, 1: fixed identity, 2: fixed sinusoidal
    assert wpos_mode in [0, 1, 2]

    prog = [((i, n_bits + i), 'XOR') for i in range(n_bits)]
    prog += [((i, n_bits + i), 'AND') for i in range(n_bits)]
    prog += [((0, 0), 'XOR')]
    for i in range(n_bits):
        prog += [((2 * n_bits + i, 4 * n_bits + 2 * i), 'AND')]
        prog += [((3 * n_bits + i, 4 * n_bits + 2 * i + 1), 'XOR')]
    prog += [((2 * n_bits + i, 4 * n_bits + 2 * i), 'XOR')
             for i in range(n_bits)]
    sum_idxs = [-n_bits + i for i in range(n_bits)] + [-n_bits - 1]
    interm_idxs = np.arange(2 * n_bits, 2 * n_bits + len(prog) - n_bits)

    out_idxs = np.setdiff1d(np.arange(2*n_bits, 2*n_bits+len(prog)),
                            interm_idxs)
    direct_fn = lambda a, b: prog2fn(prog)(a, b)[out_idxs - 2*n_bits]
    c_len = n_bits

    carry_out_idxs = np.union1d(np.arange(4*n_bits, 6*n_bits, step=2), out_idxs)
    carry_out_fn = lambda a, b: prog2fn(prog)(a, b)[carry_out_idxs - 2*n_bits]

    param_status = {
        'W_emb':  {'set': False, 'freeze': False},
        'W_pos':  {'set': True, 'freeze': True},
        'W_ln_1': {'set': False, 'freeze': False},
        'b_ln_1': {'set': False, 'freeze': False},
        'W_k':    {'set': False, 'freeze': False},
        'b_k':    {'set': False, 'freeze': False},
        'W_q':    {'set': False, 'freeze': False},
        'b_q':    {'set': False, 'freeze': False},
        'W_v':    {'set': False, 'freeze': False},
        'b_v':    {'set': False, 'freeze': False},
        'W_proj': {'set': False, 'freeze': False},
        'b_proj': {'set': False, 'freeze': False},
        'W_ln_2': {'set': False, 'freeze': False},
        'b_ln_2': {'set': False, 'freeze': False},
        'W_ff_1': {'set': False, 'freeze': False},
        'b_ff_1': {'set': False, 'freeze': False},
        'W_ff_2': {'set': False, 'freeze': False},
        'b_ff_2': {'set': False, 'freeze': False},
        'W_ln_f': {'set': False, 'freeze': False},
        'b_ln_f': {'set': False, 'freeze': False},
        'W_head': {'set': False, 'freeze': False},
    }

    econf = ExperimentConfig(
        n_bits=n_bits,
        vocab_size=2,
        target_fn=prog2fn(prog) if strong_sprv else direct_fn,
        c_len=len(prog) if strong_sprv else c_len,
        prog=prog,
        sum_idxs=sum_idxs if strong_sprv else [],
        interm_idxs=interm_idxs if strong_sprv else [],
        interm_loss_prop=1.0,
        sample_interm=False,
        param_status=param_status,
        sin_pos=wpos_mode==2,
        n_emb=n_emb,
        n_layers=1,
        n_heads=n_heads,
        n_ff=None,
        max_ntotal=100000,
        max_ntrain=0,
        max_ntrain_s2=n_train,
        max_nval=512,
        max_ntest=10000,
        n_updates=0,
        n_updates_s2=n_updates,
        restart_lr=False,
        val_freq=100,
        epoch_datasets=None,
        batch_size=64,
        lr=6e-4
    )

    trial_losses = []
    trial_accs = []
    for t in range(n_trials):
        print(f'starting trial {t+1}/{n_trials}...')
        exp = Experiment(econf)
        if wpos_mode in [1, 2]:
            if wpos_mode == 1:
                block_size = econf.block_size
                n_emb = max(econf.n_emb, block_size)
                W_pos = np.zeros((1, block_size, n_emb))
                W_pos[:, :, -block_size:] = np.eye(block_size)
                W_pos = W_pos[:, :, -econf.n_emb:]
            else:
                block_size = econf.block_size
                n_emb = econf.n_emb
                W_pos = np.zeros((1, block_size, n_emb))
                for i in range(block_size):
                    for j in range(n_emb):
                        if j % 2 == 0:
                            W_pos[0, i, j] = np.sin(i / (10000 ** (j / n_emb)))
                        else:
                            W_pos[0, i, j] = np.cos(i / (10000 ** ((j-1) / n_emb)))
            exp.model.pos_emb = torch.nn.Parameter(torch.from_numpy(W_pos))
            exp.model.pos_emb.requires_grad = False
            if torch.cuda.is_available():
                exp.model.to(exp.device)
            
        print(f'    training...')
        train_losses_s2, val_losses_s2 = exp.train(s1=False)
        trial_losses.append((train_losses_s2, val_losses_s2))
        print(f'    testing...')
        train_accs_s2, val_accs_s2, test_accs_s2 = exp.test()
        trial_accs.append((train_accs_s2, val_accs_s2, test_accs_s2))

    if res_pth is not None:
        with open(res_pth, 'wb') as f:
            hyperparams = {
                'n_bits': n_bits,
                'n_emb': n_emb,
                'n_heads': n_heads,
                'n_train': n_train,
                'n_updates': n_updates,
                'strong_sprv': strong_sprv,
                'wpos_mode': wpos_mode,
                'n_trials': n_trials
            }
            pickle.dump((hyperparams, trial_losses, trial_accs), f)


if __name__ == '__main__':
    set_seed(42)
    fire.Fire(main)
