import logging
import string
from dataclasses import dataclass

import pandas as pd
import torch
from datasets import DatasetDict
from transformers import BatchEncoding, PreTrainedTokenizer
from transformers import AutoModelForCausalLM
from defs import LLMExperimentConfig
from lib_dl.analysis.experiment import ExperimentTaskDescription, experiment
from lib_llm.eval.sequences import (  # SequenceMetric,; EntropyMetric,
    SequenceEvaluationTask,
    SequenceMetric,
    SequenceProbMetric,
)
from lib_llm.inference import tokenize
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train

from ..named_entity_detection.data import load_random_names, load_scientists
from .data import (
    GeneratedDataset,
    RandomStringConfig,
    generate_random_strings,
    load_wiki_text,
)


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "same_data_different_division"

# - Plot mean entropy of the character distribution at each position
# - Plot mean KLD of the character distribution at each position
# from the previous position
# - Plot mean pairwise KLD of the character distributions in different
# sequences
# - Plot the probability of each character at each position
# - Look at both the distributions over only letters as well as over all
# possible tokens


@dataclass
class SDDDExperimentConfig(LLMExperimentConfig):
    seed: int
    base_model: ModelConfig
    data: RandomStringConfig
    fine_tuning: TrainingConfig
    shifted_pos_token_experiment: bool
    eval_1024: bool


@dataclass
class SDDDExperimentResult:
    token_distributions: pd.DataFrame
    training_history: pd.DataFrame


@experiment(EXP_NAME)
def SDDD_experiment(
    config: SDDDExperimentConfig,
    description: ExperimentTaskDescription,
) -> SDDDExperimentResult:
    model, tokenizer = load_model_tokenizer(
        config.base_model,
    )
    eval_1024 = config.eval_1024
    print("eval_1024", eval_1024)
    dataset = load_data(config.data, config.shifted_pos_token_experiment, False)
    if config.data.characterwise_tokenization:
        encoded_dataset = encode_character_wise(tokenizer, dataset.data, config.shifted_pos_token_experiment, data_config = config.data)
    else:
        encoded_dataset = encode_naturally(tokenizer, dataset.data)
        raise NotImplementedError("TODO: Set the alphabet accordingly")
    if eval_1024:
        dataset_1024 = load_data(config.data, config.shifted_pos_token_experiment, True)
        if config.data.characterwise_tokenization:
            encoded_dataset_1024 = encode_character_wise(tokenizer, dataset_1024.data, config.shifted_pos_token_experiment, data_config = config.data)
        else:
            encoded_dataset_1024 = encode_naturally(tokenizer, dataset_1024.data)
            raise NotImplementedError("TODO: Set the alphabet accordingly")

    training_dataset = encoded_dataset.remove_columns(["text"])
    if eval_1024:
        training_dataset_1024 = encoded_dataset_1024.remove_columns(["text"])
    print(training_dataset["test"].to_pandas().head())
    if eval_1024:
        print(training_dataset_1024["test"].to_pandas().head())
    logger.info(f"Generated {len(encoded_dataset['test'])} sequences")
    if eval_1024:
        logger.info(f"Generated {len(encoded_dataset_1024['test'])} sequences")

    metrics: dict[str, SequenceMetric] = get_character_metrics(
        dataset.alphabet_characters,
        tokenizer,
        config.data,
    )
    if eval_1024:
        metrics_1024: dict[str, SequenceMetric] = get_character_metrics(
            dataset_1024.alphabet_characters,
            tokenizer,
            config.data,
        )

    character_eval_task = SequenceEvaluationTask(
        metrics,
        sequences=encoded_dataset["test"],
        tokenizer=None,
        local_rank=config.local_rank,
    )

    if eval_1024:
        character_eval_task_1024 = SequenceEvaluationTask(
            metrics_1024,
            sequences=encoded_dataset_1024["test"],
            tokenizer=None,
            local_rank=config.local_rank,
        )

    training_res = train(
        description,
        (config.base_model.model_id, model),
        ("random_strings", training_dataset),
        config=config.fine_tuning,
        tokenizer=tokenizer,
        local_rank=config.local_rank,
        callbacks=[character_eval_task],
        # callbacks=[],
        data_already_preprocessed=True,
    )

    model = training_res.model

    if eval_1024:
        training_res_1024 = train(
            description,
            (config.base_model.model_id, model),
            ("random_strings", training_dataset_1024),
            config=config.fine_tuning,
            tokenizer=tokenizer,
            local_rank=config.local_rank,
            callbacks=[character_eval_task_1024],
            # callbacks=[],
            data_already_preprocessed=True,
        )
        model = training_res_1024.model
    # metric_res = eval_task.evaluate(model)

    token_distributions = character_eval_task.result()
    token_distributions.index.names = ["epoch", "sequence"]
    token_distributions = match_token_probs_to_sequence_pos(token_distributions)

    if eval_1024:
        token_distributions_1024 = character_eval_task_1024.result()
        token_distributions_1024.index.names = ["epoch", "sequence"]
        token_distributions_1024 = match_token_probs_to_sequence_pos(token_distributions_1024)

    print(token_distributions)
    if eval_1024:
        print(token_distributions_1024)

    if eval_1024:
        return SDDDExperimentResult(
            token_distributions=token_distributions_1024,
            training_history=training_res_1024.training_logs,
        )
    else:
        return SDDDExperimentResult(
            token_distributions=token_distributions,
            training_history=training_res.training_logs,
        )


