# %%
"""
Generates...
"""
from random import choice
import pandas as pd
import numpy as np
from time import time
from scipy.spatial.distance import jaccard
from datetime import datetime
from src.dataset import grid_to_one_hot, get_dataset
from src.verify_LST import verify_LST
import src.generate_sudoku as gen_sudoku


def calc_puzzle_dist(puzzle, dataset):
    '''
    Calculates the Jaccard-Needham dissimilarity between a
    given puzzle and a pytorch dataset
    1 = highly dissimilar, 0 = very similar
    - puzzle is a numpy grid of a single puzzle
    - dataset is a pytorch dataset of puzzles

    '''
    if len(puzzle) == 16:
        # convert current item to one hot and reshape to -1 (80)
        puzzle_vec = grid_to_one_hot(puzzle.reshape(4, 4)).reshape(-1)

    elif puzzle.shape == (4, 4, 5):
        # already one hot encoded
        puzzle_vec = puzzle.reshape(-1)

    # get test puzzles as one hot vectors with shape -1 (80)
    # and calculate the jaccard distance (1 = good, 0 = bad)
    j_dist = np.zeros((len(dataset)))
    for trial in range(len(dataset)):
        if isinstance(dataset, np.ndarray):
            # the dataset is a numpy array of
            test_vec = grid_to_one_hot(
                dataset[trial, :].reshape(4, 4)).reshape(-1)
        else:
            # we assume a pytorch dataset
            test_vec = dataset[trial][0].reshape(-1).numpy()
        j_dist[trial] = jaccard(puzzle_vec, test_vec)
    return j_dist


def save_puzzles(puzzles, solutions, train_condition, distance_cutoff):
    # save in a consistent dataframe format
    df = pd.DataFrame(
        columns=['ID', 'LST_num', 'condition', 'puzzles', 'solutions'])

    for trial in range(len(puzzles)):

        # Puzzle
        puzzle_data = str(puzzles[trial, :])
        for char in ['[', ']', ' ']:
            puzzle_data = puzzle_data.replace(char, '')

        # Solution
        solution = str(int(solutions[trial]))

        # add to the dataframe
        df.loc[trial] = [trial, np.nan, train_condition, puzzle_data, solution]

    df.to_csv('../data/nn/generated_puzzle_data_'+train_condition.lower() +
              '_dist'+str(int(distance_cutoff*100))+'.csv', index=False)


def main(num, train_condition, distance_cutoff=0.5):

    print('Condition =', train_condition)
    # get the test dataset
    # this is used to test for overlap
    test_dataset = get_dataset('../data/nn/puzzle_data_original.csv')

    # preallocate
    puzzles = np.zeros((num, 16), np.int32)
    solutions = np.zeros((num, 1), np.int32)
    count = 0

    start = time()
    while count < num:
        for n in range(3, 12):
            all_results, solution = gen_sudoku.run(n=n, iter=2)
            puzzle = np.squeeze(gen_sudoku.best(all_results)).reshape(16)

            # choose a random target from a blank square
            blanks = np.where(puzzle == 0)[0]
            target = choice(blanks)
            puzzle[target] = 5

            # note the solution
            solution = np.squeeze(solution).reshape(16)[target]

            # identify condition of quiz
            condition, verified_solution, _ = verify_LST(puzzle.reshape(4, 4))

            # is this a puzzle we want?
            if condition == train_condition:

                # check the solutions match
                assert solution == verified_solution, "Solutions do not match"

                # ensure the degree of overlap between the item and
                # existing items is small via the jaccard index
                if count == 0:
                    train_dist = 1.0
                else:
                    train_dist = calc_puzzle_dist(puzzle, puzzles)

                test_dist = calc_puzzle_dist(puzzle, test_dataset)

                if np.min(train_dist) > 0.2 and np.min(test_dist) > distance_cutoff:
                    puzzles[count, :] = puzzle.copy()
                    solutions[count] = solution.copy()
                    count += 1

                    if (count+1) % 10 == 0:
                        print(
                            'Created ', count+1, ' puzzles. Time elapsed since last update = ', np.round(time() - start), '. Time now = ', datetime.now())
                        # save interim results
                        save_puzzles(puzzles, solutions,
                                     train_condition, distance_cutoff)
                        start = time()
                    
    # save final results
    save_puzzles(puzzles, solutions, train_condition, distance_cutoff)


if __name__ == "__main__":
    print('Generating puzzles...')
    n = 20
    distance_cutoff = 0.85
    main(n, 'Binary', distance_cutoff=distance_cutoff)
    main(n, 'Ternary', distance_cutoff=distance_cutoff)
    main(n, 'Quaternary', distance_cutoff=distance_cutoff)
