import torch
import jsonlines
import random
from models import VariationalIRT
import re


def infer(ckpt_path, infile_path, outfile_path):
    model_virt = VariationalIRT()
    state_dict = torch.load(ckpt_path)
    model_virt.load_state_dict(state_dict)
    model_virt.eval()

    of = jsonlines.open(outfile_path, 'w')
    with jsonlines.open(infile_path, 'r') as f:
        data = [o for o in f]
    y = torch.tensor([o['response'] for o in data])
    with torch.no_grad():
        d = model_virt.predict_d(y).tolist()
    for i in range(len(data)):
        data[i]['parameters'] = d[i]
        del data[i]['response']
        if 'virtue' in ckpt_path:
            data[i]['prompt'] = re.search('.+(?= Which of the)', data[i]['prompt']).group(0)
        elif 'commonsense' in ckpt_path:
            data[i]['prompt'] = data[i]['prompt'][20:-93]
        elif 'justice' in ckpt_path:
            data[i]['prompt'] = data[i]['prompt'][20:-27]
        of.write(data[i])
    of.close()

def check_parameter_consistency(ckpt_path, infile_path, n_samples=100, n_shuffle=10):
    model_virt = VariationalIRT()
    state_dict = torch.load(ckpt_path)
    model_virt.load_state_dict(state_dict)
    model_virt.eval()

    with jsonlines.open(infile_path, 'r') as f:
        data = [o for o in f]
    samples = random.sample([o['response'] for o in data], n_samples)
    y = []
    for s in samples:
        for _ in range(n_shuffle):
            random.shuffle(s)
            y.append(s[:])
    y = torch.tensor(y)

    with torch.no_grad():
        d = model_virt.predict_d(y).view(n_samples, n_shuffle, -1)
    std = d.std(dim=1)
    print('Mean:', std.mean(dim=0).tolist())
    print('Max:', std.max(dim=0)[0].tolist())

def check_ability_consistency(ckpt_path, infile_path):
    model_virt = VariationalIRT()
    state_dict = torch.load(ckpt_path)
    model_virt.load_state_dict(state_dict)
    model_virt.eval()

    y = torch.tensor([[1,0,1],[1,1,1],[0,0,0],[1,0,0]])
    d = torch.tensor([[0.5,3.2],[-1.0,4.3],[2.0,2.2],[2.5,3.8]])
    with torch.no_grad():
        a = model_virt.predict_a(y, d)
    print(a)

if __name__ == '__main__':
    _type = 'commonsense'
    split = 'train'
    ckpt_path = f''
    infile_path = f''
    outfile_path = f''
    infer(ckpt_path, infile_path, outfile_path)
    # check_ability_consistency(ckpt_path, infile_path)