import torch.nn as nn
import torch

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from abl.reasoning import ReasonerBase, KBBase, WeaklySupervisedReasoner
from pathlib import Path
from abl.learning import BasicNN, ABLModel, WeaklySupervisedABLModel, WeaklySupervisedNN
from abl.bridge import SimpleBridge, WeaklySupervisedBridge
from abl.evaluation import SymbolMetric, ABLMetric
from abl.utils import ABLLogger, print_log
from collections import defaultdict
from models.nn import LeNet5, ResNet18
import argparse
import wandb
from datasets.addition import digits_to_number, get_data

_ROOT = Path(__file__).parent
def split_list(lst):
    middle = len(lst) // 2
    list1 = lst[:middle]
    list2 = lst[middle:]
    return list1, list2


class add_KB(KBBase):
    def __init__(
        self,
        pseudo_label_list=list(range(10)),
        prebuild_GKB=True,
        GKB_len_list=[2 * 2],
        max_err=0,
        use_cache=True,
        kb_file_path="./kb/addition_2_16.pl",
        digit_base = 10,
    ):
        super().__init__(
            pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache, kb_file_path
        )
        self.digit_base = digit_base

    def logic_forward(self, nums):
        nums1, nums2 = split_list(nums)
        return digits_to_number(nums1, self.digit_base) + digits_to_number(nums2, self.digit_base)


def main(args):
    log_name = f"wsabl_{args.digit_size}_{args.digit_base}_{args.dataset}"
    base2num = {10: "1", 16: "2"}
    logger = ABLLogger.get_instance("abl", logger_name=f"wsabl{args.digit_size}{base2num[args.digit_base]}", log_name=log_name)
    # wandb.init(
    #     project="ws_abl",
    #     group=f"{args.group_hint}: addition {args.dataset}-{args.digit_size} ws abl",
    # )
    print_log(f"args: {args}", logger="current")
    seed_everything(args.seed)
    kb_file_path = f"./kb/addition_{args.digit_size}_{args.digit_base}.pl"
    kb = add_KB(
            prebuild_GKB=True, 
            pseudo_label_list=list(range(args.digit_base)),
            GKB_len_list=[args.digit_size * 2],
            kb_file_path=kb_file_path,
            digit_base=args.digit_base,
        )
    abducer = WeaklySupervisedReasoner(kb)
    cls_map = defaultdict(lambda: LeNet5(
        num_classes=len(kb.pseudo_label_list)))
    cls_map.update({"MNIST": LeNet5(num_classes=len(kb.pseudo_label_list))})
    cls_map.update({"CIFAR": ResNet18(num_classes=len(kb.pseudo_label_list))})
    cls = cls_map[args.dataset]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss(reduction="none")
    optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
    base_model = WeaklySupervisedNN(
        cls,
        criterion,
        optimizer,
        device,
        save_interval=1,
        save_dir=logger.save_dir,
        batch_size=256,
        num_epochs=1,
    )
    
    model = WeaklySupervisedABLModel(
        base_model, topK=args.topk, temp=args.temp, digit_base=args.digit_base)
    metric = [
        SymbolMetric(prefix=f"{args.dataset}_add"),
        ABLMetric(prefix=f"{args.dataset}_add"),
    ]
    train_data = get_data(
                    train=True, 
                    get_pseudo_label=True, 
                    n=args.digit_size,
                    digit_base=args.digit_base,
                    dataset=args.dataset,
                )
    test_data = get_data(
                    train=False, 
                    get_pseudo_label=True, 
                    n=args.digit_size,
                    digit_base=args.digit_base,
                    dataset=args.dataset,
                )
    bridge = WeaklySupervisedBridge(model, [abducer], metric)
    bridge.train(
        train_data,
        max_iters=args.max_iters,
        batch_size=256,
        test_data=test_data,
        more_revision=2,
    )
    # bridge.test(test_data)


def get_args():
    parser = argparse.ArgumentParser(
        prog="Addition Experiment, Weakly Supervised ABL")
    parser.add_argument("--group_hint", type=str, default="")
    parser.add_argument("--dataset", type=str, default="MNIST")
    parser.add_argument("--digit_size", type=int, default=2)
    parser.add_argument("--digit_base", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--temp", type=float, default=0.2)
    parser.add_argument("--max_iters", type=int, default=5000)
    parser.add_argument(
        "--topk",
        type=int,
        default=32,
        help="choose only top k candidates, k=-1 means use all of them.",
    )
    args = parser.parse_args()
    return args


def seed_everything(seed: int = 0):
    import random
    import os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


if __name__ == "__main__":
    args = get_args()
    main(args)
