import os
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import numpy as np
import torch.nn.functional as F
from typing import List
import logging  # Use logging for messages

logger = logging.getLogger(__name__)


def _load_data(data_path, tokenizer=None):
    logger.info(f"Attempting to load data from: {data_path}")
    try:
        with open(data_path, "r") as f:
            data = f.read().splitlines()
            if not data:
                logger.warning(f"Data file is empty: {data_path}")
                return DictDataset([], [], tokenizer)  # Return empty dataset
    except FileNotFoundError:
        logger.error(f"Data file not found: {data_path}")
        raise  # Re-raise the exception
    except Exception as e:
        logger.error(f"Error reading data file {data_path}: {e}", exc_info=True)
        raise

    input_texts = []
    target_texts = []
    malformed_lines = 0
    for i, line in enumerate(data):
        # Skip empty lines
        if not line.strip():
            continue
        parts = line.split(":", 1)  # Split only once at the first colon
        if len(parts) == 2:
            input_texts.append(parts[0].strip())
            target_texts.append(parts[1].strip())
        else:
            if malformed_lines < 5:  # Log only first few malformed lines
                logger.warning(f"Skipping malformed line {i+1} in {data_path}: {line}")
            malformed_lines += 1

    if malformed_lines > 0:
        logger.warning(f"Total malformed lines skipped in {data_path}: {malformed_lines}")

    if not input_texts:
        logger.warning(f"No valid data loaded from {data_path} after parsing.")

    logger.info(f"Loaded {len(input_texts)} samples from {data_path}")
    dataset = DictDataset(input_texts, target_texts, tokenizer)

    return dataset


def load_data(
    data_path,
    encoding="prefix",
    batch_sizes=[4, 100],
    return_dataloader=True,
    extensions=["train", "test"],
    do_shuffle=[True, False],
    tokenizer=None,
    continuous_coefficient=True,
    continuous_exponent=False,
    support_learning=False,
):

    ret = []
    for ext, batch_size, shuffle in zip(extensions, batch_sizes, do_shuffle):
        path = f"{data_path}.{ext}"
        print(f"loading ... {path}")
        if encoding:
            path = path + f".{encoding}"
        dataset = _load_data(path)

        if return_dataloader:
            data_collator = DataCollator(
                tokenizer,
                continuous_coefficient=continuous_coefficient,
                continuous_exponent=continuous_exponent,
                support_learning=support_learning,
            )
            print(f"content of batch_size: {batch_size}", flush=True)
            dataset = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=os.cpu_count(),
                pin_memory=True,
                collate_fn=data_collator,
            )

        ret.append(dataset)

    return ret[0] if len(ret) == 1 else ret


class DictDataset(Dataset):
    def __init__(self, input_texts, target_texts, tokenizer=None):
        self.tokenizer = tokenizer

        input_ = input_texts if tokenizer is None else tokenizer(input_texts, padding="longest", return_tensors="pt")
        target = target_texts if tokenizer is None else tokenizer(target_texts, padding="longest", return_tensors="pt")

        self.input = input_ if tokenizer is None else input_["input_ids"]
        self.input_mask = None if tokenizer is None else input_["attention_mask"].bool()
        self.target = target if tokenizer is None else target["input_ids"]
        self.target_mask = None if tokenizer is None else target["attention_mask"].bool()

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        return {
            "input": self.input[idx],
            "target": self.target[idx],
            "input_mask": self.input_mask[idx] if self.tokenizer is not None else None,
            "target_mask": self.target_mask[idx] if self.tokenizer is not None else None,
        }


def str_to_float(s):
    try:
        return float(s)
    except:
        if "/" in s:
            a, b = s.split("/")
            return float(a) / float(b)

        raise ValueError(f"invalid string: {s}")


def _preprocess_input_texts(input_text: str, method, sep_token_id=4, number_token_id=0, separator="[SEP]"):
    # if method is None: return input_text

    if method == "monom-wise":
        print(input_text)
        polys = input_text.split(f" {separator} ")
        if "[SEP]" in polys:
            print(input_text)
            print(polys)
            print()
            assert False
        xs = []
        for poly in polys:
            monoms = poly.split(" + ")
            x = torch.tensor(
                [[number_token_id] + [int(m[1:]) for m in monom.split()] for monom in monoms]
            )  # remove coeff/exp tags
            xs.append(x)

            sep = torch.zeros(x.shape[-1])
            sep[0] = sep_token_id
            xs.append(sep.unsqueeze(0))  # add separator

        xs = torch.cat(xs[:-1], dim=0)

    if method is None:
        xs = input_text.split(" ")

    return xs


def preprocess_input_texts(input_texts: List[str], method=None, sep_token_id=4, number_token_id=0, separator="[SEP]"):
    return [
        _preprocess_input_texts(
            it, method, sep_token_id=sep_token_id, number_token_id=number_token_id, separator=separator
        )
        for it in input_texts
    ]


