import ee_backend
import argparse
import tqdm
import datetime

import os.path
import os
import torch
import numpy as np
import pytz
import dataclasses
import csv
import torch.utils.tensorboard.writer


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Train the Score Function and Question Selector models",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "cpu"],
        help="Device to use for training. Default is cuda.",
    )
    parser.add_argument(
        "--no_cache",
        action="store_true",
        help="Disable prompt cache for the candidate model.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default="cache",
        help="Directory to store the cache for the candidate model.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--no_fairness_correct",
        action="store_true",
        help="Disable fairness correction step.",
    )
    parser.add_argument(
        "--logs_dir",
        type=str,
        default="logs",
        help="Directory to store the training logs.",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="checkpoints",
        help="Directory to store the model checkpoints.",
    )
    parser.add_argument(
        "--run_name",
        type=str,
        default=f"gen_2-{
            datetime.datetime.now(tz=pytz.timezone('America/New_York')).strftime(
                '%Y-%m-%d_%H-%M-%S'
            )
        }",
        help="Name of the training run. Used for logging and checkpointing.",
    )

    args = parser.parse_args()

    no_cache: bool = args.no_cache
    cache_dir: str = args.cache_dir
    seed: int = args.seed
    no_fairness_correct: bool = args.no_fairness_correct
    logs_dir: str = args.logs_dir
    save_dir: str = args.save_dir

    run_name: str = args.run_name
    if no_fairness_correct:
        run_name = f"{run_name}-no_fairness_correct"

    torch_device = args.device
    if torch_device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Please use CPU instead.")
    elif torch_device == "cuda":
        print("Using GPU for training.")
    else:
        print("Using CPU for training.")

    config = ee_backend.TrainingConfig.from_file(args.config, cache_dir)

    working_dir = os.path.join(
        config.output_dir,
        "runs",
        f"run-{run_name}",
    )
    os.makedirs(working_dir, exist_ok=True)
    tb_writer = torch.utils.tensorboard.writer.SummaryWriter(
        log_dir=os.path.join(working_dir, "tensorboard"),
        comment="Training run",
        flush_secs=30,
    )
    models, optimizers = get_models(no_cache, cache_dir, config, torch_device)
    rng = np.random.default_rng(seed)

    progress_bar = tqdm.tqdm(
        total=config.num_epochs * config.fairness_batch_size,
        desc="Starting training",
        leave=True,
    )
    for epoch in range(config.num_epochs):
        progress_bar.set_description(f"Epoch {epoch + 1}/{config.num_epochs}")
        train_epoch(
            models,
            optimizers,
            config,
            torch_device,
            rng,
            tb_writer,
            logs_dir,
            save_dir,
            run_name,
            epoch,
            no_fairness_correct,
            progress_bar,
        )


@dataclasses.dataclass
class Models:
    question_set: ee_backend.SkillQuestionSet
    candidate_set: ee_backend.CandidateSet
    candidate_model: ee_backend.LIModelWrapper
    question_model: ee_backend.QuestionSelector
    score_model: ee_backend.ScoreFunction
    score_model_lc: dict[int, float] | None = None


@dataclasses.dataclass
class Optimizers:
    question_optimizer: torch.optim.Optimizer
    loss_fn_question: torch.nn.Module
    score_optimizer: torch.optim.Optimizer
    loss_fn_score: torch.nn.Module


