from functools import wraps
import logging
from collections.abc import Mapping, Sequence

from typing import Any, List, Union, Dict, get_type_hints
import json
import numpy as np

import torch
from torch import Tensor
import torch.distributions as td
import tqdm
from datasets import load_dataset as _load_dataset
from datasets import Dataset, concatenate_datasets
from trl import apply_chat_template
from transformers import DataCollatorForLanguageModeling

from redflag.data_utils import (
    get_leftpadded_position_ids,
    get_pattern_positions,
    create_inserted_tensors,
    drop_row_insert_seqlen,
    InsertIdxSampler,
    DefaultInsertIdxSampler,
    FixedInsertIdxSampler,
    DataCollatorForCompletionOnlyLM,
)

# TODO: don't hardcode this
PROMPT_DEST_KEY = "prompt"
COMPLETION_DEST_KEY = "completion"
LABEL_IGNORE_IDX = -100
ATTN_ON = 1
ATTN_OFF = 0


def _load_dataset_v1(datasets: list | dict | str | None, tokenizer) -> Dataset | Dict[str, Dataset]:
    if isinstance(datasets, str):
        datasets = [datasets]

    if isinstance(datasets, Sequence):
        loaded_dataset = []
        for dataset_name in datasets:
            with open(dataset_name, "r") as f:
                data = json.load(f)
            loaded_dataset.append(Dataset.from_dict(data))
        loaded_dataset = concatenate_datasets(loaded_dataset)
        loaded_dataset = loaded_dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
        return loaded_dataset

    elif isinstance(datasets, Mapping):
        loaded_dataset = {}
        for key, dataset_name in datasets.items():
            with open(dataset_name, "r") as f:
                data = json.load(f)
            data = Dataset.from_dict(data)
            loaded_dataset[key] = data.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
        return loaded_dataset

    else:
        raise ValueError(f"Invalid dataset type: {type(datasets)}")


def _load_dataset_v2(datasets: dict, tokenizer, preprocess_hook=None) -> Dataset | Dict[str, Dataset]:
    """
    Assuming config is in the style:
    {
        version: v2
        datasets:
            - path:      [PATH].jsonl
              partition: [redflag/benign]
            - path:      [PATH].jsonl
              partition: [redflag/benign]
            - path:      [PATH].jsonl
              partition: [redflag/benign]
            ...
    }
    """
    _ds = datasets.get("datasets")

    def _text2chat(example):
        example["prompt"] = [{"role": "user", "content": example["prompt"]}]
        example["completion"] = [{"role": "assistant", "content": example["completion"]}]
        return example

    def _apply_hook(example):
        if preprocess_hook is not None:
            return preprocess_hook(example)
        return example

    if isinstance(_ds, Sequence):
        loaded_datasets = []
        for d in _ds:
            loaded_ds = _load_dataset("json", data_files=d["path"], split="train")
            loaded_ds = loaded_ds.add_column("partition", [d["partition"]] * len(loaded_ds))
            loaded_datasets.append(loaded_ds)

        loaded_datasets = concatenate_datasets(loaded_datasets)
        # using prompt/completion via trl.apply_chat_template
        loaded_datasets = loaded_datasets.remove_columns("messages")
        loaded_datasets = loaded_datasets.map(_text2chat)
        if preprocess_hook is not None:
            loaded_datasets = loaded_datasets.map(_apply_hook)
        loaded_datasets = loaded_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

    elif isinstance(_ds, Mapping):
        loaded_datasets = {}
        for key, d in _ds.items():
            loaded_ds = _load_dataset("json", data_files=d["path"], split="train")
            loaded_ds = loaded_ds.add_column("partition", [d["partition"]] * len(loaded_ds))
            # using prompt/completion via trl.apply_chat_template
            loaded_ds = loaded_ds.remove_columns("messages")
            loaded_ds = loaded_ds.map(_text2chat)
            if preprocess_hook is not None:
                loaded_ds = loaded_ds.map(_apply_hook)
            loaded_datasets[key] = loaded_ds.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

    else:
        raise ValueError(f"Invalid dataset specification:\n\n{datasets}")

    return loaded_datasets


