import torch
import numpy as np
from mingpt.utils import prog2fn, plot_matrix, sample
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import TrainerConfig, Trainer
from mingpt.dataset import (AdditionDataset, BinaryDataset,
                            give_addition_exam, give_exam)


class ExperimentConfig:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
        assert hasattr(self, 'n_bits') and hasattr(self, 'c_len')
        self.seq_len = 2 * self.n_bits + self.c_len
        self.block_size = self.seq_len - 1
        assert hasattr(self, 'n_emb')
        if not hasattr(self, 'n_ff') or self.n_ff is None:
            self.n_ff = 4 * self.n_emb
        if not hasattr(self, 'interm_idxs') or self.interm_idxs is None:
            self.interm_idxs = np.arange(2 * self.n_bits,
                                         self.seq_len - self.n_bits)
        self.interm_idxs = np.array(self.interm_idxs)
        if not hasattr(self, 'sample_interm'):
            self.sample_interm = False
        if not hasattr(self, 'batch_size'):
            self.batch_size = 512
        if not hasattr(self, 'val_freq'):
            self.val_freq = 0


class Experiment:
    def __init__(self, econf):
        self.econf = econf
        self.device = ('cpu' if not torch.cuda.is_available() else
                       torch.cuda.current_device())

        # Initialize model
        mconf = GPTConfig(vocab_size=econf.vocab_size,
                          block_size=econf.block_size,
                          n_layer=econf.n_layers,
                          n_head=econf.n_heads,
                          n_embd=econf.n_emb,
                          n_ff=econf.n_ff)
        self.model = GPT(mconf)
        self.model.double() # prevent a type error
        print(f'num params: {sum(p.numel() for p in self.model.parameters())}')

        # Initialize datasets
        if econf.target_fn is None:
            self.train_dataset = AdditionDataset(ndigit=econf.n_bits,
                                                 split='train')
            self.test_dataset = AdditionDataset(ndigit=econf.n_bits,
                                                split='test')
        else:
            for attr in ['max_ntotal', 'max_ntrain', 'max_nval', 'max_ntest']:
                if not hasattr(econf, attr):
                    setattr(econf, attr, float('inf'))
            if not hasattr(econf, 'max_ntrain_s2'):
                econf.max_ntrain_s2 = 0

            num = (2 ** econf.n_bits) ** 2 # total number of input combinations
            num_total = min(num, econf.max_ntotal)
            num_test = min(int(num_total * 0.2), econf.max_ntest)
            num_val = min(int(num_total * 0.2), econf.max_nval)
            num_train = min(num_total - num_val - num_test, econf.max_ntrain)
            num_train_s2 = min(num_total - num_val - num_test - num_train,
                               econf.max_ntrain_s2)

            if econf.n_bits < 16:
                perm = np.random.choice(num, size=num_total, replace=False)
            else:
                perm = np.random.choice(2, size=(num_total, 2 * econf.n_bits),
                                        replace=True)
            perm_splits = np.split(
                perm,
                np.cumsum([num_test, num_val, num_train, num_train_s2])
            )
            self.test_dataset = BinaryDataset(
                n_bits=econf.n_bits,
                target_fn=econf.target_fn,
                ixes=perm_splits[0],
                interm_idxs=econf.interm_idxs,
                interm_loss_prop=1.0 # do not corrupt test labels with -100s
            )
            self.val_dataset = BinaryDataset(
                n_bits=econf.n_bits,
                target_fn=econf.target_fn,
                ixes=perm_splits[1],
                interm_idxs=econf.interm_idxs,
                interm_loss_prop=1.0 # do not corrupt val labels with -100s
            )
            self.train_dataset = BinaryDataset(
                n_bits=econf.n_bits,
                target_fn=econf.target_fn,
                ixes=perm_splits[2],
                interm_idxs=econf.interm_idxs,
                interm_loss_prop=1.0
            )
            self.train_dataset_s2 = BinaryDataset(
                n_bits=econf.n_bits,
                target_fn=econf.target_fn,
                ixes=perm_splits[3],
                interm_idxs=econf.interm_idxs,
                interm_loss_prop=econf.interm_loss_prop
            )

            tconf = TrainerConfig(
                batch_size=self.econf.batch_size,
                learning_rate=self.econf.lr,
                lr_decay=True,
                warmup_tokens=1024,
                final_tokens=50*len(self.train_dataset)*(self.econf.n_bits+1),
                num_workers=0,
                interm_idxs=self.econf.interm_idxs
            )
            self.trainer = Trainer(model=self.model, config=tconf)

    def handcode(self, all_params=False, print_params=False):
        assert hasattr(self.econf, 'prog')

        block_size = self.econf.block_size
        n_emb = self.econf.n_emb
        prog = self.econf.prog

        W_emb = np.zeros((2, n_emb))
        W_emb[0, :3] = np.array([-1, 1, -1])
        W_emb[1, :3] = np.array([1, -1, -1])

        if hasattr(self.econf, 'sin_pos') and self.econf.sin_pos:
            W_pos = np.zeros((1, block_size, n_emb))
            for i in range(block_size):
                for j in range(n_emb):
                    if j % 2 == 0:
                        W_pos[0, i, j] = np.sin(i / (10000 ** (j / n_emb)))
                    else:
                        W_pos[0, i, j] = np.cos(i / (10000 ** ((j-1) / n_emb)))
        else:
            assert n_emb >= block_size
            W_pos = np.zeros((1, block_size, n_emb))
            W_pos[:, :, -block_size:] = np.eye(block_size)

        W_ln_1 = np.ones(n_emb)
        b_ln_1 = np.zeros(n_emb)
        W_k = np.eye(n_emb, k=-(n_emb-block_size))
        b_k = np.zeros(n_emb)
        W_q = np.zeros((n_emb, n_emb))
        for i, (attn, _) in enumerate(prog):
            for j in attn:
                W_q[-len(prog)+i, j] = 100
        b_q = np.zeros(n_emb)
        W_v = np.zeros((n_emb, n_emb))
        W_v[0, 3] = 1
        W_v[1, 4] = 1
        b_v = np.zeros(n_emb)
        W_proj = np.eye(n_emb)
        b_proj = np.zeros(n_emb)
        W_ln_2 = np.ones(n_emb)
        b_ln_2 = np.zeros(n_emb)
        W_ff_1 = np.zeros((n_emb, 4 * n_emb))
        W_ff_1[3, :2] = 100 * np.array([1, -1])
        factor = 1 / np.sqrt(4 / n_emb)
        W_ff_1[5:, 0] = factor * -100
        W_ff_1[5:, 1] = factor * 100
        if len(prog) > 0:
            op_arr = [1 if op == 'AND' else 0 if op == 'XOR' else -1
                      for _, op in prog]
            W_ff_1[-len(prog):, :2] *= np.array(op_arr)[:, np.newaxis]
        b_ff_1 = np.zeros(4 * n_emb)
        W_ff_2 = np.zeros((4 * n_emb, n_emb))
        W_ff_2[:2, 0] = np.array([-2, -2])
        b_ff_2 = 100 * np.eye(n_emb)[0]
        W_ln_f = np.ones(n_emb)
        b_ln_f = np.zeros(n_emb)
        W_head = np.zeros((n_emb, 2))
        W_head[0, :] = 100 * np.array([-1, 1])

        param_settings = {
            'W_emb':  (self.model.tok_emb, 'weight', W_emb),
            'W_pos':  (self.model, 'pos_emb', W_pos),
            'W_ln_1': (self.model.blocks[0].ln1, 'weight', W_ln_1.T),
            'b_ln_1': (self.model.blocks[0].ln1, 'bias', b_ln_1.T),
            'W_k':    (self.model.blocks[0].attn.key, 'weight', W_k.T),
            'b_k':    (self.model.blocks[0].attn.key, 'bias', b_k.T),
            'W_q':    (self.model.blocks[0].attn.query, 'weight', W_q.T),
            'b_q':    (self.model.blocks[0].attn.query, 'bias', b_q.T),
            'W_v':    (self.model.blocks[0].attn.value, 'weight', W_v.T),
            'b_v':    (self.model.blocks[0].attn.value, 'bias', b_v.T),
            'W_proj': (self.model.blocks[0].attn.proj, 'weight', W_proj.T),
            'b_proj': (self.model.blocks[0].attn.proj, 'bias', b_proj.T),
            'W_ln_2': (self.model.blocks[0].ln2, 'weight', W_ln_2.T),
            'b_ln_2': (self.model.blocks[0].ln2, 'bias', b_ln_2.T),
            'W_ff_1': (self.model.blocks[0].mlp[0], 'weight', W_ff_1.T),
            'b_ff_1': (self.model.blocks[0].mlp[0], 'bias', b_ff_1.T),
            'W_ff_2': (self.model.blocks[0].mlp[2], 'weight', W_ff_2.T),
            'b_ff_2': (self.model.blocks[0].mlp[2], 'bias', b_ff_2.T),
            'W_ln_f': (self.model.ln_f, 'weight', W_ln_f.T),
            'b_ln_f': (self.model.ln_f, 'bias', b_ln_f.T),
            'W_head': (self.model.head, 'weight', W_head.T),
        }

        if not hasattr(self.econf, 'param_status'):
            self.econf.param_status = {}
            for p in param_settings:
                self.econf.param_status[p] = {'set': False, 'freeze': False}

        for p, s in param_settings.items():
            if self.econf.param_status[p]['set'] or all_params:
                if print_params:
                    s2 = s[2] if p in ['W_emb', 'W_pos'] else s[2].T
                    print(f'{p} (shape: {s2.shape})\n{s2}\n')
                setattr(s[0],
                        s[1],
                        torch.nn.Parameter(torch.from_numpy(s[2])))
            if self.econf.param_status[p]['freeze']:
                getattr(s[0], s[1]).requires_grad = False

        if torch.cuda.is_available():
            self.model.to(self.device) # move to GPU if needed

    def train(self, s1=True):
        assert hasattr(self.econf, 'n_updates')
        assert hasattr(self.econf, 'n_updates_s2')

        train_dataset = self.train_dataset if s1 else self.train_dataset_s2
        n_updates = self.econf.n_updates if s1 else self.econf.n_updates_s2
        n_batches = np.ceil(len(train_dataset) / self.econf.batch_size)
        n_epochs = int(n_updates / n_batches) if n_batches > 0 else 0

        print(f'starting stage {1 if s1 else 2}...')
        print(f'\tnum training examples: {len(train_dataset)}')
        print(f'\tnum validation examples: {len(self.val_dataset)}')
        print(f'\tnum testing examples: {len(self.test_dataset)}')
        print(f'\tnum updates: {n_updates}')
        print(f'\tsample train datapoint:\n\t\t{train_dataset[0]}')

        train_losses, val_losses = self.trainer.train(
            train_dataset=train_dataset,
            n_epochs=n_epochs,
            val_dataset=self.val_dataset,
            val_freq=self.econf.val_freq,
            restart_lr=self.econf.restart_lr
        )
        return train_losses, val_losses

    def train2(self):
        assert hasattr(self.econf, 'epoch_datasets')

        num = (2 ** self.econf.n_bits) ** 2 # total number of input combinations
        with_replacement = num > 2 ** 24
        if with_replacement:
            print('forming the datasets with replacement...')
        perm = np.random.choice(num, size=min(num, self.econf.max_ntotal),
                                replace=with_replacement)
        num_test = min(int(len(perm)*0.2), 1000)
        num_train = len(perm) - num_test
        num_interm_loss = int(self.econf.interm_loss_prop * num_train)

        self.train_dataset_incl = BinaryDataset(
            n_bits=self.econf.n_bits, target_fn=self.econf.target_fn,
            ixes=perm[num_test:num_test+num_interm_loss],
            interm_idxs=self.econf.interm_idxs,
            interm_loss_prop=1.0
        )
        self.train_dataset_excl = BinaryDataset(
            n_bits=self.econf.n_bits, target_fn=self.econf.target_fn,
            ixes=perm[num_test+num_interm_loss:],
            interm_idxs=self.econf.interm_idxs,
            interm_loss_prop=0.0
        )
        self.test_dataset = BinaryDataset(
            n_bits=self.econf.n_bits, target_fn=self.econf.target_fn,
            ixes=perm[:num_test], interm_idxs=self.econf.interm_idxs,
            interm_loss_prop=1.0 # do not corrupt test labels with -100s
        )
        print(f'num training (incl interm): {len(self.train_dataset_incl)}')
        print(f'num training (excl interm): {len(self.train_dataset_excl)}')
        print(f'num testing examples: {len(self.test_dataset)}')

        tconf = TrainerConfig(
            max_epochs=self.econf.n_epochs, batch_size=512,
            learning_rate=self.econf.lr, lr_decay=True,
            warmup_tokens=1024,
            final_tokens=50*len(self.train_dataset)*(self.econf.n_bits+1),
            num_workers=0,
            sample_interm=self.econf.sample_interm,
            interm_idxs=self.econf.interm_idxs,
            epoch_datasets=self.econf.epoch_datasets
        )
        self.trainer = Trainer(
            model=self.model,
            train_dataset=(self.train_dataset_incl, self.train_dataset_excl),
            # val_dataset=self.val_dataset,
            val_dataset=None,
            config=tconf
        )
        self.trainer.train()

    def test(self):
        self.model.eval()
        exam_fn = (give_exam if self.econf.target_fn is not None else
                   give_addition_exam)
        print('evaluating on the train dataset...')
        train_accs = exam_fn(self.model, self.econf, self.device,
                             self.train_dataset_s2, batch_size=64,
                             max_batches=-1)
        print('evaluating on the validation dataset...')
        val_accs = exam_fn(self.model, self.econf, self.device,
                           self.val_dataset, batch_size=64,
                           max_batches=-1)
        print('evaluating othe test dataset...')
        test_accs = exam_fn(self.model, self.econf, self.device,
                            self.test_dataset, batch_size=64,
                            max_batches=-1)
        return train_accs, val_accs, test_accs

    def eval_example(self, a, b, print_interpreted=True,
                     plot_attns=True, save_attns=None, verbose_attns=False):
        y = np.array(self.econf.target_fn(a, b)).astype(int)
        x = torch.tensor([np.concatenate([a, b])]).long().to(self.device)
        preds, attns = sample(self.model, x, self.econf.c_len)
        pred = preds.cpu().numpy()[0, -self.econf.c_len:].astype(int)

        print('sample generation...')
        print(f'a:\t\t\t{a}')
        print(f'b:\t\t\t{b}')
        print(f'target_fn(a, b) true:\t{y}')
        print(f'target_fn(a, b) pred:\t{pred}')
        print('correct' if np.all(y == pred) else 'incorrect')

        if print_interpreted and hasattr(self.econf, 'sum_idxs'):
            print('****** interpreted output ******')
            print('a:' + f'{a[::-1]}'.rjust(60))
            print('b:' + f'{b[::-1]}'.rjust(60))
            print('correct:' + f'{y[self.econf.sum_idxs][::-1]}'.rjust(54))
            print('pred:' + f'{pred[self.econf.sum_idxs][::-1]}'.rjust(57))

        if plot_attns or save_attns is not None:
            seq = np.concatenate([a, b, pred]).astype(int)
            assert len(seq) == 2 * self.econf.n_bits + self.econf.c_len
            for i in range(self.econf.n_heads):
                plot_matrix(
                    attns[0][:, i], title=f'Attention weights: head {i+1}',
                    xlabel='input tokens', ylabel='output tokens',
                    xticklabels=seq[:-1],
                    yticklabels=np.arange(len(seq))[:-1]+1,
                    verbose=verbose_attns, save_pth=save_attns
                )

        return pred
