import torch
import os
import time
import polars as pl
import json

from . import mcsps
from .mcsps.verify import accept_one_of
from . import gen


def get_logits_generator(
    data_kwargs, model_kwargs, generation_kwargs, reproducibility_kwargs
):
    """
    output: logits_generator
    each next(logits_generator) returns (logits_q, logits_p)
    logits_q: (batch_size, seq_len, vocab_size)
    logits_p: (batch_size, seq_len, vocab_size)
    valid_mask: (batch_size, seq_len) bool
    """
    if data_kwargs["ds_name"] == "synth":
        return gen.synth.get_logits_generator(
            data_kwargs, model_kwargs, generation_kwargs, reproducibility_kwargs
        )
    return gen.normal.get_logits_generator(
        data_kwargs, model_kwargs, generation_kwargs, reproducibility_kwargs
    )


def compute_mcsps_records(
    logits_q, logits_p, valid_mask, mcsps_kwargs, seed, temperature=0.0
):
    """
    output: records
    records: List[Dict]
    records[i] = {
        # example
        "num_candidate": 10,
        "algorithm": "wr_recursive",
        "source": "verify",
        "alpha": [1.0,0.0,0.0,1.0,...], # len(alpha) = sum(valid_mask) ~= batch_size * seq_len
    }
    len(records) = len(mcsps_kwargs["num_candidates_list"]) * len(algorithms)
    """
    assert mcsps_kwargs["methods"] == "all"
    # logits_q = logits_q.reshape(-1, logits_q.shape[-1])
    # logits_p = logits_p.reshape(-1, logits_p.shape[-1])
    # only select the valid tokens
    logits_q = logits_q[valid_mask].reshape(-1, logits_q.shape[-1])
    logits_p = logits_p[valid_mask].reshape(-1, logits_p.shape[-1])
    # print(
    #     f"batch_size: {valid_mask.shape[0]}, seq_len: {valid_mask.shape[1]}, sum(valid_mask): {sum(valid_mask.reshape(-1))}, logits_q.shape: {logits_q.shape}"
    # )
    # assert logits_q.shape == logits_p.shape
    # assert logits_q.shape[0] == sum(valid_mask.reshape(-1))
    num_candidates_list = mcsps_kwargs["num_candidates_list"]
    records = []
    for num_candidate in num_candidates_list:
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "wr_optimal",
                "source": "theory",
                "alpha": mcsps.theory.wr_optimal_theory(
                    logits_q, logits_p, temperature, num_candidate
                ).tolist(),
            }
        )
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "wor_optimal",
                "source": "theory",
                "alpha": mcsps.theory.wor_optimal_theory(
                    logits_q, logits_p, temperature, num_candidate
                ).tolist(),
            }
        )
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "K_Seq",
                "source": "theory",
                "alpha": mcsps.theory.K_Seq_theory(
                    logits_q, logits_p, temperature, num_candidate
                ).tolist(),
            }
        )
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "GCSpS",
                "source": "theory",
                "alpha": mcsps.theory.GCSpS_theory(
                    logits_q, logits_p, temperature, num_candidate
                ).tolist(),
            }
        )

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        drafts = mcsps.gen.wr_gen(logits_q, temperature, num_candidate)
        targettoken = mcsps.verify.wr_recursive_verify(
            logits_q, logits_p, temperature, drafts
        )
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "wr_recursive",
                "source": "verify",
                "alpha": accept_one_of(drafts, targettoken).tolist(),
            }
        )

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        drafts = mcsps.gen.wor_gen(logits_q, temperature, num_candidate)
        targettoken = mcsps.verify.wor_recursive_verify(
            logits_q, logits_p, temperature, drafts
        )
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "wor_recursive",
                "source": "verify",
                "alpha": accept_one_of(drafts, targettoken).tolist(),
            }
        )

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        drafts = mcsps.gen.wr_gen(logits_q, temperature, num_candidate)
        targettoken = mcsps.verify.K_Seq_verify(logits_q, logits_p, temperature, drafts)
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "K_Seq",
                "source": "verify",
                "alpha": accept_one_of(drafts, targettoken).tolist(),
            }
        )

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        drafts = mcsps.gen.GCSpS_gen(logits_q, temperature, num_candidate)
        targettoken = mcsps.verify.GCSpS_verify(logits_q, logits_p, temperature, drafts)
        records.append(
            {
                "num_candidate": num_candidate,
                "algorithm": "GCSpS",
                "source": "verify",
                "alpha": accept_one_of(drafts, targettoken).tolist(),
            }
        )
    return records


def run_exp_one(config):
    print(f'Run {config["name"]}')
    csv_save_path = os.path.join(
        os.path.dirname(__file__), "..", "data", config["name"], "data.csv"
    )
    json_save_path = os.path.join(
        os.path.dirname(__file__), "..", "data", config["name"], "data.json"
    )
    os.makedirs(os.path.dirname(csv_save_path), exist_ok=True)

    time_start = time.time()

    logits_generator = get_logits_generator(
        config["data_kwargs"],
        config["model_kwargs"],
        config["generation_kwargs"],
        config["reproducibility_kwargs"],
    )
    csv_first_row = True
    for idx, (logits_q, logits_p, valid_mask) in enumerate(logits_generator):
        records = compute_mcsps_records(
            logits_q,
            logits_p,
            valid_mask,
            config["mcsps_kwargs"],
            config["reproducibility_kwargs"]["seed"] + idx,
            temperature=config["generation_kwargs"]["temperature"],
        )

        records = [{**r, "alpha": str(r["alpha"])} for r in records]
        ds = pl.DataFrame(records)
        if csv_first_row:
            with open(csv_save_path, "w") as f:
                ds.write_csv(f, include_header=True)
            csv_first_row = False
        else:
            with open(csv_save_path, "a") as f:
                ds.write_csv(f, include_header=False)

    time_end = time.time()
    json_data = {
        "total_time_elapsed": time_end - time_start,
    }
    json.dump(json_data, open(json_save_path, "w"))

def run_exp(config):
    if isinstance(config, dict):
        run_exp_one(config)
    elif isinstance(config, list):
        for c in config:
            run_exp(c)