def get_models(
    no_cache: bool,
    cache_dir: str,
    training_config: ee_backend.TrainingConfig,
    torch_device: str,
) -> tuple[Models, Optimizers]:
    question_set = ee_backend.SkillQuestionSet.load_from_json(
        training_config.question_set_path
    )
    candidate_set = ee_backend.CandidateSet.load_from_json(
        training_config.candidate_set_path,
        os.path.join(cache_dir, "candidate_cache")
    )

    question_model = ee_backend.QuestionSelector(question_set)
    question_model = question_model.train()
    question_model = question_model.to(torch_device)

    score_model = ee_backend.ScoreFunction(question_set)
    score_model = score_model.train()
    score_model = score_model.to(torch_device)

    candidate_model = ee_backend.LIModelWrapper(
        model_path=training_config.llm_path,
        device="auto" if torch_device == "cuda" else "cpu",
        caching=not no_cache,
        cache_dir=cache_dir,
        seed=training_config.random_seed,
    )
    optimizer_question = torch.optim.Adam(
        question_model.parameters(),
        lr=training_config.intial_learning_rate,
    )

    optimizer_score = torch.optim.Adam(
        score_model.parameters(),
        lr=training_config.intial_learning_rate,
    )

    loss_fn_question = torch.nn.CrossEntropyLoss()
    loss_fn_score = torch.nn.BCELoss()
    return Models(
        question_set=question_set,
        candidate_set=candidate_set,
        candidate_model=candidate_model,
        question_model=question_model,
        score_model=score_model,
    ), Optimizers(
        question_optimizer=optimizer_question,
        loss_fn_question=loss_fn_question,
        score_optimizer=optimizer_score,
        loss_fn_score=loss_fn_score,
    )


def train_epoch(
    models: Models,
    optimizers: Optimizers,
    config: ee_backend.TrainingConfig,
    torch_device: str,
    rng: np.random.Generator,
    tb_writer: torch.utils.tensorboard.writer.SummaryWriter,
    logs_dir: str,
    save_dir: str,
    run_name: str,
    epoch: int,
    no_fairness_correct: bool,
    progress_bar: tqdm.tqdm,
) -> None:
    with open(
        os.path.join(
            logs_dir,
            f"train-{run_name}.csv",
        ),"a") as log_file:
        log_writer = csv.DictWriter(
            log_file,
            fieldnames=[
                "epoch",
                "fairness_batch_num",
                "rl_batch_num",
                "score_function_loss",
                "question_selector_loss",
                "mean_z_score",
                "initial_fairness_loss",
                "final_fairness_loss",
                "initial_accuracy",
                "corrected_accuracy",
                "fairness_correction",
                "l_c",
            ],
        )
        if epoch == 0 and log_file.tell() == 0:
            log_writer.writeheader()
            log_file.flush()


        for fairness_batch_num in range(config.fairness_batch_size):
            progress_bar.set_description(
                f"Epoch {epoch + 1}/{config.num_epochs} Fairness Batch {fairness_batch_num + 1}/{config.fairness_batch_size}"
            )
            for rl_batch_num in range(config.rl_batch_size):
                batch_xy = generate_interviews(
                    models.candidate_model,
                    models.question_model,
                    models.question_set,
                    models.candidate_set,
                    config.rl_batch_size,
                    config.summary_prefix,
                    config.tree_depth,
                    torch_device,
                    rng,
                )
                full_zscores: list[dict[int, float]] = []
                for interview in batch_xy:
                    full_zscores.append(
                        get_z(
                            interview,
                            models.question_model,
                            models.score_model,
                            models.score_model_lc,
                            models.candidate_model,
                            config.num_rollouts,
                            config.rollout_depth,
                            models.question_set,
                            config.summary_prefix,
                            torch_device,
                        )
                    )
                batch_x: list[torch.Tensor] = []
                batch_y: list[torch.Tensor] = []
                for interview in batch_xy:
                    batch_x.append(
                        interview.get_x(
                            models.question_model.input_dim,
                            torch_device
                        )
                    )
                    batch_y.append(
                        interview.get_y(
                            config.target_skill,
                            torch_device
                        )
                    )

                f_loss, q_loss, score_logits = rl_update(
                    models.question_model,
                    models.score_model,
                    batch_x,
                    batch_y,
                    full_zscores,
                    optimizers.loss_fn_question,
                    optimizers.loss_fn_score,
                    optimizers.question_optimizer,
                    optimizers.score_optimizer,
                    torch_device
                )
                report_rl_batch(
                    tb_writer,
                    epoch,
                    fairness_batch_num,
                    rl_batch_num,
                    f_loss,
                    q_loss,
                    full_zscores,
                    progress_bar,
                    config,
                )
                log_writer.writerow({
                    "epoch": epoch,
                    "fairness_batch_num": fairness_batch_num,
                    "rl_batch_num": rl_batch_num,
                    "score_function_loss": f_loss,
                    "question_selector_loss": q_loss,
                    "mean_z_score": np.mean(
                        [
                            sorted(zs.items(), key=lambda x: x[1], reverse=True)[0][1]
                            for zs in full_zscores
                        ]
                    ),
                    "initial_fairness_loss": None,
                    "final_fairness_loss": None,
                    "initial_accuracy": None,
                    "corrected_accuracy": None,
                    "fairness_correction": not no_fairness_correct,
                    "l_c": models.score_model_lc,
                })
                log_file.flush()
                print(
                    f"Batch {fairness_batch_num} RL batch {rl_batch_num} Q loss: {q_loss:.4f} "
                )
                progress_bar.set_postfix(
                    {
                        "f_loss": f_loss,
                        "q_loss": q_loss,
                    }
                )
            batch_xy_fair = generate_interviews(
                models.candidate_model,
                models.question_model,
                models.question_set,
                models.candidate_set,
                config.fairness_batch_size,
                config.summary_prefix,
                config.tree_depth,
                torch_device,
                rng,
            )
            fairness_result = fairness_correction(
                models.score_model,
                batch_xy_fair,
                config.target_skill,
                config.fairness_epsilon,
                torch_device,
                no_fairness_correct,
            )
            log_writer.writerow({
                "epoch": epoch,
                "fairness_batch_num": fairness_batch_num,
                "rl_batch_num": None,
                "score_function_loss": None,
                "question_selector_loss": None,
                "mean_z_score": None,
                "initial_fairness_loss": fairness_result.initial_fairness_loss,
                "final_fairness_loss": fairness_result.fairness_loss,
                "initial_accuracy": fairness_result.initial_accuracy,
                "corrected_accuracy": fairness_result.corrected_accuracy,
                "fairness_correction": not no_fairness_correct,
                "l_c": fairness_result.l_c,
            })
            report_fairness_batch(
                tb_writer,
                epoch,
                fairness_batch_num,
                fairness_result,
            )
            progress_bar.update(1)
            if not no_fairness_correct:
                models.score_model_lc = fairness_result.l_c