def load_data(config: RandomStringConfig, shifted_pos_token_experiment: bool, dataset_1024: bool) -> GeneratedDataset:
    if config.data_type == "rand":
        dataset = generate_random_strings(config, shifted_pos_token_experiment, dataset_1024)
    elif config.data_type == "rand-names":
        dataset = load_random_names(
            num_names=config.num_sequences, seed=config.seed
        )
        raise NotImplementedError("TODO: Set the alphabet accordingly")
    elif config.data_type == "sci-names":
        dataset = load_scientists(
            num_names=config.num_sequences, seed=config.seed
        )
        raise NotImplementedError("TODO: Set the alphabet accordingly")
    elif config.data_type == "wiki":
        dataset = load_wiki_text(config)
        raise NotImplementedError("TODO: Set the alphabet accordingly")
    else:
        raise ValueError(f"Unknown data type: {config.data_type}")
    return dataset


def encode_naturally(
    tokenizer: PreTrainedTokenizer,
    dataset: DatasetDict,
) -> DatasetDict:
    def encode(example: dict):
        sequences = example["text"]
        max_length = max(len(s) for s in sequences)
        return tokenize(
            tokenizer,
            sequences,
            # Pad to the longest possible tokenization
            max_length=max_length,
        )

    return dataset.map(
        encode,
        batched=True,
        # remove_columns=["text"],
    )