class PolynomialDataCollator:
    def __init__(self, num_variables=-1, vocab_map=None, tokenizer=None, method="monom-wise"):
        assert num_variables > 0
        self.num_variables = num_variables
        self.tokenizer = tokenizer
        self.method = method

        if tokenizer is not None:
            self.vocab_map = tokenizer.get_vocab_map()
            if method is not None:
                RuntimeWarning(f"Tokenizer is given but method is not None: {method}. Method is set to None.")
                self.method = None

        elif vocab_map is not None:
            self.vocab_map = vocab_map
        else:
            self.vocab_map = {
                "pad_token_id": 1,
                "start_token_id": 2,
                "eos_token_id": 3,
                "sep_token_id": 4,
                "number_token_id": 0,
            }

    @torch.no_grad()
    def __call__(self, batch):
        batch_size = len(batch)
        inputs = [item["input"] for item in batch]
        targets = [item["target"] for item in batch]
        inputs = preprocess_input_texts(
            inputs,
            method=self.method,
            sep_token_id=self.vocab_map["sep_token_id"],
            number_token_id=self.vocab_map["number_token_id"],
        )
        targets = preprocess_input_texts(
            targets,
            method=self.method,
            sep_token_id=self.vocab_map["sep_token_id"],
            number_token_id=self.vocab_map["number_token_id"],
        )

        if self.method is not None:
            max_input_seq_len = max([len(sample) for sample in inputs])
            max_target_seq_len = max([len(sample) for sample in targets])

            encoder_input = torch.zeros((batch_size, max_input_seq_len + 2, self.num_variables + 2))
            decoder_input = torch.zeros((batch_size, max_target_seq_len + 2, self.num_variables + 2))

            for i, x in enumerate(inputs):
                encoder_input[i, : len(x), :] = x
                encoder_input[i, len(x), 0] = self.vocab_map["eos_token_id"]
                encoder_input[i, len(x) + 1 :, 0] = self.vocab_map["pad_token_id"]

            for i, y in enumerate(targets):
                decoder_input[i, : len(y), :] = y
                decoder_input[i, len(y), 0] = self.vocab_map["eos_token_id"]
                decoder_input[i, len(y) + 1 :, 0] = self.vocab_map["pad_token_id"]

            encoder_padding_mask = encoder_input[:, :, 0] == self.vocab_map["pad_token_id"]
            decoder_padding_mask = decoder_input[:, :, 0] == self.vocab_map["pad_token_id"]
            labels = decoder_input[:, :-1, 0].contiguous()
            labels_regression = decoder_input[:, :-1, 1:].contiguous()
        else:
            inputs = self.tokenizer(inputs, padding="longest", return_tensors="pt")
            targets = self.tokenizer(targets, padding="longest", return_tensors="pt")
            encoder_input = inputs["input_ids"]
            decoder_input = targets["input_ids"]
            encoder_padding_mask = encoder_input == self.vocab_map["pad_token_id"]
            decoder_padding_mask = decoder_input == self.vocab_map["pad_token_id"]
            labels = decoder_input[:, :-1, 0].contiguous()
            labels_regression = decoder_input[:, :-1, 1:].contiguous()

        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_padding_mask": encoder_padding_mask,
            "decoder_padding_mask": decoder_padding_mask,
            "labels": labels,
            "labels_for_regression": labels_regression,
        }


class SimpleDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    @torch.no_grad()
    def __call__(self, batch):
        # print("batch:", batch)
        input_texts = [item["input"] for item in batch]
        target_texts = [item["target"] for item in batch]

        input_encodings = self.tokenizer(input_texts, padding="longest", return_tensors="pt")
        target_encodings = self.tokenizer(target_texts, padding="longest", return_tensors="pt")

        return {
            # "input_ids": input_encodings["input_ids"],
            "encoder_input": input_encodings["input_ids"],
            "decoder_input": target_encodings["input_ids"],
            "encoder_padding_mask": ~input_encodings[
                "attention_mask"
            ].bool(),  # NOTE: attantion mask given by tokenizer is 0/1 multiplicative mask (0 for no attention) but transformer use bool mask (True for no attention)
            "decoder_padding_mask": ~target_encodings["attention_mask"].bool(),
            "labels": target_encodings["input_ids"][:, :-1].contiguous(),
        }


class SimpleDataCollatorMatrix:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    @torch.no_grad()
    def __call__(self, batch):
        input_texts = [item["input"].split(" ; ")[1] for item in batch]
        target_texts = [item["target"] for item in batch]
        matrix_dict = [item["input"].split(" ; ")[0] for item in batch]

        input_encodings = self.tokenizer(input_texts, padding="longest", return_tensors="pt")
        target_encodings = self.tokenizer(target_texts, padding="longest", return_tensors="pt")

        return {
            "encoder_input": input_encodings["input_ids"],
            "decoder_input": target_encodings["input_ids"],
            "encoder_padding_mask": ~input_encodings[
                "attention_mask"
            ].bool(),  # NOTE: attantion mask given by tokenizer is 0/1 multiplicative mask (0 for no attention) but transformer use bool mask (True for no attention)
            "decoder_padding_mask": ~target_encodings["attention_mask"].bool(),
            "labels": target_encodings["input_ids"][:, :-1].contiguous(),
            "matrix_ids": matrix_dict,
        }


class GPTDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    @torch.no_grad()
    def __call__(self, batch):
        input_texts = [item["input"] for item in batch]
        target_texts = [item["target"] for item in batch]

        combined_texts = [inp + " " + tgt for inp, tgt in zip(input_texts, target_texts)]

        encodings = self.tokenizer(combined_texts, padding="longest", return_tensors="pt")

        input_lengths = [len(self.tokenizer(inp)["input_ids"]) - 1 for inp in input_texts]  # -1 to exclude EOS

        labels = encodings["input_ids"].clone()
        for i, length in enumerate(input_lengths):
            labels[i, :length] = -100
        # breakpoint()
        return {
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"],
            "labels": labels,
        }
