import polars as pl
import torch
from tqdm import tqdm
from sklearn.metrics import (
    balanced_accuracy_score, 
    accuracy_score, 
    classification_report, 
    confusion_matrix,
    precision_recall_fscore_support,
)

def cartesian_product(a:list, b:list) -> list:
    results = list()
    for ap in a:
        for bp in b:
            results.append([ap,bp])
    return results

def refine_directory_path(path:str) -> str:
    if path[-1] != "/":
        path    =  f"{path}/"
    return path

def check_path(path:str):
    from pathlib import Path
    path    = refine_directory_path(path)
    Path(path).mkdir(parents=True, exist_ok=True)
    return path

def clean_dir(path:str):
    from pathlib import Path
    Path(path).mkdir(parents=True, exist_ok=True)
    for file in Path(path).glob("*"):
        if file.is_file():
            file.unlink()

def get_files_from_directory(path:str, extension:str) -> list:
    import glob
    path    = refine_directory_path(path)
    all_files   = glob.glob(path + extension)
    return all_files

def balance_df(input_df:pl.DataFrame, on:str, seed=14):
    classes = input_df.select(on).unique().to_series()
    count   = sorted([input_df.filter(pl.col(on)==item).shape[0] for item in classes])[0]
    return pl.concat([input_df.filter(pl.col(on)==item).sample(count, seed=seed) for item in classes], how="vertical")

def sample_df_by_class(input_df:pl.DataFrame, on:str, n:int, seed=14):
    classes = input_df.select(on).unique().to_series()
    results = []
    for _class in classes:
        candidates  = input_df.filter(pl.col(on)==_class)
        if candidates.shape[0] > n:
            candidates  = candidates.sample(n, seed=seed)
        results.append(candidates)
    return pl.concat(results, how="vertical")

def count_gpus():
    return torch.cuda.device_count(), torch.mps.device_count()

def get_torch_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.mps.is_available():
        return "mps"
    return "cpu"

def restore_analogy_from_vector(vec_in:torch.Tensor):
    return vec_in.reshape(vec_in.shape[0], 4, -1)

def match_onehot_to_index(vector:torch.Tensor, compensator=1):
    if 1 in vector:
        found_index = (vector==1).nonzero().squeeze().item()
        return found_index + compensator
    return 0

def match_element(image:torch.Tensor):
    delimiter   = int(image.shape[0]/4/2)
    cells   = image.reshape(4,-1,delimiter)
    decoded_cells   = []
    for v_shape, v_colour in cells:
        shape   = match_onehot_to_index(v_shape)
        colour  = match_onehot_to_index(v_colour)
        decoded_cells.append((shape,colour))
    return decoded_cells


def generate_vector_prompt(analogies:torch.Tensor, analogy_template):
    analogies   = restore_analogy_from_vector(analogies)
    prompts     = []
    for analogy in analogies:
        prompts.append(analogy_template.format(*analogy.tolist()))
    return prompts


def generate_text_prompt(_analogies:torch.Tensor, analogy_template, image_template, shapes:dict, colours:dict):
    analogies   = restore_analogy_from_vector(_analogies)
    prompts     = []
    empty_text  = f"{colours[0]} {shapes[0]}"
    for images in analogies:
        elements    = [match_element(image) for image in images]
        description = [image_template.format(*[f"{colours[c]} {shapes[s]}" for s,c in element]) for element in elements]
        prompts.append(analogy_template.format(*description).replace(empty_text, "empty"))
    return prompts

def generate_text_prompt_from_df(df:pl.DataFrame, analogy_template, image_template, shapes:dict, colours:dict, column="original_analogy"):
    analogies   = df.select(column).to_series()
    prompts     = []
    empty_text  = f"{colours[0]} {shapes[0]}"
    for elements in tqdm(analogies, desc="Generating text prompts"):
        # elements    = [a,b,c,d]
        description = [image_template.format(*[f"{colours[c]} {shapes[s]}" for s,c in element]) for element in elements]
        prompts.append(analogy_template.format(*description).replace(empty_text, "empty"))
    return prompts

def generate_vector_prompt_from_df(df:pl.DataFrame, analogy_template, column="original_analogy"):
    analogies   = restore_analogy_from_vector(analogies)
    prompts     = []
    for analogy in analogies:
        prompts.append(analogy_template.format(*analogy.tolist()))
    return prompts


def generate_image_prompts_from_onehots(image, image_template, shapes_dict, colours_dict):
    # TODO: Not done yet
    empty       = f"{colours_dict[0]} {shapes_dict[0]}"
    original_encodings  = []
    for s, c in image.to_list():
        if type(s) == type(c) == int:
            s_idx, c_idx    = s, c
        else:
            try:
                s_idx   = s.index(1) + 1
            except ValueError:
                s_idx   = 0
            try:
                c_idx   = c.index(1) + 1
            except ValueError:
                c_idx   = 0
        original_encodings.append([s_idx,c_idx])
    image_desc  = image_template.format(*[f"{colours_dict[c]} {shapes_dict[s]}" for s,c in original_encodings])
    image_desc  = image_desc.replace(empty, "empty")
    return image_desc

# def tensor2string(input_tensors:torch.Tensor, sep=","):
#     return [sep.join([str(int(ix)) for ix in item]) for item in input_tensors]

# def string2intlist(input_strings:list, sep=","):
#     output_list = []
#     for string in input_strings:
#         output_list.append([int(item) for item in string.split(sep)])
#     return output_list

# def get_i2l_from_df(df:pl.DataFrame, column="encoded_image"):
#     col_data    = df.select(column).to_series()
#     return {i:v for i,v in enumerate(tensor2string(col_data))}

def flatten_images(images):
    flat_images = []
    for image in images:
        flat_images.append([item for cell in image for property in cell for cp in property for item in cp])
    return flat_images

def preprocess_dataset(df:pl.DataFrame):
    """
    To make dataset compatible despite version changes
    """
    if "analogy" not in df.columns and "encoded_analogy" in df.columns:
        # Version 0.1: records are no longer flattened at the beginning
        analogies   = pl.Series(flatten_images(df.select("encoded_analogy").to_series().to_list())).alias("analogy")
        return df.with_columns(analogies)
    return df

def get_metrics_result(test_df:pl.DataFrame, ref_pred_cols=["ref", "pred"], verbose=False):
    col_ref, col_pred   = ref_pred_cols
    y_test  = test_df.select(col_ref).to_series()
    y_pred  = test_df.select(col_pred).to_series()
    report  = classification_report(y_test, y_pred)
    acc     = accuracy_score(y_test, y_pred)
    bacc    = balanced_accuracy_score(y_test, y_pred)
    prfs    = precision_recall_fscore_support(y_test, y_pred)
    cm      = confusion_matrix(y_test, y_pred)
    if verbose:
        print("Classification Report:\n", report)
        print("Balanced Accuracy Score:", bacc)
        print("Accuracy Score:", acc)
        print("PRFS:\n", prfs)
    return (acc, bacc), prfs, cm, report