import pandas as pd
import os
import json
import inspect
import re
import datetime
import functools
import hashlib

class TaskResult:
    def __init__(self, path):
        self.task = path.split("/")[-1].split("-")[0]
        self.path = path
        with open(path, "r") as f:
            self.metadata = json.load(f)
        self.metadata.update({"task": self.task})
        self._data = None

    def data(self):
        if self._data is None:
            self._data = pd.read_csv(self.path.replace(".json", ".csv"))
        return self._data

    def __repr__(self):
        return f"<TaskResult>"

def kwstring(kw):
    if type(kw) == str: return kw
    return "{" + str(",".join(sorted([f"'{k}': {v}" for k,v in kw.items()]))) + "}"

DEFAULT_COLS = ["task", "model", "decoder", "kwargs", "num_samples", "cost", "timestamp", "result", "hash"]

def results(selected_cols=None, additional=None, finished=None, **kwargs):
    if selected_cols is None: selected_cols = DEFAULT_COLS.copy()
    if additional is None: additional = []

    latest = kwargs.pop("latest", False)
    
    # if querying by kwargs, convert them to a canonical string representation
    eval_kwargs = kwstring(kwargs.pop("kwargs", {}))
    if eval_kwargs != "{}":
        kwargs.update({"kwargs": eval_kwargs})
    
    metadata_files = [r for r in os.listdir("results") if r.endswith(".json")]
    # filter by .csv file also existing
    metadata_files = [r for r in metadata_files if os.path.exists("results/" + r.replace(".json", ".csv"))]

    task_results = [TaskResult(os.path.join("results", r)) for r in metadata_files]
    df = pd.DataFrame([{**r.metadata, "result": r} for r in task_results])

    df["hash"] = df["timestamp"].apply(lambda t: hashlib.sha256(t.encode("utf-8")).hexdigest()[0:8])

    # make sure task is first column
    cols = df.columns.tolist()
    if "task" in cols:
        cols = ["task"] + [c for c in cols if c != "task"]
    df = df[cols]

    def is_included_by(cv, value):
        if callable(value):
            return value(cv)
        elif type(value) is str:
            return value in cv
        else:
            return value == cv

    if len(df) == 0: return df

    # filter by kwargs
    for k, v in kwargs.items(): 
        if k in df.columns:
            df = df[df[k].apply(functools.partial(is_included_by, value=v))]

    if len(df) == 0: return df
    df["kwargs"] = df["kwargs"].apply(kwstring)

    # group by all columns except "timestamp", "data" and "result"
    if latest != False:
        group_cols = [c for c in df.columns if c not in ["timestamp", "data", "result", "cost"]]
        # replace "kwargs" with "kwargs_str" for grouping
        df["kwargs_str"] = df["kwargs"].apply(kwstring)
        group_cols = [c if c != "kwargs" else "kwargs_str" for c in group_cols]

        df = df.groupby(group_cols).agg({"timestamp": "max", "result": "first"})
        # per group only keep highest timestamp
        df = df[df.groupby(group_cols)["timestamp"].transform(max) == df["timestamp"]]
        df = df.reset_index()

    # if empty dataframe, return it
    if len(df) == 0:
        return df

    # add additional column values
    for col_computer in additional:
        if type(col_computer) == str and col_computer in df.columns:
            selected_cols.append(col_computer)
            continue
        df[col_computer.__name__] = col_computer(df)
        if not col_computer.__name__ in selected_cols:
            selected_cols.append(col_computer.__name__)

    # filter again by kwargs
    for k, v in kwargs.items():
        if k == "maximum_age":
            if type(v) is int:
                oldest_time = datetime.datetime.now() - datetime.timedelta(minutes=v)
            elif type(v) is datetime.datetime:
                oldest_time = v
            
            df = df[df["timestamp"].apply(lambda t: datetime.datetime.strptime(t, "%Y:%m:%d_%H:%M:%S") > oldest_time)]
            continue

        if k in df.columns:
            df = df[df[k].apply(functools.partial(is_included_by, value=v))]

    # only keep the columns we want
    selected_cols = [c for c in selected_cols if c in df.columns]
    df = df[selected_cols]

    if finished == True:
        df = df[df["cost"].apply(lambda c: not "-1" in str(c))]

    return df

