import os
import time
import json
import logging
import datetime
from traceback import format_exc

import random
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, balanced_accuracy_score, accuracy_score
import torch
from torch.distributions import Categorical, Bernoulli
import transformers

from arguments import Args
from data_generator import DATASET_MAPPING, DETERMINISTIC_TASKS
from data_generator import GeoILPDataset
from model import ProbProgram, SymbProgram
from utils import CosineAnnealingRestartScheduler, LinearAnnealingRestartScheduler, ExponentialScheduler
from utils import PeriodicHandler


logger = logging.getLogger(__name__)


@torch.no_grad()
def symbolic_evaluation(
    symb_output: list[torch.Tensor],
    targets_label: list[torch.Tensor],
    total_truth_values: int
):
    y_true = torch.cat(tuple(l.flatten() for l in targets_label)).detach().cpu().numpy()
    y_pred = symb_output.detach().cpu().numpy()
    P, R, F1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0.0)
    acc = accuracy_score(y_true, y_pred)
    balanced_acc = balanced_accuracy_score(y_true, y_pred)

    return P, R, F1, acc, balanced_acc


def _loo_topk(x: torch.Tensor, k: int):
    x = x.flatten()
    global_topkp1, _ = torch.topk(x, k + 1)

    the_k, the_kp1 = global_topkp1[-2], global_topkp1[-1]
    loo_topk = torch.where(x >= the_k, the_kp1, the_k)

    return loo_topk


