import sys
import os
import traceback
import datetime
import tempfile
import asyncio
from typing import Optional

from multiprocessing import Pool
from tqdm import tqdm

PROJECT_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
os.chdir(PROJECT_ROOT)
print("Running in", os.getcwd())
sys.path.append(f"{PROJECT_ROOT}/lmql/src/")

import lmql
from lmql.runtime.dclib.dclib_model import DcModel
from lmql.runtime.model_registry import LMQLModelRegistry
import lmql.runtime.bopenai as bopenai

import lmql.runtime.dclib as dc
import asyncio
import subprocess

import json
import pandas as pd
import os

from dataclasses import dataclass

def strip_indent(text):
    common_indent = None
    for line in text.splitlines():
        if len(line) == 0:
            continue
        line_stripped = line.lstrip()
        if len(line_stripped) < len(line):
            indent = line[:len(line) - len(line_stripped)]
            if common_indent is None or len(indent) < len(common_indent):
                common_indent = indent
    return "\n".join([l[len(common_indent) - 1:] if len(line) > 0 else "" for l in text.splitlines()])

from lmql.runtime.output_writer import PrintingDebuggerOutputWriter

class EvaluationOutputWriter(PrintingDebuggerOutputWriter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.clear = False
        self.print_output = False
        
        self.model_stats = None
        self.accumulated_model_stats = None
        
        self.steps = 0
    
    def report_model_stats(self, **kwargs):
        self.model_stats = kwargs

    def add_decoder_state(self, *args, **kwargs):
        self.steps += 1

async def run(lmql_file, writer):
    # os.chdir(os.path.join(os.path.dirname(__file__), "../../"))
    return await lmql.run_file(lmql_file, output_writer=writer)

def lstrip_spaces(l):
    r = ""
    for i, c in enumerate(l):
        if c == "\t": r += c
        elif c == " ": continue
        else: 
            r += l[i:]
            break
    return r

def indent(l):
    return l.replace("\n", "\n    ")

@dataclass
class EvalSampleTask:
    # same as in return statement above
    task_name: str
    repo_hash: str
    lmql_hash: str
    executor: str
    timestamp: str
    model: str
    decoder: str
    shots: int
    query: str
    query_file: str
    target: str
    
    distribution: Optional[dict] = None
    prediction: Optional[str] = None
    model_result: Optional[str] = None
    model_stats: Optional[dict] = None

async def eval_sample(task: EvalSampleTask):
    writer = EvaluationOutputWriter()
    try:
        res = await run(task.query_file, writer)
        if type(res) is not list: res = [res]
        
        model_response = "<none>"
        distribution = "<open ended>"

        if len(res) == 0:
            prediction = "<none>"
        else:
            r = res[0]
            model_response = r.prompt

            if r.distribution_variable is not None:
                dvar = r.distribution_variable
                dvar_scores = r.variables[f"P({dvar})"]
                # prediction is max item
                prediction = max(dvar_scores, key=lambda x: x[1])[0]
                distribution = dvar_scores
                # print(res.distribution_variable)
                task.distribution = distribution
            else:
                if "answer" in r.variables:
                    answer = r.variables["answer"]
                else:
                    answer = "<could not parse>"
            
                prediction = answer.strip()
    except Exception as e:
        prediction = "<error>"
        model_response = "<error>"
        traceback.print_exc()
        # raise e
        print("Error during query", task.query_file, e)
        with open(task.query_file, "r") as f:
            print(f.read())

    task.model_stats = writer.model_stats
    task.prediction = prediction
    task.model_result = model_response

    return task

class EvaluationSuite:
    def __init__(self, task_name, file, size):
        self.size = size
        self.task = task_name
        self.few_shot_samples = []
        
        if size is not None:
            file_without_ext = os.path.splitext(file)[0]
            self.file = file_without_ext + "-{}.json".format(size)
            self.task = task_name + "_{}".format(size)
        else:
            self.file = file

    def make_query(self, instance, model, decoder, kwargs, shots):
        pass

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            return data["instances"]

    async def eval(self, model, decoder=None, num_workers=8, shots = 0, **kwargs):
        print("Evaluation")
        print("  task: {}".format(self.task))
        print("  model: {}".format(model))
        print("  decoder: {}".format(decoder))
        print("  shots: {}".format(shots))
        print("  num_workers: {}".format(num_workers))
        for k, v in kwargs.items():
            print("  {}: {}".format(k, v))
        print("=========================================")

        diff_df = None
        queries_only = False

        #  check for queries only
        if "queries" in kwargs["kwargs"].keys():
            del kwargs["kwargs"]["queries"]
            queries_only = True

        #  check for diff
        if "diff" in kwargs["kwargs"].keys():
            diff = kwargs["kwargs"]["diff"]
            # remove it
            del kwargs["kwargs"]["diff"]

            # find csv file with results
            df = pd.read_csv(diff)
            # map rows by query
            diff_df = df.set_index("query")
            # get rows for this model

        task = self.task
        instances = self.instances()

        timestamp = datetime.datetime.now().strftime("%Y:%m:%d_%H:%M:%S")
        
        # results file name
        filename_model = model.replace("/", "-")
        filename = f"evaluation/results/{task}-{filename_model}-{decoder}-{timestamp}.csv"
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # metadata file name
        metadata_filename = f"evaluation/results/{task}-{filename_model}-{decoder}-{timestamp}.json"

        # determine repo_hash
        repo_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=PROJECT_ROOT).decode("utf-8").strip()
        lmql_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.join(PROJECT_ROOT, "lmql")).decode("utf-8").strip()
        executor = subprocess.check_output(["hostname"]).decode("utf-8").strip()

        metadata = {
            "repo_hash": repo_hash,
            "lmql_hash": lmql_hash,
            "executor": executor,
            "timestamp": timestamp,
            "model": model,
            "decoder": decoder,
            "shots": shots,
            "kwargs": "{" + str(",".join(sorted([f"'{k}': {v}" for k,v in kwargs.items()]))) + "}",
            "num_samples": len(instances),
            "cost": {
                "tokens": -1,
                "cost": -1
            }
        }

        # save metadata
        with open(metadata_filename, "w") as f:
            json.dump(metadata, f, indent=4)

        tasks = []
        
        tempdir = tempfile.mkdtemp()

        instances = list(instances.items())
        # instances = instances[:10]

        skipped = []

        # turn all instances into EvalSampleTask
        for key,instance in sorted(instances, key=lambda x: x[0]):
            query = self.make_query(instance, model, decoder, kwargs.get("kwargs", {}), shots)

            # write query to file
            query_file = os.path.join(tempdir, f"{key}.lmql")
            with open(query_file, "w") as f:
                f.write(query)
            
            task = EvalSampleTask(
                task_name=self.task, repo_hash=repo_hash, lmql_hash=lmql_hash, executor=executor, timestamp=timestamp, 
                model=model, decoder=decoder, shots=shots, query=query, query_file=query_file, target=instance["target"]
            )

            # check if diff_df already contains this query
            if diff_df is not None:
                if query in diff_df.index and diff_df.loc[query]["prediction"] not in ["<error>", "<none>", "None"]:
                    task.model_result = diff_df.loc[query]["model_result"]
                    task.prediction = diff_df.loc[query]["prediction"]
                    task.model_stats = diff_df.loc[query]["model_stats"]
                    # task = EvalSampleTask(**diff_df.loc[query].to_dict(), query=query)
                    skipped += [key]
                    print("Skipping {} as model result was already in diff".format(diff_df.loc[query]["prediction"]))

            if task.model_result is None and queries_only:
                print("# sample", key, "target:", task.target)
                print(task.query)

            tasks.append(task)

        if len(skipped) > 0:
            print("Skipping samples {} as model result was already in diff. Copying from diff for samples ({}).".format(len(skipped), skipped))

        if queries_only:
            print("[all queries printed]", flush=True)
            sys.exit(0)

        # run tasks in parallel
        results = []
        pbar = tqdm(total=len(tasks))
        
        semaphore = asyncio.Semaphore(num_workers)
        async def worker(task: EvalSampleTask):
            async with semaphore:
                if task.model_result is not None:
                    # print("Skipping {} as model result was already in diff".format(task.query))
                    return task
                else:
                    task.prediction = "<error>"
                    task.model_result = "<error>"
                    return await eval_sample(task)

        tasks = [worker(task) for task in tasks]

        async def pbar_logger():
            while True:
                await asyncio.sleep(1)
                pbar.update(0)
                if "openai" in model:
                    pbar.set_description(str(bopenai.Completion.get_stats()) + " Cost: ${:.2f}".format(bopenai.Completion.get_stats().cost_estimate(model)))
                pbar.refresh()

        pbar_task = asyncio.create_task(pbar_logger())
        bopenai.Completion.set_use_stream(True)
        bopenai.Completion.set_use_stream = lambda x: None

        # for task in tqdm(pool.imap(eval_sample, tasks), total=len(tasks)):
        for task in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
            task = await task
            
            pbar.update(1)
            
            result_data = task.__dict__

            # apply prediction parser
            result_data["prediction"] = self.parse_prediction(result_data["prediction"], result_data["model_result"])
            
            prediction = result_data["prediction"]
            assert not prediction.startswith("<could not parse>"), "model prediction could not be parsed from model response (implement parse_prediction)"

            results.append(result_data)
            df = pd.DataFrame(results, columns=sorted(list(results[0].keys())))
            df.to_csv(filename, index=False)

        results = []

        pbar_task.cancel()
        pbar.close()

        # update metadata with cost info
        metadata["cost"] = {
            "tokens": bopenai.Completion.get_stats().tokens,
            "cost": bopenai.Completion.get_stats().cost_estimate(model)
        }

        with open(metadata_filename, "w") as f:
            json.dump(metadata, f, indent=4)
    
    def parse_prediction(self, prediction, model_result):
        return prediction
        
    def main(self, model=None, **kwargs):
        model = model if model is not None else "openai/text-ada-001"
        
        # if "queries" in set(sys.argv):
        #     decoder = kwargs.get("decoder", "argmax")
        #     shots = kwargs.get("shots", 0)
        #     kwargs = kwargs.get("kwargs", {})
        #     # only print all instance queries
        #     for key,instance in self.instances().items():
        #         query = self.make_query(instance, model, decoder, kwargs, shots)
        #         print(query)
        #         print(instance["target"])
        #     print(kwargs)
        #     sys.exit(0)

        asyncio.run(self.eval(model, **kwargs))

# class decorator to register tasks
def suite(task_name):
    def decorator(cls):
        setattr(cls, "name", task_name)
        return cls
    return decorator

if __name__ == "__main__":
    writer = EvaluationOutputWriter()
    # run task
    file = sys.argv[1]
    asyncio.run(run(file, writer))