

import os
from copy import deepcopy
from typing import Union

import fire
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from embedding_utils import (
    load_cisr_model,
    load_sd_model,
    load_luar_model_and_tokenizer,
    load_sbert_model,
    get_instance_embeddings,
    get_author_embeddings,
)

def flatten(lst):
    return [item for sublist in lst for item in sublist]

def flatten_transfer_text(
    df: pd.DataFrame
) -> Union[list[str], np.ndarray, list[int]]:
    """Flatten the transfer_text column of the DataFrame into a list of strings.
       Returns the flattened list, the cumulative lengths of the transfer_text column,
         and the lengths of the individual transfer_text lists.
    """
    # transfer_text: list[list[str]]
    lengths = np.cumsum([0] + df["transfer_text"].apply(len).tolist())
    transfer_text = df["transfer_text"].tolist()
    transfer_text = flatten(transfer_text)
    lengths_individual = [len(t) for t in transfer_text]
    transfer_text = flatten(transfer_text)
    return transfer_text, lengths, lengths_individual

def get_reference_embeddings(
    df: pd.DataFrame,
    model_kwargs: dict,
    model_name: str,
):
    author_embeddings = [
        get_author_embeddings(
            text,
            model_kwargs,
            model_name
        ) for text in tqdm(df["reference"].tolist())
    ]
    author_embeddings = torch.cat(author_embeddings, dim=0)
    return author_embeddings

def get_transfer_text_embeddings(
    transfer_text: list[str],
    lengths: np.ndarray,
    lengths_individual: list[int],
    model_kwargs: dict,
    model_name: str,
) -> list[torch.Tensor]:
    transfer_embeddings = get_instance_embeddings(
        transfer_text,
        model_kwargs,
        model_name,
    )
    transfer_embeddings = torch.split(transfer_embeddings, lengths_individual)
    transfer_embeddings = [transfer_embeddings[i:j] for i, j in zip(lengths[:-1], lengths[1:])]
    return transfer_embeddings

def pick_best_style(
    df: pd.DataFrame,
    luar,
    luar_tok,
    style_model_name: str = "mud",
):
    """Pick the generation that maximizes the Style Similarity (LUAR-MUD)
    """
    author_embeddings = get_reference_embeddings(
        df,
        {"luar": luar, "luar_tok": luar_tok},
        style_model_name,
    )
    
    transfer_text, lengths, lengths_individual = flatten_transfer_text(df)
    transfer_embeddings = get_transfer_text_embeddings(
        transfer_text,
        lengths,
        lengths_individual,
        {"luar": luar, "luar_tok": luar_tok, "progress_bar": True},
        style_model_name,
    )
    
    to_return = {}
    cossim = torch.nn.CosineSimilarity(dim=-1)
    for num_generations in [1, 10, 25, 50, 100]:
        best_style = [[] for _ in range(len(df))]
        i = 0
        for author, transfer in zip(author_embeddings, transfer_embeddings):
            j = 0
            for t in transfer:
                cossim_scores = cossim(author.unsqueeze(0), t[:num_generations])
                best_index = torch.argmax(cossim_scores)
                best_style[i].append(df.iloc[i]["transfer_text"][j][best_index])
                j += 1
            i += 1
        df["transfer_pick"] = best_style
        to_return[f"best_style_N={num_generations}"] = deepcopy(df)

    return to_return

def pick_average(
    df: pd.DataFrame,
    luar,
    luar_tok,
    sbert,
    style_weight: float = None,
    content_weight: float = None,
    style_model_name: str = "mud",
):
    luar_author_embeddings = get_reference_embeddings(
        df,
        {"luar": luar, "luar_tok": luar_tok},
        style_model_name,
    )
    
    transfer_text, lengths, lengths_individual = flatten_transfer_text(df)
    luar_transfer_embeddings = get_transfer_text_embeddings(
        transfer_text,
        lengths,
        lengths_individual,
        {"luar": luar, "luar_tok": luar_tok, "progress_bar": True},
        style_model_name,
    )

    sbert_transfer_embeddings = get_transfer_text_embeddings(
        transfer_text,
        lengths,
        lengths_individual,
        {"model": sbert, "progress_bar": True},
        "sbert",
    )
    
    sbert_content_embeddings = get_instance_embeddings(
        [item for item in df["content_text"].tolist() for item in item],
        {"model": sbert, "progress_bar": True},
        "sbert"
    )
    sbert_content_embeddings = [sbert_content_embeddings[i:j] for i, j in zip(lengths[:-1], lengths[1:])]

    style_weights = [float(x) for x in range(0, 2+1)]
    content_weights = [float(x) for x in range(0, 2+1)]
    if style_weight:
        style_weights = [float(style_weight)]
    if content_weight:
        content_weights = [float(content_weight)]
    to_return = {}
    
    cossim = torch.nn.CosineSimilarity(dim=-1)
    for num_generations in [1, 10, 25, 50, 100]:
        for stylew in style_weights:
            for contentw in content_weights:
        
                best = [[] for _ in range(len(df))]
                i = 0
                for author, transfer_luar, transfer_sbert, content_sbert in \
                    zip(luar_author_embeddings, luar_transfer_embeddings, 
                        sbert_transfer_embeddings, sbert_content_embeddings):
                    j = 0
                    for tluar in transfer_luar:
                        cossim_luar = stylew * cossim(author.unsqueeze(0), tluar[:num_generations])
                        cossim_sbert = contentw * cossim(content_sbert[j:j+1], transfer_sbert[j][:num_generations])

                        scores = cossim_luar + cossim_sbert
                        best_index = torch.argmax(scores)
                        best[i].append(df.iloc[i]["transfer_text"][j][best_index])
                        j += 1
                    i += 1

                df["transfer_pick"] = best
                to_return[f"average_style={stylew}_content={contentw}_N={num_generations}"] = deepcopy(df)

    return to_return
    
def main(
    data_path: str = "./outputs/stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=100_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3.jsonl",
    selection_strategy: str = "average",
    reference_key: str = "reference_text",
    style_weight: float = None,
    content_weight: float = None,
):
    df = pd.read_json(data_path, lines=True)
    df.rename(columns={reference_key:"reference"}, inplace=True)

    luar, luar_tok = load_luar_model_and_tokenizer("rrivera1849/LUAR-MUD")
    luar.to("cuda")
    sbert = load_sbert_model()
    sbert.to("cuda")
    cisr = load_cisr_model()
    cisr.to("cuda")
    sd = load_sd_model()
    sd.to("cuda")

    if selection_strategy == "best_style":
        name_to_df = pick_best_style(df, luar, luar_tok)
    elif selection_strategy == "average":
        name_to_df = pick_average(df, luar, luar_tok, sbert, style_weight, content_weight)
    else:
        raise ValueError("Invalid selection strategy")

    # Save Transfer Pick:
    os.makedirs("./outputs/transfer_pick", exist_ok=True)
    filename = os.path.basename(data_path)
    filename = filename.replace(".jsonl", "")
    for key, value in name_to_df.items():
        current_filename = filename + "_" + key + ".jsonl"
        savepath = os.path.join("./outputs/transfer_pick", current_filename)
        value.to_json(savepath, orient="records", lines=True)

    return 0

if __name__ == "__main__":
    fire.Fire(main)