import random
from copy import deepcopy
from D5 import D5
from validator import DummyValidator
from lm_proposer import GPT3_Proposer
from get_representative import return_extreme_values
import pickle as pkl


def subsample(samples, n=1000):
    selected_idxes = list(range(len(samples)))
    random.shuffle(selected_idxes)
    selected_idxes = selected_idxes[:n]
    return [samples[i] for i in sorted(selected_idxes)]

def flip_problem(problem):
    problem = deepcopy(problem)
    problem['A_desc'], problem['B_desc'] = problem['B_desc'], problem['A_desc']
    problem['split'] = {
        k: {
            'A_samples': v['B_samples'],
            'B_samples': v['A_samples']
        } for k, v in problem['split'].items()
    }
    return problem


if __name__ == '__main__':
    problem = pkl.load(open('example_problem.pkl', 'rb'))
    verifier = DummyValidator()
    print('!!!!! WARNING: You are using a dummy verifier that returns random results!!!!!!!')
    proposer = GPT3_Proposer(problem)

    extreme_vals = return_extreme_values(problem['split']['research']['A_samples'], problem['split']['research']['B_samples'])
    problem['split']['research']['A_samples'] = subsample(extreme_vals['sorted_A_samples'])
    problem['split']['research']['B_samples'] = subsample(extreme_vals['sorted_B_samples'])

    d5 = D5(
        problem['split']['research']['A_samples'], 
        problem['split']['research']['B_samples'], 
        verifier,
        proposer,
        total_hypotheses_count=10,
        early_stop=True
    )
    h2h_dicts = d5.run()
    pkl.dump(h2h_dicts, open('h2h_dicts.pkl', 'wb'))