def generate_interviews(
    candidate_model: ee_backend.LIModelWrapper,
    question_model: ee_backend.QuestionSelector,
    question_set: ee_backend.SkillQuestionSet,
    candidate_set: ee_backend.CandidateSet,
    batch_size: int,
    summary_prefix: str,
    num_questions: int,
    torch_device: str,
    rng: np.random.Generator,
) -> list[ee_backend.InterviewNode]:
    interviews: list[ee_backend.InterviewNode] = []
    for i in range(batch_size):
        candidate = candidate_set.sample_candidate(rng)
        interview_root = ee_backend.InterviewNode(
            parent=None,
            children={},
            question=None,
            question_id=None,
            num_questions=len(question_model.question_set),
            candidate=candidate,
            _node_tensor=None,
        )
        interview_root.ask_root_summary(
            summary_prefix=summary_prefix,
            model=candidate_model,
        )
        current_node = interview_root
        for i in range(num_questions):
            questions_dist = question_model.get_questions_for_node(
                current_node,
                question_set,
                device=torch_device,
            )
            current_node = current_node.ask_question(
                questions_dist.top_question,
                questions_dist.top_question_int + 1,
                len(question_model.question_set),
                summary_prefix,
                candidate_model,
            )
            interviews.append(current_node)
    return interviews