def run(args):
    logger.info(f"Configuration: {json.dumps(args.to_json(), indent=4)}")
    results = {}

    assert not list(args.tensorboard_dir.iterdir()), f"Tensorboard directory is not empty: {args.tensorboard_dir}"


    # ================================================
    #   Initialization
    # ================================================
    torch.set_num_threads(args.num_intraop_threads)
    torch.set_num_interop_threads(args.num_interop_threads)
    logger.info(f"Using {torch.get_num_interop_threads()} interop threads & CUDA streams | {torch.get_num_threads()} intraop threads")

    
    # ================================================
    #   Fetch training data
    # ================================================
    if "GeoILP" in args.task:
        dataset = GeoILPDataset(args.task, args.device)
        train_bk_init, train_targets_label, train_targets_init = dataset.generate_data(is_train=True)
        train_num_constants = dataset.train_num_constants
    else:
        train_num_constants = args.train_num_constants
        dataset = DATASET_MAPPING[args.task]()
        train_bk_init, train_targets_label = dataset.generate_data(
            train_num_constants, device=args.device
        )
        train_targets_init = tuple(torch.zeros_like(p) for p in train_targets_label)

    train_aux_init = tuple(
        torch.zeros([train_num_constants] * args.max_aux_arity, dtype=bool, device=args.device)
        for _ in range(args.num_aux_predicates)
    )
    train_predicates_init = train_bk_init + train_aux_init + train_targets_init

    num_train_targets_values = sum(p.numel() for p in train_targets_label)

    bk_names, targets_names = dataset.predicate_names()
    aux_names = tuple(f"aux_{i + 1}" for i in range(args.num_aux_predicates))
    predicate_names = bk_names + aux_names + targets_names


    # ================================================
    #   Fetch evaluation data
    # ================================================
    if "GeoILP" in args.task:
        eval_bk_init, eval_targets_label, eval_targets_init = dataset.generate_data(is_train=False)
        eval_num_constants = dataset.eval_num_constants
    else:
        eval_num_constants = args.eval_num_constants
        eval_bk_init, eval_targets_label = dataset.generate_data(eval_num_constants, device=args.device)
        eval_targets_init = tuple(torch.zeros_like(p) for p in eval_targets_label)

    eval_aux_init = tuple(
        torch.zeros([eval_num_constants] * args.max_aux_arity, dtype=bool, device=args.device)
        for _ in range(args.num_aux_predicates)
    )
    eval_predicates_init = eval_bk_init + eval_aux_init + eval_targets_init

    num_eval_targets_values = sum(p.numel() for p in eval_targets_label)


    # ================================================
    #   Initialize model
    # ================================================
    model = ProbProgram(
        args.cwa,
        len(train_bk_init),
        len(train_aux_init),
        len(train_targets_init),
        [p.ndim for p in train_bk_init],
        [p.ndim for p in train_aux_init],
        [p.ndim for p in train_targets_init],
        args.max_occurrence_in_body,
        args.num_rules,
        args.predicate_embed_dim,
        args.variable_embed_dim,
        args.rule_head_atom_embed_dim,
        args.rule_body_atom_embed_dim,
        args.max_variables,
        args.max_body_atoms,
        args.remove_irrelevant_vars
    )

    model = model.to(args.device)
    model.head_arities_tensor = model.head_arities_tensor.to(args.device)
    model.body_arities_tensor = model.body_arities_tensor.to(args.device)
    logger.info(f"Model is running on device {args.device}")
    # model.compile()

    model_params = sum(p.numel() for p in model.parameters())
    results["model_params"] = model_params
    logger.info(f"Total model parameters: {model_params}")


    symb_model = SymbProgram(
        args.cwa,
        args.max_train_inference_steps,
        args.max_eval_inference_steps,
        [p.ndim for p in train_bk_init],
        [p.ndim for p in train_aux_init],
        [p.ndim for p in train_targets_init],
        args.num_interop_threads,
        args.sampling_bacc_non_parallel_compute
    )
    # symb_model.compile()


    # ================================================
    #   Initialize optimizer
    # ================================================
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        weight_decay=args.weight_decay
    )
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=args.lr,
    #     weight_decay=args.weight_decay
    # )


    # ================================================
    #   Other initialization
    # ================================================
    max_vars_entropy = Categorical(torch.tensor([1.0 / args.max_variables] * args.max_variables)).entropy().item()
    num_head_predicates = len(train_aux_init + train_targets_init) if args.cwa else len(train_predicates_init)
    if num_head_predicates == 1:
        max_head_atom_entropy = 0
    else:
        max_head_atom_entropy = Categorical(torch.tensor([1.0 / num_head_predicates] * num_head_predicates)).entropy().item()
    max_body_atom_entropy = Bernoulli(0.5).entropy().item()

    # entropy_coeff_sche = CosineAnnealingRestartScheduler(
    #     init_value=args.entropy_coeff_init,
    #     final_value=args.entropy_coeff_final,
    #     restart_period=args.entropy_coeff_anneal_epochs
    # )
    entropy_coeff_sche = ExponentialScheduler(
        init_value=args.entropy_coeff_init,
        final_value=args.entropy_coeff_final,
        restart_period=args.entropy_coeff_anneal_epochs
    )
    # entropy_coeff_sche = LinearAnnealingRestartScheduler(
    #     init_value=args.entropy_coeff_init,
    #     final_value=args.entropy_coeff_final,
    #     restart_period=args.entropy_coeff_anneal_epochs
    # )
    

    # ================================================
    #   Main loop
    # ================================================
    train_progress = []
    start_time = time.time()
    for epoch in range(args.train_epochs):
        eval_at_this_epoch: bool = ((epoch + 1) % args.eval_every_epoch == 0)
        train_frac = (epoch + 1) / args.train_epochs


        # ================================================
        #   Training
        # ================================================
        optimizer.zero_grad()

        model.train()
        symb_model.train()


        res = model(args.num_sample_vars, args.num_sample_atoms)


        train_balanced_acc, train_fc_steps = symb_model(
            train_predicates_init,
            train_targets_label,
            res["vars"]["head"]["samples"], res["vars"]["body"]["samples"],
            res["atom"]["head"]["samples"], res["atom"]["body"]["samples"]
        )

        train_bacc_mean = train_balanced_acc.mean(dim=1)
        train_bacc_mean_mean = train_bacc_mean.mean()
        train_bacc_vars_std = torch.std(train_balanced_acc, dim=1).mean()
        train_bacc_atoms_std = torch.std(train_bacc_mean)
        train_bacc_max, _col_max = train_balanced_acc.max(dim=1)
        train_bacc_max_mean = train_bacc_max.mean()
        train_bacc_max_max, _row_max = train_bacc_max.max(dim=0)
        _col_max = _col_max[_row_max]
        train_bacc_max_std = torch.std(train_bacc_max)


        f = train_balanced_acc

        vars_log_prob = res["vars"]["head"]["log_prob"].sum(dim=(-1, -2)) + res["vars"]["body"]["log_prob"].sum(dim=(-1, -2, -3))
        atom_log_prob = res["atom"]["head"]["log_prob"].sum(dim=-1) + \
            res["atom"]["body"]["log_prob"].sum(dim=(-1, -2, -3)) - res["atom"]["body"]["log_prob_cdf"].sum()
        if args.objective == "REINFORCE":
            L1 = torch.mean(
                (atom_log_prob.unsqueeze(-1) + vars_log_prob) * f
            )
        elif args.objective == "RLOO":
            L1 = torch.sum(
                (atom_log_prob.unsqueeze(-1) + vars_log_prob) * (f - f.mean())
            ) / (f.numel() - 1)
        elif args.objective == "RLOO-topk":
            L1 = torch.mean(
                (atom_log_prob.unsqueeze(-1) + vars_log_prob).flatten() * (f.flatten() - _loo_topk(f, args.rloo_topk))
            )
        else:
            raise ValueError(f"Uknown objective: {args.objective}")
        head_vars_entropy = torch.sum(res["vars"]["head"]["entropy"])
        body_vars_entropy = torch.sum(res["vars"]["body"]["entropy"])
        head_atom_entropy = torch.sum(res["atom"]["head"]["entropy"])
        body_atom_entropy = torch.sum(res["atom"]["body"]["entropy"])
        L5 = head_vars_entropy + \
            body_vars_entropy + \
            head_atom_entropy + \
            body_atom_entropy
        # L5 = L5 / 4
        # L3 = vars_entropy + atom_entropy
        entropy_coeff = entropy_coeff_sche.get_value(epoch + 1)
        L = - (L1 + entropy_coeff * L5)  # This value is meaningless. `bce` is the actual loss.

        if args.debug:
            with torch.autograd.set_detect_anomaly(True):
                L.backward()
        else:
            L.backward()

        head_vars_entropy_ratio_avg = head_vars_entropy.detach() / max_vars_entropy / res["vars"]["head"]["entropy"].numel()
        body_vars_entropy_ratio_avg = body_vars_entropy.detach() / max_vars_entropy / res["vars"]["body"]["entropy"].numel()
        head_atom_entropy_ratio_avg = head_atom_entropy.detach() / max_head_atom_entropy / res["atom"]["head"]["entropy"].numel() if max_head_atom_entropy else torch.tensor(0)
        body_atom_entropy_ratio_avg = body_atom_entropy.detach() / max_body_atom_entropy / res["atom"]["body"]["entropy"].numel()

        if eval_at_this_epoch:
            for param_name, param in model.named_parameters():
                grad_norm = torch.linalg.vector_norm(param.grad, ord=2)
        if args.max_grad_norm is not None:
            total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
        else:
            total_grad_norm = torch.nn.utils.get_total_norm((p.grad for p in model.parameters()))

        optimizer.step()

        if eval_at_this_epoch:
            for param_name, param in model.named_parameters():
                clipped_grad_norm = torch.linalg.vector_norm(param.grad, ord=2)
            total_clipped_grad_norm = torch.nn.utils.get_total_norm((p.grad for p in model.parameters()))


        # ================================================
        #   Evaluation
        # ================================================
        # Extract symbolic rules
        if eval_at_this_epoch:
            model.eval()
            symb_model.eval()
        

            with torch.no_grad():
                eval_res = model()
            
            eval_train_predictions, eval_train_symb_fc_steps = symb_model(
                train_predicates_init,
                train_targets_label,
                [[eval_res["vars"]["head"]["samples"]]], [[eval_res["vars"]["body"]["samples"]]],
                [eval_res["atom"]["head"]["samples"]], [eval_res["atom"]["body"]["samples"]]
            )
            eval_train_P, eval_train_R, eval_train_F1, eval_train_acc, eval_train_balanced_acc = symbolic_evaluation(
                eval_train_predictions.squeeze((0, 1)), train_targets_label, num_train_targets_values
            )
            del eval_train_predictions

            eval_predictions, eval_symb_fc_steps = symb_model(
                eval_predicates_init,
                eval_targets_label,
                [[eval_res["vars"]["head"]["samples"]]], [[eval_res["vars"]["body"]["samples"]]],
                [eval_res["atom"]["head"]["samples"]], [eval_res["atom"]["body"]["samples"]]
            )
            eval_P, eval_R, eval_F1, eval_acc, eval_balanced_acc = symbolic_evaluation(
                eval_predictions.squeeze((0, 1)), eval_targets_label, num_eval_targets_values
            )
            del eval_predictions

            raw_rules, postprocessed_rules = symb_model.extract_rules(
                predicate_names,
                eval_res["vars"]["head"]["samples"], eval_res["vars"]["body"]["samples"],
                eval_res["atom"]["head"]["samples"], eval_res["atom"]["body"]["samples"]
            )


        # ================================================
        #   Stop criteria based on training sampling
        # ================================================
        model.eval()
        symb_model.eval()

        sampling_res = {
            "vars": {"head": {"samples": res["vars"]["head"]["samples"][_row_max][_col_max]}, "body": {"samples": res["vars"]["body"]["samples"][_row_max][_col_max]}},
            "atom": {"head": {"samples": res["atom"]["head"]["samples"][_row_max]}, "body": {"samples": res["atom"]["body"]["samples"][_row_max]}}
        }
        sampling_predictions, sampling_symb_fc_steps = symb_model(
            eval_predicates_init,
            eval_targets_label,
            [[sampling_res["vars"]["head"]["samples"]]], [[sampling_res["vars"]["body"]["samples"]]],
            [sampling_res["atom"]["head"]["samples"]], [sampling_res["atom"]["body"]["samples"]]
        )
        sampling_P, sampling_R, sampling_F1, sampling_acc, sampling_balanced_acc = symbolic_evaluation(
            sampling_predictions.squeeze((0, 1)), eval_targets_label, num_eval_targets_values
        )
        del sampling_predictions
        
        sampling_raw_rules, sampling_postprocessed_rules = symb_model.extract_rules(
            predicate_names,
            sampling_res["vars"]["head"]["samples"], sampling_res["vars"]["body"]["samples"],
            sampling_res["atom"]["head"]["samples"], sampling_res["atom"]["body"]["samples"]
        )


        # ================================================
        #   Logging at each epoch
        # ================================================
        epoch_record = {
            "epoch": epoch + 1,
            "train": {
                "max_fc_steps": train_fc_steps,
                "L": L.item(),
                "L1": L1.item(),
                "L5": L5.item(),
                "acc": eval_train_acc,
                "balanced_acc": eval_train_balanced_acc,
                "P": eval_train_P,
                "R": eval_train_R,
                "F1": eval_train_F1,
                "fc_steps": eval_train_symb_fc_steps
            },
            "sampling": {
                "train_bacc": train_bacc_max_max.item(),
                "eval_acc": sampling_acc,
                "eval_balanced_acc": sampling_balanced_acc,
                "eval_P": sampling_P,
                "eval_R": sampling_R,
                "eval_F1": sampling_F1,
                "eval_fc_steps": sampling_symb_fc_steps,
                "postprocessed_rules": sampling_postprocessed_rules,
                "raw_rules": sampling_raw_rules
            }
        }
        if eval_at_this_epoch:
            epoch_record.update({
                "eval": {
                    "acc": eval_acc,
                    "balanced_acc": eval_balanced_acc,
                    "P": eval_P,
                    "R": eval_R,
                    "F1": eval_F1,
                    "fc_steps": eval_symb_fc_steps,
                    "postprocessed_rules": postprocessed_rules,
                    "raw_rules": raw_rules  
                }
            })
        train_progress.append(epoch_record)
        logger.info(json.dumps(epoch_record, ensure_ascii=False))


        # ================================================
        #   Stop or Continue
        # ================================================
        if args.stop_per_sampling_rules:
            if train_bacc_max_max == 1.0 or eval_train_balanced_acc == 1.0:
                break
        else:
            if eval_train_balanced_acc == 1.0:
                break


    # ================================================
    #   Output
    # ================================================
    training_time = time.time() - start_time
    training_time_formatted = datetime.timedelta(seconds=training_time)
    logger.info(f"Training time: {training_time_formatted}")
    results["time_usage"] = str(training_time_formatted)

    best_epoch_per_train_balanced_acc = max(
        filter(lambda d: "train" in d, train_progress),
        key=lambda d: (d["train"]["balanced_acc"], - d["epoch"])
    )
    best_epoch_per_train_sampling_bacc = max(
        train_progress,
        key=lambda d: (d["sampling"]["train_bacc"], - d["epoch"])
    )
    results["eval_bacc_of_best_train_bacc"] = best_epoch_per_train_balanced_acc["eval"]["balanced_acc"]
    results["postprocessed_rules_of_best_train_bacc"] = best_epoch_per_train_balanced_acc["eval"]["postprocessed_rules"]
    results["eval_bacc_of_best_train_sampling_bacc"] = best_epoch_per_train_sampling_bacc["sampling"]["eval_balanced_acc"]
    results["postprocessed_rules_of_best_train_sampling_bacc"] = best_epoch_per_train_sampling_bacc["sampling"]["postprocessed_rules"]
    results["best_epoch_per_train_bacc"] = best_epoch_per_train_balanced_acc.copy()
    results["best_epoch_per_train_sampling_bacc"] = best_epoch_per_train_sampling_bacc.copy()
    logger.info(f"Best epoch according to training balanced accuracy: {json.dumps(best_epoch_per_train_balanced_acc, indent=4, ensure_ascii=False)}")
    logger.info(f"Best epoch according to training sampling balanced accuracy: {json.dumps(best_epoch_per_train_sampling_bacc, indent=4, ensure_ascii=False)}")

    results["configuration"] = args.to_json()
    results["train_progress"] = train_progress

    return results


def main():
    # ================================================
    #   Parse arguments
    # ================================================
    parser = transformers.HfArgumentParser((Args, ))
    args, = parser.parse_args_into_dataclasses()


    # ================================================
    #   Ensure reproducibility
    # ================================================
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


    # ================================================
    #   Set up logging
    # ================================================
    periodic_handler = PeriodicHandler(args.log_file, args.log_buffer_capacity, args.log_period_to_disk)
    logging.basicConfig(
        level=logging.INFO,
        handlers=[periodic_handler]
    )


    # ================================================
    #   Train & Evaluation
    # ================================================
    try:
        results = run(args)
    except Exception as e:
        if args.debug:
            raise e
        else:
            logging.error(format_exc())
            return


    # ================================================
    #   Write results
    # ================================================
    # for split in ("train", "eval"):
    #     if split in results["best_epoch"]:
    #         results["best_epoch"][split] = json.dumps(results["best_epoch"][split], ensure_ascii=False)
    with args.result_file.open("w") as file_writer:
        json.dump(results, file_writer, indent=4, ensure_ascii=False)


    # ================================================
    #   Finalize
    # ================================================
    logging.shutdown()


if __name__ == '__main__':
    main()