def data(df, fct=None, nounwrap=False):
    """Returns the data for the provided list of task results of a results() call"""
    if type(df) is str:
        df = results(hash=df)
    
    if type(df) is pd.Series:
        df = df.to_frame().T
    
    if not "result" in df.columns:
        print(df.columns)
        raise Exception("data() can only be called on a dataframes with 'result' comments.")

    r = [result.data() for result in df["result"].tolist()]
    for rdf in r:
        task_names = []
        for t in rdf["task_name"].values:
            if "EvalSampleTask" in t:
                try:
                    t = t.split("task_name='")[1].split("'")[0]
                except:
                    t = "<could not parse from evaluation file>"
            task_names.append(t)
        rdf["task_name"] = task_names

    if fct is not None:
        r = [fct(d) for d in r] if fct is not None else r

    if len(r) == 1 and not nounwrap: 
        return r[0]
    else:
        return r

def path(df):
    """Returns the path for the provided list of task results of a results() call"""
    return [r.path for r in df["result"].iloc]

def errors(df):
    def invalid_prediction(p):
        return str(p.prediction) == "<error>" or str(p.prediction) == "<none>" or "not parse" in str(p.prediction) or p.model_result == "<none>" or p.model_result == "<error>"
    return [sum(invalid_prediction(s) for s in d.iloc) for d in data(df, nounwrap=True)]

def compare(h1, h2):
    df1 = data(h1).sort_values(by=["query_file"])
    df2 = data(h2).sort_values(by=["query_file"])
    
    for r1, r2 in zip(df1.iloc, df2.iloc):
        assert os.path.basename(r1.query_file) == os.path.basename(r2.query_file), f"Failed to match {r1.query_file} and {r2.query_file}"
        
        if r1.prediction != r1.target and r2.prediction == r2.target:
            # print(r1.prediction, r1.target, r2.prediction, r2.target)
            yield r1, r2
        

def correct(df, t=str):
    result = []
    names = []
    task_names = df["task"].values
    for d in data(df, nounwrap=True):
        d = d.sort_values(by=["query_file"])
        names.append(d["query_file"].values[0])
        result.append((d["target"].apply(t) == d["prediction"].apply(t)))
    # concantenate all results along new axis
    # pd.concat(result, axis=1, keys=names)
    return pd.concat(result, axis=1, keys=task_names)
    

def pbool(s):
    if type(s) is bool:
        return s
    # like bool() but parses "True" and "False" fuzzily
    if s == "True" or s == "true" or s == "1" or s == 1:
        return True
    elif s == "False" or s == "false" or s == "0" or s == 0:
        return False
    # if True occurs before False, return True
    elif "True" in s and "False" in s:
        return s.index("True") < s.index("False")
    elif "True" in s:
        return True
    elif "False" in s:
        return False
    else:
        return bool(s)

def accuracy(df, t=str):
    """Returns the accuracy for the provided list of task results of a results() call"""
    # for p in [t.path for t in df["result"].values]:
    #     print(p)
    return [(d["target"].apply(t) == d["prediction"].apply(t)).mean() for d in data(df, nounwrap=True)]

def num_samples(df):
    """Returns the number of samples for the provided list of task results of a results() call"""
    return [len(d) for d in data(df, nounwrap=True)]

def shots(df):
    return df.apply(lambda r: data(r)["shots"].iloc[0], axis=1)