def get_z(
    interview_node: ee_backend.InterviewNode,
    question_model: ee_backend.QuestionSelector,
    score_model: ee_backend.ScoreFunction,
    l_c:dict[int, float] | None,
    candidate_model: ee_backend.LIModelWrapper,
    num_rollouts: int,
    rollout_depth: int,
    question_set: ee_backend.SkillQuestionSet,
    summary_prefix: str,
    torch_device: str,
) -> dict[int, float]:
    ret_dict: dict[int, float] = {}
    questions_dist = question_model.get_questions_for_node(
        interview_node,
        question_set,
        device=torch_device,
    )
    possible_questions = sorted(
        questions_dist.question_dist_int.items(), key=lambda x: x[1], reverse=True
    )
    possible_questions = [
        x[0]
        for x in possible_questions
        if x[0] not in interview_node.previous_questions()
    ]

    for i in range(num_rollouts):
        rollout_question = possible_questions.pop(0)
        new_node = interview_node.ask_question(
            question=question_set.idx_to_skill[rollout_question],
            question_id=rollout_question + 1,
            num_questions=len(question_model.question_set),
            summary_prefix=summary_prefix,
            model=candidate_model,
        )
        for j in range(rollout_depth):
            questions_dist = question_model.get_questions_for_node(
                new_node,
                question_set,
                device=torch_device,
            )
            new_node = new_node.ask_question(
                questions_dist.top_question,
                questions_dist.top_question_int + 1,
                len(question_model.question_set),
                summary_prefix,
                candidate_model,
            )
        new_score = score_model.get_score_for_node(
            new_node,
            torch_device,
        )
        if l_c is not None:
            new_score = new_score.get_corrected_score(
                new_node.candidate.five_factors.to_int_dict(),
                l_c
            )
        ret_dict[rollout_question] = new_score.to_z_score()
    return ret_dict



def rl_update(
    question_model: ee_backend.QuestionSelector,
    score_model: ee_backend.ScoreFunction,
    batch_x: list[torch.Tensor],
    batch_y: list[torch.Tensor],
    batch_z: list[dict[int, float]],
    loss_question: torch.nn.Module,
    loss_score: torch.nn.Module,
    optimizer_question: torch.optim.Optimizer,
    optimizer_score: torch.optim.Optimizer,
    device: str,
) -> tuple[float, float, list[tuple[float, float]]]:
    question_model.train()
    score_model.train()

    optimizer_question.zero_grad()
    optimizer_score.zero_grad()

    interview_tensors = torch.stack([x[:2,] for x in batch_x]).to(device)

    q_logits = question_model(
        interview_tensors,
    )[:, 0, :]

    labels: list[int] = []
    weights: list[float] = []
    for z_dict in batch_z:
        label, weight = max(z_dict.items(), key=lambda x: x[1])
        labels.append(label)
        weights.append(weight)

    labels_tensor = torch.tensor(labels, device=device)
    labels_tensor_onehot = torch.nn.functional.one_hot(
        labels_tensor,
        num_classes=len(question_model.question_set),
    )
    weights_tensor = torch.tensor(weights, device=device)
    question_loss = (
        loss_question(
            q_logits,
            labels_tensor_onehot.float(),
        )
        * weights_tensor
    )
    question_loss = question_loss.mean()
    question_loss.backward()
    optimizer_question.step()

    score_logits = score_model(
        interview_tensors,
    )[:, -1, :]

    score_labels_tensor_onehot = torch.stack(
        batch_y,
    ).to(device)

    score_loss = loss_score(
        score_logits,
        score_labels_tensor_onehot.float(),
    )
    score_loss = score_loss.mean()
    score_loss.backward()
    optimizer_score.step()
    return (
        question_loss.item(),
        score_loss.item(),
        [(sl[0], sl[1]) for sl in score_logits.tolist()],
    )