def encode_character_wise(
    tokenizer: PreTrainedTokenizer,
    dataset: DatasetDict,
    shifted_pos_token_experiment: bool,
    data_config: RandomStringConfig,
) -> DatasetDict:

    sequence_length = data_config.sequence_length
    num_sequences = data_config.num_sequences

    def characterwise_encoding(example: dict):
        sequences = example["text"]
        max_length = max(len(s) for s in sequences)
        sequence_token_ids = []
        sequence_token_masks = []
        for sequence in sequences:
            sequence_chars = list(sequence)
            encoded_chars = tokenize(
                tokenizer,
                sequence_chars,
                max_length=1,
            )
            # add padding
            num_padding = max_length - len(sequence)
            padded_input_ids = torch.cat(
                (
                    torch.tensor(
                        [tokenizer.pad_token_id] * num_padding, dtype=torch.long
                    ),
                    encoded_chars.input_ids.squeeze(1),
                )
            )
            padded_attention_mask = torch.cat(
                (
                    torch.tensor([0] * num_padding, dtype=torch.long),
                    encoded_chars.attention_mask.squeeze(1),
                )
            )
            sequence_token_ids.append(padded_input_ids)
            sequence_token_masks.append(padded_attention_mask)
        return {
            "input_ids": torch.stack(sequence_token_ids),
            "attention_mask": torch.stack(sequence_token_masks),
        }

    def characterwise_encoding_with_attention_mask_zeroed_outside_relevant(example: dict):
        sequences = example["text"]
        max_length = max(len(s) for s in sequences)
        sequence_token_ids = []
        sequence_token_masks = []
        currIdx = 0
        for sequence in sequences:
            sequence_chars = list(sequence)
            encoded_chars = tokenize(
                tokenizer,
                sequence_chars,
                max_length=1,
            )
            number_of_tokens_towards_left_with_zero_attention_mask = currIdx * sequence_length
            number_of_tokens_towards_right_with_zero_attention_mask = 1024 - (number_of_tokens_towards_left_with_zero_attention_mask + sequence_length)
            # add padding
            num_padding = max_length - len(sequence)
            if num_padding != 0:
                raise ValueError("Padding not expected, all sequences should be of same length")
            input_ids = torch.cat(
                (
                    torch.tensor(
                        [tokenizer.pad_token_id] * num_padding, dtype=torch.long
                    ),
                    encoded_chars.input_ids.squeeze(1),
                )
            )
            print()
            print("AttentionMaskLens", number_of_tokens_towards_left_with_zero_attention_mask, sequence_length, number_of_tokens_towards_right_with_zero_attention_mask)
            print()
            attention_mask = torch.cat(
                (
                    torch.tensor([0] * number_of_tokens_towards_left_with_zero_attention_mask, dtype=torch.long),
                    torch.tensor([1] * sequence_length, dtype=torch.long),
                    torch.tensor([0] * number_of_tokens_towards_right_with_zero_attention_mask, dtype=torch.long),
                )
            )
            len_input_ids = len(input_ids)
            if len_input_ids != 1024:
                raise ValueError("Input ids length not expected, all sequences should be of same length, expected: 1024" + " got: " + str(len_input_ids))
            len_attention_mask = len(attention_mask)
            if len_attention_mask != 1024:
                raise ValueError("Attention mask length not expected, all sequences should be of same length, expected: 1024" + " got: " + str(len_attention_mask))
            sequence_token_ids.append(input_ids)
            sequence_token_masks.append(attention_mask)
            currIdx += 1
        return {
            "input_ids": torch.stack(sequence_token_ids),
            "attention_mask": torch.stack(sequence_token_masks),
        }

    if shifted_pos_token_experiment:
        return dataset.map(
            characterwise_encoding_with_attention_mask_zeroed_outside_relevant,
            batched=True,
            # remove_columns=["text"],
        )

    return dataset.map(
        characterwise_encoding,
        batched=True,
        # remove_columns=["text"],
    )


def get_character_metrics(
    characters: str,
    tokenizer: PreTrainedTokenizer,
    data_config: RandomStringConfig,
) -> dict[str, SequenceMetric]:
    """Construct a metric per character that tracks the charcters probability
    distribution over the sequence.
    """
    metrics = {}
    for char in characters:
        char_encoding = tokenizer.encode(char, add_special_tokens=False)
        token_ids = torch.tensor(char_encoding * data_config.sequence_length)
        token_mask = torch.ones_like(token_ids)
        metrics[char] = SequenceProbMetric(
            BatchEncoding(
                dict(
                    input_ids=token_ids,
                    attention_mask=token_mask,
                )
            )
        )
    return metrics


def match_token_probs_to_sequence_pos(
    token_distributions: pd.DataFrame,
) -> pd.DataFrame:
    """Out of the box the character probability metrics store a list
    of probabilities in the dataframe cells for each tracked character,
    with an entry per sequence positions.
    This function converts the lists stored in each cell to an additional
    dataframe column, such that each cell only stores a scalar value.
    """
    return pd.concat(
        {
            epoch: pd.concat(
                [
                    pd.DataFrame(
                        {
                            char: char_probs
                            for char, char_probs in zip(
                                sequence_df.columns,
                                sequence_df.values.flatten(),
                            )
                        },
                        index=pd.Index(
                            list(str(sequence)[1:]),
                            name="character",
                        ),
                        dtype="float",
                    )
                    for sequence, sequence_df in epoch_df.droplevel(
                        "epoch"
                    ).groupby("sequence")
                ],
                axis=0,
                keys=epoch_df.index.get_level_values("sequence"),
            )
            for epoch, epoch_df in token_distributions.groupby("epoch")
        },
        axis=0,
        keys=token_distributions.index.get_level_values("epoch"),
    )
