import torch.nn as nn
import torch

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from abl.reasoning import ReasonerBase, KBBase

from abl.learning import BasicNN, ABLModel
from abl.bridge import SimpleBridge
from abl.evaluation import SymbolMetric, ABLMetric
from abl.utils import ABLLogger, print_log
from collections import defaultdict
from models.nn import LeNet5, ResNet18
import argparse
import wandb
from datasets.addition import digits_to_number, get_data


def split_list(lst):
    middle = len(lst) // 2
    list1 = lst[:middle]
    list2 = lst[middle:]
    return list1, list2


class add_KB(KBBase):
    def __init__(
        self,
        pseudo_label_list=list(range(10)),
        prebuild_GKB=True,
        GKB_len_list=[2 * 2],
        max_err=0,
        use_cache=True,
        kb_file_path="./kb/addition_2_16.pl",
        digit_base = 10,
    ):
        super().__init__(
            pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache, kb_file_path
        )
        self.digit_base = digit_base

    def logic_forward(self, nums):
        nums1, nums2 = split_list(nums)
        return digits_to_number(nums1, self.digit_base) + digits_to_number(nums2, self.digit_base)


def main(args):
    log_name = f"nabl_{args.digit_size}_{args.digit_base}_{args.dataset}"
    base2num = {10: "1", 16: "2"}
    logger = ABLLogger.get_instance("abl", logger_name=f"nabl{args.digit_size}{base2num[args.digit_base]}", log_name=log_name)
    
    # wandb.init(
    #     project="ws_abl", group=f"addition {args.dataset}-{args.digit_size} naive abl"
    # )
    print_log(f"args: {args}", logger="current")
    seed_everything(args.seed)
    kb_file_path = f"./kb/addition_{args.digit_size}_{args.digit_base}.pl"
    kb = add_KB(
            prebuild_GKB=True, 
            pseudo_label_list=list(range(args.digit_base)),
            GKB_len_list=[args.digit_size * 2],
            kb_file_path=kb_file_path,
            digit_base=args.digit_base,
        )
    print(f"Number of pseudo labels: {len(kb.pseudo_label_list)}")
    abducer = ReasonerBase(kb, dist_func=args.dist_func)
    cls_map = defaultdict(lambda:LeNet5(num_classes=len(kb.pseudo_label_list)))
    cls_map.update({"MNIST": LeNet5(num_classes=len(kb.pseudo_label_list))})
    cls_map.update({"CIFAR": ResNet18(num_classes=len(kb.pseudo_label_list))})
    cls = cls_map[args.dataset]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
    base_model = BasicNN(
        cls,
        criterion,
        optimizer,
        device,
        save_interval=1,
        save_dir=logger.save_dir,
        batch_size=256,
        num_epochs=1,
        )
    
    model = ABLModel(base_model)
    metric = [
        SymbolMetric(prefix=f"{args.dataset}_add", digit_base=args.digit_base),
        ABLMetric(prefix=f"{args.dataset}_add"),
    ]
    train_data = get_data(
                    train=True, 
                    get_pseudo_label=True, 
                    n=args.digit_size,
                    digit_base=args.digit_base,
                    dataset=args.dataset
                )
    test_data = get_data(
                    train=False, 
                    get_pseudo_label=True, 
                    n=args.digit_size,
                    digit_base=args.digit_base,
                    dataset=args.dataset
                )
    print(f"Train data: {len(train_data)}")
    bridge = SimpleBridge(model, [abducer], metric)
    bridge.train(train_data, max_iters=args.max_iters, batch_size=256, test_data=test_data)
    # bridge.test(test_data)


def get_args():
    parser = argparse.ArgumentParser(prog="Addition Experiment, Naive ABL")
    parser.add_argument("--dataset", type=str, default="MNIST")
    parser.add_argument("--digit_size", type=int, default=2)
    parser.add_argument("--digit_base", type=int, default=10)
    parser.add_argument("--dist_func", type=str, choices=['hamming', 'confidence'], default='confidence')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--max_iters", type=int, default=5000)
    args = parser.parse_args()
    return args


def seed_everything(seed: int = 0):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


if __name__ == "__main__":
    args = get_args()
    main(args)
