import ee_backend
import argparse
import tqdm
import datetime

import os.path
import os
import torch
import numpy as np

import torch.utils.tensorboard.writer


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Train the Score Function and Question Selector models",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "cpu"],
        help="Device to use for training. Default is cuda.",
    )
    parser.add_argument(
        "--no_cache",
        action="store_true",
        help="Disable prompt cache for the candidate model.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default="cache",
        help="Directory to store the cache for the candidate model.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility.",
    )

    args = parser.parse_args()

    torch_device = args.device
    if torch_device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Please use CPU instead.")
    elif torch_device == "cuda":
        print("Using GPU for training.")
    else:
        print("Using CPU for training.")

    training_config = ee_backend.TrainingConfig.from_file(args.config)

    print("Loading model")

    question_set = ee_backend.SkillQuestionSet.load_from_json(
        training_config.question_set_path
    )
    candidate_set = ee_backend.CandidateSet.load_from_json(
        training_config.candidate_set_path
    )

    question_model = ee_backend.QuestionSelector(question_set)
    question_model = question_model.train()
    question_model = question_model.to(torch_device)

    score_model = ee_backend.ScoreFunction(question_set)
    score_model = score_model.train()
    score_model = score_model.to(torch_device)

    candidate_model = ee_backend.LIModelWrapper(
        model_path=training_config.llm_path,
        device="auto" if torch_device == "cuda" else "cpu",
        caching =not args.no_cache,
        cache_dir=args.cache_dir,
        seed =training_config.random_seed,
    )

    working_dir = os.path.join(
        training_config.output_dir,
        "runs",
        f"run-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}",
    )
    os.makedirs(working_dir, exist_ok=True)

    tb_writer = torch.utils.tensorboard.writer.SummaryWriter(
        log_dir=os.path.join(working_dir, "tensorboard"),
        flush_secs=30,
    )

    rng = np.random.default_rng(training_config.random_seed)

    optimizer_question = torch.optim.Adam(
        question_model.parameters(),
        lr=training_config.intial_learning_rate,
    )

    optimizer_score = torch.optim.Adam(
        score_model.parameters(),
        lr=training_config.intial_learning_rate,
    )

    loss_fn_question = torch.nn.CrossEntropyLoss()
    loss_fn_score = torch.nn.CrossEntropyLoss()

    # Make training_config and torch_device globally accessible for fairness_correction
    global TRAINING_CONFIG_GLOBAL, TORCH_DEVICE_GLOBAL
    TRAINING_CONFIG_GLOBAL = training_config
    TORCH_DEVICE_GLOBAL = torch_device

    for epoch in range(training_config.num_epochs):
        progress_bar = tqdm.tqdm(
            desc=f"[{epoch:03.00f}/{training_config.num_epochs:03.0f}]Starting training",
            total=training_config.fairness_batch_size,
            leave=True,
        )
        fairness_batch_num = 0
        rl_batch_num = 0
        fairness_interviews_batch: list[ee_backend.InterviewNode] = []
        fairness_scores_batch: list[tuple[float, float, ee_backend.InterviewNode]] = []
        fairness_result = None
        for fairness_batch_num in range(training_config.fairness_batch_size):
            new_interviews: list[ee_backend.InterviewNode] = []
            zscores: list[dict[int, float]] = []
            print(f"Generating new interviews for batch {fairness_batch_num}")
            for rl_batch_num in range(training_config.rl_batch_size):

                new_interviews = generate_interviews(
                    candidate_model,
                    question_model,
                    question_set,
                    candidate_set,
                    training_config.rl_batch_size,
                    training_config.summary_prefix,
                    training_config.tree_depth,
                    torch_device,
                    rng,
                )

                for interview in new_interviews:
                    zscores.append(
                        get_z(
                            interview,
                            question_model,
                            score_model,
                            candidate_model,
                            training_config.num_rollouts,
                            training_config.rollout_depth,
                            question_set,
                            training_config.summary_prefix,
                            torch_device,
                        )
                    )
                    print(
                        f"Interview {interview.candidate.name} zscores: {zscores[-1]}"
                    )

                f_loss, q_loss, score_logits = rl_update(
                    question_model,
                    score_model,
                    new_interviews,
                    zscores,
                    loss_fn_question,
                    loss_fn_score,
                    optimizer_question,
                    optimizer_score,
                    training_config.target_skill,
                    torch_device,
                )
                zscores = []
                print(
                    f"Batch {fairness_batch_num} RL batch {rl_batch_num} "
                    f"Q loss: {q_loss:.4f} ")
                fairness_interviews_batch += new_interviews
                fairness_scores_batch += [
                    (logits[0], logits[1], i) for logits,i in zip(score_logits, new_interviews)
                ]
                progress_bar.update(1)

                tb_writer.add_scalar(
                    "loss/question_loss",
                    q_loss,
                    fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
                )
                tb_writer.add_scalar(
                    "loss/score_loss",
                    f_loss,
                    fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
                )
                tb_writer.add_scalar(
                    "loss/mean_z_score",
                    np.mean([
                        sorted(zs.items(), key=lambda x: x[1], reverse=True)[0][1] for zs in zscores]),
                    fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
                )
                tb_writer.add_scalar(
                    "epoch",
                    epoch,
                    fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
                )

                fairness_result = ee_backend.fairness_correction(
                    score_model,
                    fairness_scores_batch,
                    training_config.target_skill,
                    training_config.fairness_epsilon,
                    torch_device,
                )
        fairness_interviews_batch = []
        fairness_scores_batch = []
        assert fairness_result is not None, "Fairness correction failed"
        tb_writer.add_scalar(
            "loss/initial_fairness_loss",
            fairness_result.initial_fairness_loss,
            fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
        )
        tb_writer.add_scalar(
            "loss/fairness_loss",
            fairness_result.fairness_loss,
            fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
        )
        if fairness_result.l_c is not None:
            tb_writer.add_scalar(
                "loss/l_c",
                sum(fairness_result.l_c.values()),
                fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
            )
            for class_name, multiplier in fairness_result.l_c.items():
                tb_writer.add_scalar(
                    f"loss/l_c/{class_name}",
                    multiplier,
                    fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
                )
        else:
            tb_writer.add_scalar(
                "loss/l_c",
                0.0,
                fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
            )
        tb_writer.add_scalar(
            "epoch",
            epoch,
            fairness_batch_num * training_config.rl_batch_size + rl_batch_num,
        )


    tb_writer.close()