def report_rl_batch(
    tb_writer: torch.utils.tensorboard.writer.SummaryWriter,
    epoch: int,
    fairness_batch_num: int,
    rl_batch_num: int,
    f_loss: float,
    q_loss: float,
    zscores: list[dict[int, float]],
    progress_bar: tqdm.tqdm,
    config: ee_backend.TrainingConfig,
) -> None:
    print(f"Batch {fairness_batch_num} RL batch {rl_batch_num} Q loss: {q_loss:.4f} ")
    progress_bar.update(1)

    tb_writer.add_scalar(
        "loss/question_loss",
        q_loss,
        fairness_batch_num * config.rl_batch_size + rl_batch_num,
    )
    tb_writer.add_scalar(
        "loss/score_loss",
        f_loss,
        fairness_batch_num * config.rl_batch_size + rl_batch_num,
    )
    tb_writer.add_scalar(
        "loss/mean_z_score",
        np.mean(
            [
                sorted(zs.items(), key=lambda x: x[1], reverse=True)[0][1]
                for zs in zscores
            ]
        ),
        fairness_batch_num * config.rl_batch_size + rl_batch_num,
    )
    tb_writer.add_scalar(
        "epoch",
        epoch,
        fairness_batch_num * config.rl_batch_size + rl_batch_num,
    )


def fairness_correction(
    score_model: ee_backend.ScoreFunction,
    batch_xy: list[ee_backend.InterviewNode],
    target_skill: str,
    fairness_epsilon: float,
    device: str,
    skip_fairness_correction: bool,
) -> ee_backend.FairnessCorrectionResult:
    samples:list[ee_backend.FairnessSample] = []
    for i, interview in enumerate(batch_xy):
        sample_score = score_model.get_score_for_node(
            interview,
            device=device,
        )
        candidate_score = 1 if target_skill in interview.candidate.skills else 0
        samples.append(
            ee_backend.FairnessSample(
                class_weights=interview.candidate.five_factors.to_int_dict(),
                candidate_index=i,
                candidate_score=sample_score.score_true,
                candidate_ground_truth=candidate_score,
            )
        )
    initial_fairness_loss = ee_backend.get_fairness_loss(samples, None)

    if skip_fairness_correction:
        multipliers =  multipliers = {
        class_name: 0.0
        for i, class_name in enumerate(sorted(samples[0].class_weights.keys()))
    }
        fairness_loss = initial_fairness_loss
    else:
        multipliers = ee_backend.compute_multipliers(samples,
        )

        fairness_loss = ee_backend.get_fairness_loss(samples, multipliers)
        if fairness_loss > fairness_epsilon:
            print(
                f"Fairness loss {fairness_loss} is greater than epsilon {fairness_epsilon}. It started at {initial_fairness_loss}."
                "This indicates that the fairness correction did not succeed."
            )

    initial_accuracy = ee_backend.get_accuracy(
        samples,
        None)
    corrected_accuracy = ee_backend.get_accuracy(
        samples,
        multipliers)
    return ee_backend.FairnessCorrectionResult(
            initial_fairness_loss=initial_fairness_loss,
            fairness_loss=fairness_loss,
            distillation_loss=None,
            distill_fairness_loss=None,
            l_c=multipliers,
            initial_accuracy=initial_accuracy,
            corrected_accuracy=corrected_accuracy,
        )


def report_fairness_batch(
    tb_writer: torch.utils.tensorboard.writer.SummaryWriter,
    epoch: int,
    fairness_batch_num: int,
    fairness_result: ee_backend.FairnessCorrectionResult,
) -> None:
    print(f"Batch {fairness_batch_num} Fairness correction result: {fairness_result}, intial accuracy: {fairness_result.initial_accuracy:.4f}, corrected accuracy: {fairness_result.corrected_accuracy:.4f}")
    tb_writer.add_scalar(
        "fairness/initial_fairness_loss",
        fairness_result.initial_fairness_loss,
        fairness_batch_num,
    )
    tb_writer.add_scalar(
        "fairness/final_fairness_loss",
        fairness_result.fairness_loss,
        fairness_batch_num,
    )
    assert fairness_result.l_c is not None, "Fairness result l_c should not be None"
    for label, value in fairness_result.l_c.items():
        tb_writer.add_scalar(
            f"fairness/l_c/{label}",
            value,
            epoch,
        )
    tb_writer.add_scalar(
        "fairness/initial_accuracy",
        fairness_result.initial_accuracy,
        epoch,
    )
    tb_writer.add_scalar(
        "fairness/corrected_accuracy",
        fairness_result.corrected_accuracy,
        epoch,
    )

if __name__ == "__main__":
    main()
