import math
from pathlib import Path
import random
import re
import xml.etree.ElementTree as ET
from functools import partial

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator

from core.llm import LLM
from core.messages import Role, Message
from curriculum.lesson import Exercise
from curriculum.exercise_with_answers import xml_dump
from training.student_teacher_dataset import read_exercises, StudentTeacherDataset, IGNORE_INDEX
from training.utils import augment_filenames


def chunk_list(data, n):
    """
    Splits 'data' into 'n' roughly equal contiguous chunks.
    """
    length = len(data)
    chunk_size = math.ceil(length / n)
    for i in range(0, length, chunk_size):
        yield data[i : i + chunk_size]


def main(
    base: str = "llama3-8b-instruct",
    partitions: int = None,
    partition_type: str = None,
    dataset: str = None,
    lesson_temp: float = 1.5,
    lesson_num_choices: int = 1,
    augment_files: bool = True,
    dataset_llama3_8b: bool = False,
    dataset_llama3_70b: bool = False,
    dataset_qwen25_3b: bool = False,
    dataset_qwen25_14b: bool = False,
    dataset_qwen25_72b: bool = False,
    datapath: Path = Path("data"),
    bonito_model: str = "llama3-8b-instruct",
    bonito_questions: int = 6,
    bonito_temperature: float = 1.5,
    bonito_max_items_train: int = 1000,
    overwrite_file: bool = False,
    batch_size: int = 4,
):
    if dataset == "amazon_default":
        file = f"amazon_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml"
    elif dataset == "new_wiki_default":
        file = f"new_wiki_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml"
    elif dataset == "nyt_default":
        file = f"nyt_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml"
    elif dataset == "reddit_default":
        file = f"reddit_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml"
    else:
        raise NotImplementedError(f"Unknown dataset {dataset}")

    file = augment_filenames([file], lesson_temp, lesson_num_choices, vllm=True,
                             llama3_8b=dataset_llama3_8b, llama3_70b=dataset_llama3_70b,
                             qwen25_3b=dataset_qwen25_3b, qwen25_14b=dataset_qwen25_14b, qwen25_72b=dataset_qwen25_72b)

    if "llama3" in base:
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
    else:
        opening_message = None
    base_llm = LLM(base, opening_message=opening_message)
    accelerator = Accelerator(mixed_precision="bf16")

    model = base_llm.load_model()
    model = model.to(torch.bfloat16)
    model = accelerator.prepare(model)
    model = model.eval()

    dataset = StudentTeacherDataset(
        base_llm, file, verbose=False,
        datapath=datapath,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=partial(dataset.collate_fn, padding_value=0, llm=base_llm),
        shuffle=False,
    )
    dataloader = accelerator.prepare(dataloader)

    klds = []
    entropies = []
    for batch in dataloader:
        with torch.no_grad():
            student_inputs = batch['student_seqs'][..., :-1]  # (batch_size, seq_length)
            student_labels = batch['student_labels'][..., 1:]  # (batch_size, seq_length)
            student_masks = student_labels != IGNORE_INDEX
            batch_size, seq_length = student_inputs.shape

            teacher_inputs = batch['teacher_seqs'][..., :-1]
            teacher_masks = batch['teacher_masks'][..., 1:]
            teacher_output = model.forward(teacher_inputs)
            t_logits = teacher_output.logits[teacher_masks].detach()
            t_log_probs = F.log_softmax(t_logits, dim=-1).detach()
            t_probs = torch.exp(t_log_probs)
            t_entropy = -(t_probs * t_log_probs).sum(-1)  # Average entropy
            t_entropy_mx = torch.zeros(batch_size, teacher_inputs.shape[-1], device=teacher_inputs.device)
            t_entropy_mx[teacher_masks] = t_entropy
            t_entropy = t_entropy_mx.sum(-1) / teacher_masks.sum(-1)

            student_output = model.forward(student_inputs)
            student_logits = student_output.logits
            s_logits = student_logits[student_masks]
            s_log_probs = F.log_softmax(s_logits, dim=-1)
            logit_loss = F.kl_div(
                s_log_probs, t_log_probs, log_target=True,
                reduction="none",
            )
            logit_loss_t = logit_loss.sum(-1)  # (n_tokens,)

            logit_loss_mx = torch.zeros(batch_size, seq_length, device=student_inputs.device, dtype=logit_loss_t.dtype)
            logit_loss_mx[student_masks] = logit_loss_t
            logit_loss = logit_loss_mx.sum(-1) / student_masks.sum(-1)

            klds += logit_loss.tolist()
            entropies += t_entropy.tolist()

    print("Average entropy", sum(entropies) / len(entropies))
    print("Average KL-divergence", sum(klds) / len(klds))

    # ---- Read the original exercises ----
    file_path = datapath / file[0]
    print(f"Reading lessons from {file_path}")
    exercises = read_exercises(file_path)
    print(f"Number of exercises read: {len(exercises)}")
    
    assert len(klds) == len(entropies) == len(exercises)

    if not partitions or partitions <= 0:
        raise ValueError("Number of partitions must be a positive integer.")

    if partition_type == "kld":
        # Pair up (kld, exercise), sort ascending
        paired_data = sorted(zip(klds, exercises), key=lambda x: x[0])
    elif partition_type == "entropy":
        # Pair up (entropy, exercise), sort ascending
        paired_data = sorted(zip(entropies, exercises), key=lambda x: x[0])
    else:
        raise ValueError(f"Unknown partition_type '{partition_type}'. Must be 'kld' or 'entropy'.")

    # After sorting, keep only the exercises in the new order
    sorted_exercises = [p[1] for p in paired_data]

    # Chunk them into N (partitions) contiguous slices
    exercise_chunks = list(chunk_list(sorted_exercises, partitions))
    print(f"Split into {partitions} chunks. Starting to write.")

    # ---- Write out each chunk as an XML file ----
    for chunk_idx, chunk in enumerate(exercise_chunks, start=1):
        root = ET.Element("exercises_with_answers")
        ET.SubElement(root, "temperature", value=str(lesson_temp))

        for ex in chunk:
            ex.to_xml(root)

        # Prepare the new filename for each chunk
        orig_filename = str(file[0])
        output_filename = orig_filename.replace(".xml", f"_chunk{chunk_idx}_{partition_type}.xml")
        output_path = datapath / output_filename

        if not overwrite_file and output_path.exists():
            raise FileExistsError(
                f"The file '{output_path}' already exists. Exiting to avoid overwriting."
            )

        with open(output_path, "w") as output_file:
            xml_dump(root, output_file)

        print(f"Saved chunk {chunk_idx} to {output_path}")


if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
