import polars as pl
import numpy as np
import itertools
from tqdm import tqdm

from utilities import cartesian_product, balance_df, check_path
import argparse
import const_configs as ccf

def generate_onehot_dict(n, inc_empty=False):
    """
    One-hot vectors generation
    """
    elements    = {"element": [*[f"{i+1}" for i in range(n)]]}
    encodings   = pl.DataFrame(elements).to_dummies().to_numpy().tolist()
    if inc_empty:
        encodings   = [[0]*n] + encodings
        return {k:v for k,v in enumerate(encodings)}
    return {k+1:v for k,v in enumerate(encodings)}

def generate_onehot_encoding(analogy, s_dict, c_dict):
    analogy_encodings = []
    # images  = [item for pair in analogy for item in pair]
    for image in analogy:
        image_encodings = []
        for _s, _c in image:
            s_e = s_dict[_s]
            c_e = c_dict[_c]
            image_encodings.append((s_e,c_e))
        analogy_encodings.append(image_encodings)
    return analogy_encodings

def generate_cells(n_shapes, n_colours, compensators=[1,1]):
    com_s, com_c    = compensators
    cells   = []
    for s in range(0, n_shapes + com_s):
        if s != 0:
            for c in range(1, n_colours + com_c):
                cells.append((s,c))
        else:
            cells.append((s, 0))
    return cells


def validate_analogies(analogies):
    # This one goes as long as tf != ft
    is_valid    = []
    for (a, b), (c, d) in tqdm(analogies, desc="Evaluating analogies"):
        diff_ab = np.array(a) - np.array(b)
        diff_cd = np.array(c) - np.array(d)
        eq_abcd = set(diff_ab == diff_cd)
        if False in eq_abcd:
            is_valid.append(0)
        else:
            is_valid.append(1)
    return is_valid

def flatten_images(images):
    flat_images = []
    for image in images:
        flat_images.append([item for cell in image for property in cell for item in property])
    return flat_images

def validate_analogies_(analogies):
    is_valid    = []
    for images in analogies:
        a, b, c, d  = flatten_images(images)
        diff_ab = np.array(a) - np.array(b)
        diff_cd = np.array(c) - np.array(d)
        eq_abcd = set(diff_ab == diff_cd)
        if False in eq_abcd:
            is_valid.append(0)
        else:
            is_valid.append(1)
    return is_valid


def generate_images(cells):
    images      = list(itertools.product(*[cells]*4))
    return images

def generate_analogies_(n_shapes, n_colours, n_images, seed=14, *args, **kwargs):
    # For reproduction purposes
    np.random.seed(seed)
    cells   = generate_cells(n_shapes, n_colours)
    images  = generate_images(cells)
    shapes_dict     = generate_onehot_dict(n_shapes, inc_empty=True)
    colours_dict    = generate_onehot_dict(n_colours, inc_empty=True)
    # Get a subset of images
    n_imgs  = len(images) if n_images > len(images) else n_images
    rand_idx    = np.random.randint(0, len(images), n_imgs)
    rand_images = [images[idx] for idx in rand_idx]
    # Generate pairs
    pairs   = list(itertools.product(*[rand_images]*2))
    dfs     = []
    for pair1 in tqdm(pairs, "Processing each pair"):
        local_records   = []
        for pair2 in pairs:
            analogy     = pair1 + pair2
            onehot_enc  = generate_onehot_encoding(analogy, shapes_dict, colours_dict)
            is_valid    = validate_analogies_([onehot_enc])[0]
            local_records.append((analogy, onehot_enc, is_valid))
        local_df    = pl.DataFrame(local_records, schema=["original_analogy", "encoded_analogy", "is_valid"], orient="row")
        dfs.append(balance_df(local_df, on="is_valid", seed=seed))
    return pl.concat(dfs)