def load_dataset(datasets: list | dict | str | None, tokenizer, preprocess_hook=None) -> Dataset | Dict[str, Dataset]:
    """
    Load dataset with optional preprocessing hook.
    
    Args:
        datasets: Dataset configuration
        tokenizer: Tokenizer instance
        preprocess_hook: Optional callable that receives an example dict (after _text2chat conversion)
                        and returns a modified example dict. Called before apply_chat_template.
                        For v2 datasets, example will have format:
                        {"prompt": [{"role": "user", "content": "..."}], 
                         "completion": [{"role": "assistant", "content": "..."}], ...}
    """
    # version = datasets.get("version", "v1") if isinstance(datasets, dict) else None
    version = datasets.get("version", "v1") if isinstance(datasets, Mapping) else None
    logging.info(f"Loading dataset version: {version}")

    if version is None:
        return _load_dataset_v1(datasets, tokenizer)
    if version == "v1":
        return _load_dataset_v1(datasets.get("datasets"), tokenizer)
    elif version == "v2":
        return _load_dataset_v2(datasets, tokenizer, preprocess_hook=preprocess_hook)
    else:
        raise ValueError(f"Invalid dataset version: {version}")


class SampledRedflagDataCollatorCompletions(DataCollatorForCompletionOnlyLM):
    @wraps(DataCollatorForCompletionOnlyLM.__init__)
    def __init__(
        self,
        *args,
        rf_token_id,
        insert_sampler: InsertIdxSampler = DefaultInsertIdxSampler(),
        ignore_batch_val: int = -100,  # TODO; previously was -1
        min_offset=0,
        user_token_ids: List[int] = None,
        drop_rf_proba: float = 0.0,
        return_adv_tensors: bool = False,
        adv_prefill_length: int = 24,
        **kwargs,
    ):
        """
        Args:
            ...
            drop_rf_proba: probability of dropping all redflags out from a given batch, but making ALL labels a redflag
            ...
        """
        super().__init__(*args, **kwargs)
        self.rf_token_id = rf_token_id
        self.ignore_batch_val = ignore_batch_val
        self.left_pad = False
        if not self.tokenizer.padding_side == "right":
            raise ValueError("Padding side must be right")
        self.insert_sampler = insert_sampler
        self.drop_rf_proba = drop_rf_proba
        self.drop_dist = td.Uniform(0, 1)

        # adversarial attack stuff
        self.user_token_ids = user_token_ids
        self.return_adv_tensors = return_adv_tensors
        self.adv_prefill_length = adv_prefill_length

    def torch_call(self, examples):
        keys = ["input_ids", "attention_mask"]
        redflags = torch.tensor([ex["partition"] == "redflag" for ex in examples]).bool()
        inputs = [{key: ex[key] for key in keys} for ex in examples]
        batch = DataCollatorForLanguageModeling.torch_call(self, inputs)

        raw_labels = batch["labels"].clone()
        raw_inputs = batch["input_ids"].clone()
        raw_attn_mask = batch["attention_mask"].clone()

        # get samples of injected redflag positions, relevant min/max indices
        seq_max_length = raw_labels.shape[1]
        response_token_ids_idx = get_pattern_positions(raw_inputs, self.response_token_ids)
        seq_len = raw_attn_mask.sum(dim=1)
        insert_pos = self.insert_sampler.get_insert_pos(
            len(response_token_ids_idx), response_token_ids_idx, seq_len, seq_max_length
        )

        # drop redflags from being inserted; insert them past <eot> so the model doesn't see any but still will possibly
        # have a <rf> target depending on how the loss is configured (xent_mode = up_to_rf * weighting)
        drop_mask = torch.zeros_like(redflags, dtype=torch.bool)
        if self.drop_rf_proba != 0:
            drop_mask = self.drop_dist.sample((len(insert_pos),)) < self.drop_rf_proba
            drop_mask = drop_mask & redflags
            if torch.any(drop_mask):
                insert_pos = drop_row_insert_seqlen(insert_pos, drop_mask, seq_len, LABEL_IGNORE_IDX)

        batch = super().torch_call(inputs)
        batch['drop_mask'] = drop_mask
        batch["rf_input_ids"] = create_inserted_tensors(
            raw_inputs, redflags, insert_pos, self.rf_token_id, self.tokenizer.pad_token_id
        )
        batch["rf_labels"] = create_inserted_tensors(
            raw_labels, redflags, insert_pos, self.rf_token_id, LABEL_IGNORE_IDX
        )
        batch["rf_attention_mask"] = create_inserted_tensors(raw_attn_mask, redflags, insert_pos, ATTN_ON, ATTN_OFF)

        batch["position_ids"] = get_leftpadded_position_ids(batch["attention_mask"]) if self.left_pad else None
        batch["rf_position_ids"] = get_leftpadded_position_ids(batch["rf_attention_mask"]) if self.left_pad else None

        # add arange starting at non-pad token to for the post insert rf positions to account for
        # previously inserted red flags.
        #   - rf_positions refers to which positions in the original tensor the rf is inserted
        #   - rf_positions_post refers to where the redflags are in the inserted tensor
        insert_pos_post = insert_pos
        if self.insert_sampler.version > 1:
            offset_arange = get_leftpadded_position_ids((insert_pos != self.insert_sampler.pad_value).int())
            insert_pos_post = insert_pos + offset_arange
        insert_pos_post[~redflags] = self.ignore_batch_val
        insert_pos[~redflags] = self.ignore_batch_val

        batch["rf_positions"] = insert_pos
        batch["rf_positions_post"] = insert_pos_post
        batch["rf_entries"] = redflags
        batch["response_start_idx"] = response_token_ids_idx
        batch["response_start_idx_offset"] = response_token_ids_idx + self.insert_sampler.min_offset

        if self.return_adv_tensors and sum(redflags) > 0:  # TODO: return views?
            user_start_idx = get_pattern_positions(raw_inputs, self.user_token_ids, align="right")
            user_end_idx = get_pattern_positions(raw_inputs, self.response_token_ids, align="left")

            # get assistant start idxs; get max length
            adv_len = response_token_ids_idx + self.adv_prefill_length
            adv_len[~redflags] = 0  # we don't care about
            max_adv_len = adv_len.max().item()

            # get redflag examples
            adv_raw_input_ids = raw_inputs[:, :max_adv_len].clone()
            adv_raw_targets = raw_labels[:, :max_adv_len].clone()
            adv_raw_attn_mask = raw_attn_mask[:, :max_adv_len].clone()

            # fill everything past _adv_len with pad value, batch-wise
            seq_indices = torch.arange(max_adv_len)[None, :]
            mask_adv_len = seq_indices >= adv_len[:, None]
            mask_pre_response = seq_indices < response_token_ids_idx[:, None]

            adv_raw_input_ids[mask_adv_len] = self.tokenizer.pad_token_id
            adv_raw_targets[mask_adv_len] = LABEL_IGNORE_IDX
            adv_raw_targets[mask_pre_response] = LABEL_IGNORE_IDX
            adv_raw_attn_mask[mask_adv_len] = ATTN_OFF

            # probably want to set to something like:
            # x = [[ 1   2   3   4 | 5   6],  <- max length
            #      [-1  -1  -1  -1  -1  -1]   <- not RF
            #      [ 1   2 | 3   4  -1  -1]]  <- shorter
            batch["adv_raw_input_ids"] = adv_raw_input_ids
            batch["adv_raw_labels"] = adv_raw_targets
            batch["adv_raw_attn_mask"] = adv_raw_attn_mask
            batch["user_start_idx"] = user_start_idx
            batch["user_end_idx"] = user_end_idx

        return batch


