import polars as pl
import numpy as np
import transformers
import torch
import argparse
from tqdm import tqdm

from utilities import refine_directory_path, check_path
import const_prompts as cp


def restore_analogy_from_vector(vec_in:torch.Tensor):
    return vec_in.reshape(vec_in.shape[0], 4, -1)

def match_onehot_to_index(vector:torch.Tensor, compensator=1):
    if 1 in vector:
        found_index = (vector==1).nonzero().squeeze().item()
        return found_index + compensator
    return 0


def match_element(image:torch.Tensor):
    delimiter   = int(image.shape[0]/4/2)
    cells   = image.reshape(4,-1,delimiter)
    decoded_cells   = []
    for v_shape, v_colour in cells:
        shape   = match_onehot_to_index(v_shape)
        colour  = match_onehot_to_index(v_colour)
        decoded_cells.append((shape,colour))
    return decoded_cells


def generate_vector_prompt(analogies:torch.Tensor, analogy_template):
    analogies   = restore_analogy_from_vector(analogies)
    prompts     = []
    for analogy in analogies:
        prompts.append(analogy_template.format(*analogy.tolist()))
    return prompts


def generate_text_prompt(_analogies:torch.Tensor, analogy_template, image_template, shapes:dict, colours:dict):
    analogies   = restore_analogy_from_vector(_analogies)
    prompts     = []
    for images in analogies:
        elements    = [match_element(image) for image in images]
        description = [image_template.format(*[f"{colours[c]} {shapes[s]}" for s,c in element]) for element in elements]
        prompts.append(analogy_template.format(*description).replace("blank empty", "empty"))
    return prompts


def process_llm_result(result_path:str):
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score
    dfr     = pl.read_json(result_path)
    y_ref   = dfr.select("is_valid").to_series()
    y_res   = dfr.select("prediction").to_series()
    y_pred  = [1 if item.lower()=="yes" else 0 for item in y_res]
    prfs    = precision_recall_fscore_support(y_ref, y_pred, labels=[0,1])
    acc     = accuracy_score(y_ref, y_pred)
    return prfs, acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Try zeroshot LLM classification approaches")
    parser.add_argument("data_dir", type=str, default="data/", help="Analogy directory")
    parser.add_argument("--mode", type=str, choices=["text","vector"], default="text", help="Prompt mode")
    parser.add_argument("--n_samples", type=int, default=-1, help="Choose the number of samples")
    parser.add_argument("--prompts", type=str, default="4,2,3", help="Choose the prompt for messages/descriptions")
    parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="llm")
    parser.add_argument("--export", type=int, default=1, help="Choose whether or not to export the model")
    parser.add_argument("--export_path", type=str, default="results/llm/cls/zeroshot/", help="Export model path")

    args = parser.parse_args()

    arg_data_dir        = refine_directory_path(args.data_dir)
    arg_mode            = args.mode
    arg_n_samples       = args.n_samples
    arg_prompts         = args.prompts
    arg_model           = args.model
    arg_export          = args.export
    arg_export_path     = args.export_path

    input_path  = f"{arg_data_dir}test.json"
    df          = pl.read_json(input_path)
    analogies   = torch.stack([item.to_torch() for item in df.select("analogy").to_series()])
    prompt_inst, prompt_img, prompt_alg = [int(item) for item in arg_prompts.split(",")]

    if arg_mode == "vector":
        prompts = generate_vector_prompt(analogies[:arg_n_samples],
                                       cp.ANALOGY_TEMPLATES[prompt_alg])
    else:
        prompts = generate_text_prompt(analogies[:arg_n_samples], 
                                       cp.ANALOGY_TEMPLATES[prompt_alg], 
                                       cp.IMAGE_TEMPLATES[prompt_img],
                                       shapes=cp.SHAPES, 
                                       colours=cp.COLOURS)

    instruction = cp.INSTRUCTIONS[prompt_inst]
    pipeline    = transformers.pipeline("text-generation", 
                                        model=arg_model, 
                                        model_kwargs={"torch_dtype": torch.bfloat16}, 
                                        device_map="auto")
    
    responses   = []
    for question in tqdm(prompts, desc="LLM Answering"):
        messages    = [{"role": "system","content": instruction}, 
                       {"role": "user","content": question}]
        llm_out     = pipeline(messages, max_new_tokens=1, pad_token_id=pipeline.tokenizer.eos_token_id)
        responses.append(llm_out[0]["generated_text"][-1]["content"])
    
    df_response = df[:arg_n_samples].with_columns(pl.Series(responses).alias("prediction"))
    if arg_export:
        filepath    = check_path(arg_export_path)
        filename    = arg_prompts.replace(",", "")
        df_response.write_json(filepath + f"{filename}.json")
    
    