def generate_partial_analogies_(n_shapes, n_colours, n_images, targets=[0,0], seed=14, *args, **kwargs):
    # For reproduction purposes
    np.random.seed(seed)
    t_s, t_c    = targets
    cells       = generate_cells(n_shapes, n_colours)
    if t_s:
        cells_  = [(_s,_c) for _s,_c in cells  if _s==t_s or _s==0]
    elif t_c:
        cells_  = [(_s,_c) for _s,_c in cells  if _c==t_c or _c==0]
    images      = generate_images(cells_)
    shapes_dict     = generate_onehot_dict(n_shapes, inc_empty=True)
    colours_dict    = generate_onehot_dict(n_colours, inc_empty=True)
    # Get a subset of images
    n_imgs  = len(images) if n_images > len(images) else n_images
    rand_idx    = np.random.randint(0, len(images), n_imgs)
    rand_images = [images[idx] for idx in rand_idx]
    # Generate pairs
    pairs   = list(itertools.product(*[rand_images]*2))
    dfs     = []
    for pair1 in tqdm(pairs, "Processing each pair"):
        local_records   = []
        for pair2 in pairs:
            analogy     = pair1 + pair2
            onehot_enc  = generate_onehot_encoding(analogy, shapes_dict, colours_dict)
            is_valid    = validate_analogies_([onehot_enc])[0]
            local_records.append((analogy, onehot_enc, is_valid))
        local_df    = pl.DataFrame(local_records, schema=["original_analogy", "encoded_analogy", "is_valid"], orient="row")
        dfs.append(balance_df(local_df, on="is_valid", seed=seed))
    return pl.concat(dfs)


def generate_image_dict(n_shapes, n_colours, seed=14):
    np.random.seed(seed)
    cells   = generate_cells(n_shapes, n_colours)
    images  = generate_images(cells)
    shapes_dict     = generate_onehot_dict(n_shapes, inc_empty=True)
    colours_dict    = generate_onehot_dict(n_colours, inc_empty=True)
    onehot_images   = flatten_images(generate_onehot_encoding(images, shapes_dict, colours_dict))
    return pl.DataFrame([images, onehot_images], schema=["original_image","encoded_image"])

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="  ")
    parser.add_argument("n_shapes", type=int, default=3, help="")
    parser.add_argument("n_colours", type=int, default=3, help="")
    parser.add_argument("--n_images", type=int, default=81, help="")
    parser.add_argument("--targets", type=str, default="0,0", help="")
    parser.add_argument("--seed", type=int, default=14, help="")
    parser.add_argument("--export", type=int, default=1, help="")
    parser.add_argument("--export_path", type=str, default="data/", help="Export path")

    args    = parser.parse_args()
    arg_n_shapes    = args.n_shapes
    arg_n_colours   = args.n_colours
    arg_n_images    = args.n_images
    arg_targets     = args.targets
    arg_seed        = args.seed
    arg_export      = args.export
    arg_export_path = args.export_path

    targets = [int(item) for item in arg_targets.split(",")]
    _generator  = generate_analogies_
    if np.sum(targets) > 0:
        _generator  = generate_partial_analogies_
    df  = _generator(arg_n_shapes, arg_n_colours, arg_n_images, seed=arg_seed, targets=targets)
    df  = df.sample(fraction=1, shuffle=True, seed=arg_seed)
    df_dict = generate_image_dict(arg_n_shapes, arg_n_colours, arg_seed)
    
    if arg_export:
        portions    = np.linspace(0, df.height, 11, dtype=int) # Get 10 equal portion ranges
        df_train    = df[:portions[7]]
        df_dev      = df[portions[7]:portions[-3]]
        df_test     = df[portions[-3]:]
        check_path(arg_export_path)
        df_train.write_json(f"{arg_export_path}train.json")
        df_dev.write_json(f"{arg_export_path}dev.json")
        df_test.write_json(f"{arg_export_path}test.json")
        df_dict.write_json(f"{arg_export_path}dict.json")