def evaluation_time(df):
    def format_timespan_since(t):
        # convert timestampt to "1d 2h 3m 4s" format
        # leave d h m s out if they are 0
        t = datetime.datetime.now() - t
        weeks = t.days // 7
        days = t.days % 7
        hours = t.seconds // 3600
        minutes = (t.seconds // 60) % 60
        seconds = t.seconds % 60
        result = ""
        if weeks > 0: result += f"{weeks}w "
        if days > 0: result += f"{days}d "
        if hours > 0: result += f"{hours}h "
        if minutes > 0: result += f"{minutes}m "
        if seconds > 0: result += f"{seconds}s"
        return result + " ago"
    # timestamp format is 2023:01:19_12:42:05
    return df["timestamp"] \
        .apply(lambda t: datetime.datetime.strptime(t, "%Y:%m:%d_%H:%M:%S")) \
        .apply(lambda t: format_timespan_since(t))

def filename(df):
    """Returns the filename for the provided list of task results of a results() call"""
    return [r.path for r in df["result"].iloc]

def no_response(df):
    """Returns the number of samples for the provided list of task results of a results() call"""
    return [(d["prediction"] == "").mean() for d in data(df, nounwrap=True)]

def clean_up_dangeling_json(results_dir="results/"):
    to_remove = []
    for f in os.listdir(results_dir):
        if f.endswith(".json") and not os.path.exists(results_dir + f.replace(".json", ".csv")):
            to_remove.append(f)
    
    for f in to_remove:
        # parse timestamp from "<filename>-2023:01:19_12:05:40.json"
        timestamp = f.split("-")[-1].split(".")[0]
        # check if timestamp is older than 1 day
        file_time = datetime.datetime.strptime(timestamp, "%Y:%m:%d_%H:%M:%S")
        if file_time < datetime.datetime.now() - datetime.timedelta(minutes=1):
            print("Removing", f)
            os.remove(results_dir + f)
        else:
            print("Not removing", f, "because it is too recent", file_time, "vs", datetime.datetime.now() - datetime.timedelta(minutes=1))

def prompt_snippet(df, index=0):
    for d in data(df, nounwrap=True):
        d = d.sort_values(by=["query_file"])

        for s in [d.iloc[index]]:
            q = s["query"].replace("\\", "\\\\")
            q = q.split(")", 1)[1].strip()
            q = "\n".join([l.strip() for l in q.split("\n") if l.strip() != ""])
            q = q.replace("$", "\$")

            if "WHERE" in q:
                q, condition = q.split("WHERE", 1)
            else:
                q = q
                condition = ""
            q = q.split("FROM", 1)[0]
            
            prompting_method = "unknown"
            if "multivar" in s.task_name:
                prompting_method = "Multi-Variable"
            elif "cot" in s.task_name:
                prompting_method = "Chain-Of-Thought"
            elif "@ao" in s.task_name:
                prompting_method = "Answer-Only"
            prompting_method = f"\\textit{{{prompting_method}}}"

            task_name = s.task_name.replace("_", "\\_")
            # strip off _small, _medium, _large from the end of the task name
            task_name = task_name.rsplit("_", 1)[0]

            print(task_name)
            print("&")
            print(prompting_method, end="&\n")

            query_template = """
\\textbf{Query}
\\begin{lstlisting}[mathescape=false, breaklines=true, breakindent=0em]
<query>
\end{lstlisting}"""

            condition_template = """\\textbf{Condition}
\\begin{lstlisting}[mathescape=false, breaklines=true, breakindent=0em]
<condition>
\end{lstlisting}
"""
            model_response_template = """\\textbf{Model Response} 
<model response> <correct>
"""
            # make template
            r = query_template
            if condition.strip() != "":
                r += condition_template
            r += model_response_template
            
            r = r.replace("<query>", q.strip()).replace("<condition>", condition.strip()).replace("<model response>", s.model_result.replace("$", "\$"))
            r = r.replace("<correct>", "\\cmark" if s.target == s.prediction else "\\xmark")
            print(r)
            break

def query(df):
    """Returns the query for the provided list of task results of a results() call"""
    res = [d.sort_values("query_file")["query"].values[5] for d in data(df, nounwrap=True)]
    for q,name in zip(res, df["task"]):
        print(name, "\n", q, "\n", sep="")
    return res

def model_result(df):
    """Returns the query for the provided list of task results of a results() call"""
    res = [d.sort_values("query_file")["model_result"].values[5] for d in data(df, nounwrap=True)]
    for q,name in zip(res, df["task"]):
        print(name, "\n", q, "\n", sep="")
    return res

def named(df):
    def name(r):
        t = r["task"]
        s = r["shots"]
        if "cot" in t: return f"CoT ({s}-shot)"
        if "instruct" in t: return f"Instruct ({s}-shot)"
        if "ao" in t: return f"Answer-Only ({s}-shot)"
        else: return f"Control-Flow Guided ({s}-shot)"
    r = df.copy()
    r["name"] = r.apply(name, axis=1)
    r = r.set_index("name")
    return r

def table_all(hashes, accuracy_measure=accuracy):
    files = []
    
    for dataset, configurations in hashes.items():
        acc_measure = accuracy_measure
        accuracy_name = "accuracy"
        if callable(configurations[0]):
            acc_measure = configurations[0]
            accuracy_name = "acc"
            configurations = configurations[1:]
        
        df = results(hash=lambda h: h in configurations, additional=[acc_measure])
        
        result_files = [r.path.replace(".json", ".csv") for r in df["result"].values]
        # order
        def index_of_filename(f):
            if "@ao" in f: return 0
            if "@cot" in f: return 1
            if "@multivar" in f: 
                if "-var" in f: 
                    return 3
                else: 
                    return 2
            return 4
        result_files = sorted(result_files, key=lambda f: index_of_filename(f))
        files += [dataset] + result_files
        
        df = df[["task", "decoder", accuracy_name, "hash"]]
        # sort by order in configurations
        df = df.sort_values(by=["hash"], key=lambda c: [configurations.index(h) for h in c])
        
        
        decoders = df["decoder"].values
        task = df["task"].values
        common_prefix_in_task = None
        for t in task:
            if common_prefix_in_task is None:
                common_prefix_in_task = t
            else:
                common_prefix_in_task = os.path.commonprefix([common_prefix_in_task, t])
        
        decoders = ",".join([t[len(common_prefix_in_task):] + " " + d for t,d in zip(task, decoders)])
        accuracies = df[accuracy_name].values
        accuracies = ",".join([f"{a:.2f}".rjust(10) for a in accuracies])
        
        # print(f"{decoders}")
        print(f"{dataset} , {accuracies}")
    print("\nFiles:")
    print("\n".join(files))

        

def table(hashes, accuracy_measure=accuracy):
    from IPython.display import display, HTML
    hash_values = set()
    names = {}
    for name, h in hashes.items():
        if type(h) is list:
            for hh in h:
                hash_values.add(hh)
                names[hh] = name
        else:
            hash_values.add(h)
            names[h] = name

    def accuracy(samples):
        return accuracy_measure(samples)

    df = results(hash=lambda h: h in hash_values, additional=[accuracy])
    df["name"] = df["hash"].apply(lambda h: names[h])

    df = df[["name", "decoder", "accuracy", "task"]]
    # make name a common row 
    df = df.set_index(["name", "decoder"])
    # sort by decoder within name
    df = df.sort_index(level=1)
    # sort by name
    df = df.sort_index(level=0)

    # highlight best "accuracy" col along column in color
    df = df.style.apply(lambda s: ["background-color: lightgreen" if v == s.max() else "" for v in s], axis=0)

    md = df.to_html()
    # embed in HTML table and display in IPython
    display(HTML(md))

def qa(*args):
    import termcolor
    for s in data(*args).iloc:
        print(termcolor.colored(s["query"], "green"))
        print("PREDICTION", s.prediction, "TARGET", s.target)
        print(termcolor.colored(s.model_result, "red"), end="\n\n")