def generate_interviews(
    candidate_model: ee_backend.LIModelWrapper,
    question_model: ee_backend.QuestionSelector,
    question_set: ee_backend.SkillQuestionSet,
    candidate_set: ee_backend.CandidateSet,
    batch_size: int,
    summary_prefix: str,
    num_questions: int,
    torch_device: str,
    rng: np.random.Generator,
) -> list[ee_backend.InterviewNode]:
    interviews: list[ee_backend.InterviewNode] = []
    for i in range(batch_size):
        candidate = candidate_set.sample_candidate(rng)
        interview_root = ee_backend.InterviewNode(
            parent=None,
            children={},
            question=None,
            question_id=None,
            num_questions=len(question_model.question_set),
            candidate=candidate,
            _node_tensor=None,
        )
        interview_root.ask_root_summary(
            summary_prefix=summary_prefix,
            model=candidate_model,
        )
        current_node = interview_root
        for i in range(num_questions):
            questions_dist = question_model.get_questions_for_node(
                current_node,
                question_set,
                device=torch_device,
            )
            current_node = current_node.ask_question(
                questions_dist.top_question,
                questions_dist.top_question_int + 1,
                len(question_model.question_set),
                summary_prefix,
                candidate_model,
            )
            interviews.append(current_node)
    return interviews


