import json
import os
import sqlite3
from itertools import chain
from typing import List, Union

import evaluate
import numpy as np
import torch
from accelerate import Accelerator
from bitsandbytes.optim import AdamW
from datasets import load_dataset
from transformers import (
    DataCollatorForSeq2Seq,
    DataCollatorWithPadding,
    EvalPrediction,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    Seq2SeqTrainingArguments,
    Trainer,
    TrainingArguments,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

import sscompiler
from sscompiler.compiler import AbstractTransformer, Portable4BitLinear, WandaLayer
from sscompiler.compiler.layers.peft import (
    collect_all_peft_params,
    mark_adapters_as_trainable,
)
from sscompiler.compiler.multilr import MultiLR
from sscompiler.utils.constants import METRIC_MAP, SUPERGLUE_DATASETS, TASK_TO_KEYS
from sscompiler.utils.tokenization import SUPERGLUE_PROCESSORS


def get_hyperparams(database, top_k=1):
    """
    Gets the hyperparameters for an abstract transformer from a sqlite database.
    """
    database = os.path.join(database)
    con = sqlite3.connect(database)
    cur = con.cursor()

    query = """
WITH
top_trials AS (
    SELECT t.trial_id, t.study_id, tv.value
    FROM (select * from trials WHERE state='COMPLETE' order by trial_id asc limit 200) t
    JOIN trial_values tv ON t.trial_id=tv.trial_id
    ORDER BY tv.value DESC
    LIMIT ?
)
SELECT t.trial_id, tp.param_name, tp.param_value, tp.distribution_json
FROM top_trials t
JOIN trial_params tp ON t.trial_id=tp.trial_id;
        """

    cur.execute(query, (top_k,))
    rows = cur.fetchall()
    if len(rows) < 1:
        raise ValueError(
            "No rows were returned by the query. Check that the database contains entries for this task"
        )

    # creates dictionary of (trial, optimizations) pairs
    # optimizations contains the learning rate used in the trial, and an ordered
    # list of optimizations with their hyperparameters
    trials = {}
    for row in rows:
        trial_id, param, value, dist = row
        dist = json.loads(dist)
        trial = trials.get(trial_id, {})
        trial_opts = trial.get("opts", [])
        if param == "quant":
            continue
        if dist["name"] == "CategoricalDistribution":
            value = dist["attributes"]["choices"][int(value)]
        if param == "learning_rate":
            trial["learning_rate"] = float(value)
            trials[trial_id] = trial
            continue
        if param == "scale_factor":
            trial["scale_factor"] = float(value)
            trials[trial_id] = trial
            continue
        params = param.rsplit("_", 1)
        param, opt_num = params[0], int(params[1])
        if len(trial_opts) <= opt_num:
            trial_opts.extend([{} for _ in range(opt_num + 1 - len(trial_opts))])
        opt_dict = trial_opts[opt_num]
        opt_dict[param] = value
        trial["opts"] = trial_opts
        trials[trial_id] = trial

    return trials


def get_permutation_hyperparams(database, top_k=1):
    """
    Gets the hyperparameters for an abstract transformer from a sqlite database.
    """
    database = os.path.join(database)
    con = sqlite3.connect(database)
    cur = con.cursor()

    query = """
WITH
top_trials AS (
    SELECT t.trial_id, t.study_id, tv.value
    FROM (select * from trials WHERE state='COMPLETE' order by trial_id asc limit 200) t
    JOIN trial_values tv ON t.trial_id=tv.trial_id
    ORDER BY tv.value DESC
    LIMIT ?
)
SELECT t.trial_id, tp.param_name, tp.param_value, tp.distribution_json
FROM top_trials t
JOIN trial_params tp ON t.trial_id=tp.trial_id;
        """

    cur.execute(query, (top_k,))
    rows = cur.fetchall()
    if len(rows) < 1:
        raise ValueError(
            "No rows were returned by the query. Check that the database contains entries for this task"
        )

    # creates dictionary of (trial, optimizations) pairs
    # optimizations contains the learning rate used in the trial, and an ordered
    # list of optimizations with their hyperparameters
    trials = {}
    for row in rows:
        trial_id, param, value, dist = row
        dist = json.loads(dist)
        trial = trials.get(trial_id, {})
        # trial_opts = trial.get("combinations", [])
        if param == "quant":
            continue
        if dist["name"] == "CategoricalDistribution":
            value = dist["attributes"]["choices"][int(value)]
        if param == "learning_rate":
            trial["learning_rate"] = float(value)
            trials[trial_id] = trial
            continue
        if param == "scale_factor":
            trial["scale_factor"] = float(value)
            trials[trial_id] = trial
            continue
        if param == "combinations":
            trial["opts"] = value

        # params = param.rsplit("_", 1)
        # param, opt_num = params[0], int(params[1])
        # if len(trial_opts) <= opt_num:
        #     trial_opts.extend([{} for _ in range(opt_num + 1 - len(trial_opts))])
        # opt_dict = trial_opts[opt_num]
        # opt_dict[param] = value
        trials[trial_id] = trial

    return trials


def get_hyperparams_perms(database):
    """
    Gets the hyperparameters for an abstract transformer from a sqlite database.
    """
    database = os.path.join(database)
    con = sqlite3.connect(database)
    cur = con.cursor()

    query = """
WITH
top_trials AS (
    SELECT t.trial_id, t.study_id, tv.value
    FROM (select * from trials WHERE state='COMPLETE' order by trial_id asc limit 200) t
    JOIN trial_values tv ON t.trial_id=tv.trial_id
    ORDER BY tv.value DESC
    LIMIT 1
)
SELECT t.trial_id, tp.param_name, tp.param_value, tp.distribution_json
FROM top_trials t
JOIN trial_params tp ON t.trial_id=tp.trial_id;
        """

    cur.execute(query)
    rows = cur.fetchall()
    if len(rows) < 1:
        raise ValueError(
            "No rows were returned by the query. Check that the database contains entries for this task"
        )

    # creates dictionary of (trial, optimizations) pairs
    # optimizations contains the learning rate used in the trial, and an ordered
    # list of optimizations with their hyperparameters
    trials = {}
    for row in rows:
        trial_id, param, value, dist = row
        dist = json.loads(dist)
        trial = trials.get(trial_id, {})
        # trial_opts = trial.get("combinations", [])
        if param == "quant":
            continue
        if dist["name"] == "CategoricalDistribution":
            value = dist["attributes"]["choices"][int(value)]
        if param == "learning_rate":
            trial["learning_rate"] = float(value)
            trials[trial_id] = trial
            continue
        if param == "scale_factor":
            trial["scale_factor"] = float(value)
            trials[trial_id] = trial
            continue
        if param == "permutations":
            trial["opts"] = value

        # params = param.rsplit("_", 1)
        # param, opt_num = params[0], int(params[1])
        # if len(trial_opts) <= opt_num:
        #     trial_opts.extend([{} for _ in range(opt_num + 1 - len(trial_opts))])
        # opt_dict = trial_opts[opt_num]
        # opt_dict[param] = value
        trials[trial_id] = trial

    return trials


def finetune_at(
    *,
    at: AbstractTransformer,
    task: str,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    tokenized_train,
    tokenized_eval,
    epochs: int,
    batch_size: int,
    learning_rate: float = 1e-4,
    train_head: bool = False,
    metric_names: Union[List[str], None] = None,
    use_multi_lr: bool = False,
    full_parameter: bool = False,
    cosine_scheduler: bool = False,
):
    """
    Fine-tunes an abstract transformer on a specified task.
        at:     AbstractTransformer model
        task:   name of the task; be sure to include the task name and
                appropriate dataset tokenizing method to DATASET_MAP
    """

    if not full_parameter:
        mark_adapters_as_trainable(at.auto_model)
    if train_head:
        for name, param in at.auto_model.named_parameters():
            if "score" in name:
                param.requires_grad = True

    at.auto_model = Accelerator().prepare(at.auto_model)

    output_dir = os.path.join("testing")
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        weight_decay=0.06,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        save_strategy="no",
        optim="paged_adamw_32bit",
    )

    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

    if metric_names is None:
        metric_names = METRIC_MAP[task]
        if not isinstance(metric_names, list):
            metric_names = [metric_names]
    metrics = [evaluate.load(metric) for metric in metric_names]

    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if task == "stsb" else np.argmax(preds, axis=1)
        result = dict(
            chain.from_iterable(
                metric.compute(
                    predictions=preds,
                    references=p.label_ids,
                    **({"average": "macro"} if metric.name == "f1" else {}),
                ).items()
                for metric in metrics
            )
        )
        result["combined_score"] = np.mean(list(result.values())).item()
        return result

    if not full_parameter:
        optimizer_parameters = collect_all_peft_params(at.auto_model)

        # instead of importing AdamW, imports from bitsandbytes
        optimizer = AdamW(optimizer_parameters)

        # round up to the nearest multiple of 8 b/c hugggingface pads
        total_steps = int(epochs * (((len(tokenized_train) + 7) // 8) * 8) / batch_size)

        lamdbda_factories = []
        for _ in optimizer_parameters:
            lamdbda_factories.append(
                lambda y: get_linear_schedule_with_warmup(y, 0, total_steps)
            )

        multilr_scheduler = MultiLR(optimizer, lambda_factories=lamdbda_factories)
        if cosine_scheduler:
            cos_scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=0, num_training_steps=total_steps
            )

    if use_multi_lr and not full_parameter:
        optimizers = (optimizer, multilr_scheduler)
    elif cosine_scheduler:
        optimizers = (optimizer, cos_scheduler)
    else:
        optimizers = (None, None)

    trainer = Trainer(
        model=at.auto_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        optimizers=optimizers,
    )
    trainer.train()

    metric_names = [f"eval_{name}" for name in metric_names]

    best = [
        max(
            trainer.state.log_history,
            key=lambda i, name=name: i[name] if name in i else -1,
        ).get(name, -1)
        for name in metric_names
    ]

    return best, trainer.state.log_history


def wanda(at: AbstractTransformer, train_dataset, task):
    """
    Prunes an abstract transformer using Wanda.
    """
    at.inject_adapter(
        [key for key, _ in at.groups.items()],
        lambda x: WandaLayer(x, sparsity_ratio=0.5),
    )

    sentences = []
    if task in SUPERGLUE_DATASETS:
        sentences = [
            SUPERGLUE_PROCESSORS[task](row) for row in train_dataset
        ]  # switch to map
    elif task in ["arc-c", "arc-e", "mmlu", "hellaswag"]:
        sentences = train_dataset
    else:
        sentence1, sentence2 = TASK_TO_KEYS.get(task, (None, None))
        for example in train_dataset:
            args_ex = (
                (example[sentence1],)
                if sentence2 is None
                else (example[sentence1], example[sentence2])
            )
            sentences.extend(args_ex)
    at.prune(sentences)

    at.check_sparsity()


def glue_activations(at: AbstractTransformer, train_dataset, task):
    """
    Prunes an abstract transformer using Wanda.
    """

    sentences = []
    if task in SUPERGLUE_DATASETS:
        sentences = [
            SUPERGLUE_PROCESSORS[task](row) for row in train_dataset
        ]  # switch to map
    else:
        sentence1, sentence2 = TASK_TO_KEYS.get(task, (None, None))
        for example in train_dataset:
            args_ex = (
                (example[sentence1],)
                if sentence2 is None
                else (example[sentence1], example[sentence2])
            )
            sentences.extend(args_ex)
    return at.get_activations(sentences)


def nf4(at: AbstractTransformer):
    """
    Quantizes an abstract transformer using 4-bit quantization.
    """
    print("Quantizing model")
    at.inject_adapter(
        list(at.groups.keys()),
        lambda layer: Portable4BitLinear.from_linear(layer, compress_statistics=True),
    )


def pbllm(at: AbstractTransformer):
    at.inject_adapter(
        list(at.groups.keys()),
        lambda layer: sscompiler.compiler.GPTQLayer.from_linear(
            layer,
            8,
            0.15,
            mode="PTQ",
        ),
    )
    raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    at.quant(raw_datasets["text"])


def finetune_wikitext(
    *,
    at: AbstractTransformer,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    tokenized_train,
    tokenized_eval,
    epochs: int,
    batch_size: int,
    learning_rate: float = 1e-4,
    train_head: bool = False,
    use_multi_lr: bool = False,
):
    mark_adapters_as_trainable(at.auto_model)
    if train_head:
        for name, param in at.auto_model.named_parameters():
            if "score" in name:
                param.requires_grad = True

    at.auto_model = Accelerator().prepare(at.auto_model)

    output_dir = os.path.abspath(
        os.path.join(
            os.path.abspath(__file__),
            "..",
            "testing",
        )
    )
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        weight_decay=0.06,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        save_strategy="no",
        optim="paged_adamw_32bit",
    )

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    optimizer_parameters = collect_all_peft_params(at.auto_model)

    # instead of importing AdamW, imports from bitsandbytes
    optimizer = AdamW(optimizer_parameters)

    # round up to the nearest multiple of 8 b/c hugggingface pads
    total_steps = int(epochs * (((len(tokenized_train) + 7) // 8) * 8) / batch_size)

    lamdbda_factories = []
    for _ in optimizer_parameters:
        lamdbda_factories.append(
            lambda y: get_linear_schedule_with_warmup(y, 0, total_steps)
        )

    multilr_scheduler = MultiLR(optimizer, lambda_factories=lamdbda_factories)
    if use_multi_lr:
        optimizers = (optimizer, multilr_scheduler)
    else:
        optimizers = (None, None)

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=at.auto_model,
        padding="longest",
    )

    trainer = Trainer(
        model=at.auto_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        tokenizer=tokenizer,
        data_collator=data_collator,
        optimizers=optimizers,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    trainer.train()

    best = min(
        trainer.state.log_history,
        key=lambda i: i.get("eval_loss", float("inf")),
    ).get("eval_loss", float("inf"))

    perplexity = torch.exp(torch.tensor(best)).item()

    return perplexity, trainer.state.log_history


def finetune_wikitext_hf(
    *,
    auto_model,
    tokenizer,
    tokenized_train,
    tokenized_eval,
    epochs,
    batch_size,
    learning_rate,
    train_head,
    use_multi_lr,
):
    mark_adapters_as_trainable(auto_model)
    if train_head:
        for name, param in auto_model.named_parameters():
            if "score" in name:
                param.requires_grad = True

    auto_model = Accelerator().prepare(auto_model)

    output_dir = os.path.abspath(
        os.path.join(
            os.path.abspath(__file__),
            "..",
            "testing",
        )
    )
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        weight_decay=0.06,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        save_strategy="no",
        optim="paged_adamw_32bit",
    )

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    optimizers = (None, None)

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=auto_model,
        padding="longest",
    )

    trainer = Trainer(
        model=auto_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        tokenizer=tokenizer,
        data_collator=data_collator,
        optimizers=optimizers,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    trainer.train()

    best = min(
        trainer.state.log_history,
        key=lambda i: i.get("eval_loss", float("inf")),
    ).get("eval_loss", float("inf"))

    perplexity = torch.exp(torch.tensor(best)).item()

    return perplexity, trainer.state.log_history


def finetune_summarization(
    *,
    at: AbstractTransformer,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    tokenized_train,
    tokenized_eval,
    epochs: int,
    batch_size: int,
    learning_rate: float = 1e-4,
    train_head: bool = False,
    use_multi_lr: bool = False,
):
    mark_adapters_as_trainable(at.auto_model)
    if train_head:
        for name, param in at.auto_model.named_parameters():
            if "score" in name:
                param.requires_grad = True

    at.auto_model = Accelerator().prepare(at.auto_model)

    output_dir = os.path.abspath(
        os.path.join(
            os.path.abspath(__file__),
            "..",
            "testing",
        )
    )
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        weight_decay=0.06,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        save_strategy="no",
        optim="paged_adamw_32bit",
    )

    # optimizer_parameters = collect_all_peft_params(at.auto_model)

    # # instead of importing AdamW, imports from bitsandbytes
    # optimizer = AdamW(optimizer_parameters)

    # # round up to the nearest multiple of 8 b/c hugggingface pads
    # total_steps = int(epochs * (((len(tokenized_train) + 7) // 8) * 8) / batch_size)

    # lamdbda_factories = []
    # for _ in optimizer_parameters:
    #     lamdbda_factories.append(
    #         lambda y: get_linear_schedule_with_warmup(y, 0, total_steps)
    #     )

    # multilr_scheduler = MultiLR(optimizer, lambda_factories=lamdbda_factories)
    # if use_multi_lr:
    #     optimizers = (optimizer, multilr_scheduler)
    # else:
    #     optimizers = (None, None)

    # metric = evaluate.load("rouge")

    # def compute_metrics(eval_preds):
    #     preds, labels = eval_preds
    #     if isinstance(preds, tuple):
    #         preds = preds[0]

    #     # Replace -100s used for padding as we can't decode them
    #     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    #     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    #     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    #     decoded_preds = [pred.strip() for pred in decoded_preds]
    #     decoded_labels = [label.strip() for label in decoded_labels]

    #     result = metric.compute(
    #         predictions=decoded_preds,
    #         references=decoded_labels,
    #         use_stemmer=True,
    #     )
    #     result = {k: round(v * 100, 4) for k, v in result.items()}
    #     return result

    optimizers = (None, None)

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=at,
        label_pad_token_id=-100,
        padding="longest",
    )

    def preprocess_logits_for_metrics(logits, labels):
        """
        Original Trainer may have a memory leak.
        This is a workaround to avoid storing too many tensors that are not needed.
        """
        pred_ids = torch.argmax(logits, dim=-1)
        return pred_ids, labels

    trainer = Trainer(
        model=at.auto_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        tokenizer=tokenizer,
        data_collator=data_collator,
        optimizers=optimizers,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

    trainer.train()

    best = min(
        trainer.state.log_history,
        key=lambda i: i.get("eval_combined_score", float("inf")),
    ).get("eval_combined_score", float("inf"))

    return best, trainer.state.log_history


def finetune_summarization_hf(
    *,
    auto_model,
    tokenizer,
    tokenized_train,
    tokenized_eval,
    epochs,
    batch_size,
    learning_rate,
    train_head,
    use_multi_lr,
):
    mark_adapters_as_trainable(auto_model)
    if train_head:
        for name, param in auto_model.named_parameters():
            if "score" in name:
                param.requires_grad = True

    auto_model = Accelerator().prepare(auto_model)

    output_dir = os.path.abspath(
        os.path.join(
            os.path.abspath(__file__),
            "..",
            "testing",
        )
    )
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        weight_decay=0.06,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        save_strategy="no",
        optim="paged_adamw_32bit",
    )

    # optimizer_parameters = collect_all_peft_params(at.auto_model)

    # # instead of importing AdamW, imports from bitsandbytes
    # optimizer = AdamW(optimizer_parameters)

    # # round up to the nearest multiple of 8 b/c hugggingface pads
    # total_steps = int(epochs * (((len(tokenized_train) + 7) // 8) * 8) / batch_size)

    # lamdbda_factories = []
    # for _ in optimizer_parameters:
    #     lamdbda_factories.append(
    #         lambda y: get_linear_schedule_with_warmup(y, 0, total_steps)
    #     )

    # multilr_scheduler = MultiLR(optimizer, lambda_factories=lamdbda_factories)
    # if use_multi_lr:
    #     optimizers = (optimizer, multilr_scheduler)
    # else:
    #     optimizers = (None, None)

    # metric = evaluate.load("rouge")

    # def compute_metrics(eval_preds):
    #     preds, labels = eval_preds
    #     if isinstance(preds, tuple):
    #         preds = preds[0]

    #     # Replace -100s used for padding as we can't decode them
    #     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    #     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    #     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    #     decoded_preds = [pred.strip() for pred in decoded_preds]
    #     decoded_labels = [label.strip() for label in decoded_labels]

    #     result = metric.compute(
    #         predictions=decoded_preds,
    #         references=decoded_labels,
    #         use_stemmer=True,
    #     )
    #     result = {k: round(v * 100, 4) for k, v in result.items()}
    #     return result

    optimizers = (None, None)

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=auto_model,
        label_pad_token_id=-100,
        padding="longest",
    )

    def preprocess_logits_for_metrics(logits, labels):
        """
        Original Trainer may have a memory leak.
        This is a workaround to avoid storing too many tensors that are not needed.
        """
        pred_ids = torch.argmax(logits, dim=-1)
        return pred_ids, labels

    trainer = Trainer(
        model=auto_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        tokenizer=tokenizer,
        data_collator=data_collator,
        optimizers=optimizers,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

    trainer.train()

    best = min(
        trainer.state.log_history,
        key=lambda i: i.get("eval_combined_score", float("inf")),
    ).get("eval_combined_score", float("inf"))

    return best, trainer.state.log_history
