import numpy as np
import os
import torch
from lstm import LSTMTextCompletion
from six.moves import cPickle as pkl

DUMP = "DIR/TO/CACHE"
n = 20_000

DATA = {i:{j:{h:(0,0) for h in [16, 32, 64, 128]} for j in [3]} for i in [10, 20, 30, 40]}

for v_ in [10, 20, 30, 40]:
    for t_ in [3]:
        for h_ in [16, 32, 64, 128]:
            try:
                tv = []
                for tr in range(5):
                    
                    file = "v={}_t={}_n={}_h={}_trial={}.pt".format(v_, t_, n, h_, tr)    
                    model = LSTMTextCompletion(v_, 32, h_)
                    model = torch.load(os.path.join(DUMP, file)).to(0)
                    model.eval()
                    transition_matrix_ = np.zeros((v_, v_))      
                                
                    softmax = torch.nn.Softmax(-1)
                    
                    np.random.seed(0)
                    vocabulary = np.stack([str(i) for i in range(v_)])
                    transition_matrix = np.abs(np.random.normal(0, 1, (v_, v_)))
                    for i in range(len(transition_matrix)):
                        transition_matrix[i] /= transition_matrix[i].sum()

                    for i in range(v_):
                        token = torch.tensor([i], dtype=torch.int).unsqueeze(0).to(0)
                        x, _ = model.lstm(model.embedding(token))
                        x = softmax(model.linear(x)).squeeze(0).detach().cpu().numpy()
                        transition_matrix_[i] = x

                    def get_p_3(matrix):
                        p = np.zeros([v_ for _ in range(t_)])
                        for i in range(v_):
                            for j in range(v_):
                                for k in range(v_):
                                    p[i, j, k] = 1/v_ * matrix[i, j] * matrix[j, k]
                        return p 
                    
                    P = get_p_3(transition_matrix)
                    P_ = get_p_3(transition_matrix_)
                    
                    tv.append(np.abs(P-P_).sum() * 0.5)
                    
                tv = np.stack(tv)
                DATA[v_][t_][h_] = (tv.mean(), tv.std())
                
            except:
                continue

with open('data.pkl', 'wb') as f:
    pkl.dump(DATA, f)