class AllRedflagDataCollatorCompletions(DataCollatorForCompletionOnlyLM):
    @wraps(DataCollatorForCompletionOnlyLM.__init__)
    def __init__(self, *args, rf_token_id, **kwargs):
        """Collator which makes all labels the redflag token"""
        super().__init__(*args, **kwargs)
        self.rf_token_id = rf_token_id

    def torch_call(self, examples):
        keys = ["input_ids", "attention_mask"]
        redflags = [ex["partition"] == "redflag" for ex in examples]
        inputs = [{key: ex[key] for key in keys} for ex in examples]
        batch = super().torch_call(inputs)
        batch["labels"][redflags] = torch.where(
            batch["labels"][redflags] != -100, self.rf_token_id, batch["labels"][redflags]
        )
        return batch


def convert_to_completions(
    data: list,
    partition_tag: str,
    prompt_key: str = "prompt",
    completion_key: str = "response",
    partition_key: str = "partition",
):
    all_keys = set(data[0].keys())
    all_keys -= {prompt_key, completion_key}
    ds = {k: [] for k in (all_keys | {PROMPT_DEST_KEY, COMPLETION_DEST_KEY, partition_key})}
    for x in tqdm.tqdm(data):
        ds[PROMPT_DEST_KEY].append([{"role": "user", "content": x[prompt_key]}])
        ds[COMPLETION_DEST_KEY].append([{"role": "assistant", "content": x[completion_key]}])
        ds[partition_key].append(partition_tag)
        for k in all_keys:
            ds[k].append(x[k])
    return ds


def convert_file(
    filepath: str,
    savepath: str = None,
    partition_tag: str = "benign",
    prompt_key: str = "prompt",
    completion_key: str = "response",
    partition_key: str = "partition",
):
    if partition_tag not in ["benign", "redflag"]:
        raise ValueError("partition_tag must be either 'benign' or 'redflag'")
    with open(filepath, "r") as f:
        data = json.load(f)
    ds = convert_to_completions(data, partition_tag, prompt_key, completion_key, partition_key)
    with open(savepath, "w") as f:
        json.dump(ds, f, indent=2)