def get_z(
    interview_node: ee_backend.InterviewNode,
    question_model: ee_backend.QuestionSelector,
    score_model: ee_backend.ScoreFunction,
    candidate_model: ee_backend.LIModelWrapper,
    num_rollouts: int,
    rollout_depth: int,
    question_set: ee_backend.SkillQuestionSet,
    summary_prefix: str,
    torch_device: str,
) -> dict[int, float]:
    ret_dict: dict[int, float] = {}
    questions_dist = question_model.get_questions_for_node(
                interview_node,
                question_set,
                device=torch_device,
            )
    possible_questions = sorted(questions_dist.question_dist_int.items(), key=lambda x: x[1], reverse=True)
    possible_questions = [x[0] for x in possible_questions if x[0] not in interview_node.previous_questions()]

    for i in range(num_rollouts):
        rollout_question = possible_questions.pop(0)
        new_node = interview_node.ask_question(
            question=question_set.idx_to_skill[rollout_question],
            question_id =rollout_question + 1,
            num_questions=len(question_model.question_set),
            summary_prefix=summary_prefix,
            model=candidate_model,
        )
        for j in range(rollout_depth):
            questions_dist = question_model.get_questions_for_node(
                new_node,
                question_set,
                device=torch_device,
            )
            new_node = new_node.ask_question(
                questions_dist.top_question,
                questions_dist.top_question_int + 1,
                len(question_model.question_set),
                summary_prefix,
                candidate_model,
            )
        new_score = score_model.get_score_for_node(
            new_node,
            torch_device,
        )
        ret_dict[rollout_question] = new_score.to_z_score()
    return ret_dict


def rl_update(
    question_model: ee_backend.QuestionSelector,
    score_model: ee_backend.ScoreFunction,
    interviews: list[ee_backend.InterviewNode],
    zscores: list[dict[int, float]],
    loss_question: torch.nn.Module,
    loss_score: torch.nn.Module,
    optimizer_question: torch.optim.Optimizer,
    optimizer_score: torch.optim.Optimizer,
    target_skill: str,
    device: str,
) -> tuple[float, float, list[tuple[float, float]]]:
    question_model.train()
    score_model.train()

    optimizer_question.zero_grad()
    optimizer_score.zero_grad()

    interview_tensors = torch.stack([
        interview_node.get_node_tensor(
            question_model.input_dim,
            device,
        )[:2,] for interview_node in interviews
    ]).to(device)

    q_logits = question_model(
        interview_tensors,
    )[:, 0, :]

    labels:list[int] = []
    weights:list[float] = []
    for z_dict in zscores:
        label, weight = max(z_dict.items(), key=lambda x: x[1])
        labels.append(label)
        weights.append(weight)

    labels_tensor = torch.tensor(labels, device=device)
    labels_tensor_onehot = torch.nn.functional.one_hot(
        labels_tensor,
        num_classes=len(question_model.question_set),
    )
    weights_tensor = torch.tensor(weights, device=device)
    question_loss = loss_question(
        q_logits,
        labels_tensor_onehot.float(),

    ) * weights_tensor
    question_loss = question_loss.mean()
    question_loss.backward()
    optimizer_question.step()

    score_logits = score_model(
        interview_tensors,
    )[:, -1, :]

    score_labels = [
        1 if target_skill in i.candidate.skills else 0
        for i in interviews
    ]
    score_labels_tensor = torch.tensor(score_labels, device=device)
    score_labels_tensor_onehot = torch.nn.functional.one_hot(
        score_labels_tensor,
        num_classes=2,
    )

    score_loss = loss_score(
        score_logits,
        score_labels_tensor_onehot.float(),
    )
    score_loss = score_loss.mean()
    score_loss.backward()
    optimizer_score.step()
    return question_loss.item(), score_loss.item(
        ), [(sl[0], sl[1]) for sl in score_logits.tolist()]



if __name__ == "__main__":
    main()
