import torch.nn as nn
import torch

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), './models')))

from abl.reasoning import ReasonerBase, KBBase

from abl.learning import BasicNN, ABLModel
from abl.bridge import PhaseBridge
from abl.evaluation import SymbolMetric, ABLMetric
from abl.utils import ABLLogger, print_log
from collections import defaultdict
from models.nn import LeNet5, ResNet18
import argparse
import wandb
from datasets.addition import digits_to_number, get_data, get_phase_data
import random

def split_list(lst):
    middle = len(lst) // 2
    list1 = lst[:middle]
    list2 = lst[middle:]
    return list1, list2
    
class add_KB(KBBase):
    def __init__(
        self,
        pseudo_label_list=list(range(10)),
        prebuild_GKB=True,
        GKB_len_list=[2 * 2],
        max_err=0,
        use_cache=True,
        kb_file_path="./kb/addition_2_16.pl",
        digit_base = 10,
    ):
        super().__init__(
            pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache, kb_file_path
        )
        self.digit_base = digit_base

    def logic_forward(self, nums):
        nums1, nums2 = split_list(nums)
        return digits_to_number(nums1, self.digit_base) + digits_to_number(nums2, self.digit_base)

import random
def get_lists(z_list, min_phase_size, digit_base=10):
    if digit_base == 10:
        return [[1,9],[1,3,7,9],[0,1,3,5,7,9],[0,1,2,3,5,7,8,9],[0,1,2,3,4,5,6,7,8,9]]
    elif digit_base == 16:
        return [[2,9],[0,2,9,15],[0,2,4,9,10,15],[0,2,4,6,9,10,12,15],[0,2,4,6,7,9,10,12,13,15],[0,1,2,4,6,7,9,10,11,12,13,15],[0,1,2,4,5,6,7,9,10,11,12,13,14,15],[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]

def main(args):
    log_name = f"cabl_{args.digit_size}_{args.digit_base}_{args.dataset}"
    base2num = {10: "1", 16: "2"}
    logger = ABLLogger.get_instance("abl", logger_name=f"pabl{args.digit_size}{base2num[args.digit_base]}", log_name=log_name)
    print_log(f"args: {args}", logger="current")
    z_lists = get_lists(list(range(args.digit_base)), min_phase_size=args.min_phase_size, digit_base=args.digit_base)
    seed_everything(args.seed)
    
    abducer_list = []
    for sub_z_list in z_lists:
        kb_file_path = f"./kb/addition_{args.digit_size}_{args.digit_base}.pl"
        kb = add_KB(
            GKB_len_list=[args.digit_size * 2], 
            pseudo_label_list=sub_z_list,
            kb_file_path=kb_file_path,
            digit_base=args.digit_base,
        )
        mapping = {index: label for index, label in enumerate(list(range(args.digit_base)))}
        abducer = ReasonerBase(kb, dist_func=args.dist_func, mapping=mapping, use_zoopt=False)
        abducer_list.append(abducer)
        
    cls_map = defaultdict(lambda:LeNet5(num_classes=len(kb.pseudo_label_list)))
    cls_map.update({"MNIST": LeNet5(num_classes=len(kb.pseudo_label_list))})
    cls_map.update({"CIFAR": ResNet18(num_classes=len(kb.pseudo_label_list))})
    cls = cls_map[args.dataset]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
    base_model = BasicNN(
        cls,
        criterion,
        optimizer,
        device,
        save_interval=1,
        save_dir=logger.save_dir,
        batch_size=256,
        num_epochs=1,
    )
    model = ABLModel(base_model)
    metric = [
        SymbolMetric(prefix=f"{args.dataset}_add", digit_base=args.digit_base),
        ABLMetric(prefix=f"{args.dataset}_add"),
    ]
    train_data = get_phase_data(
                        train=True, 
                        get_pseudo_label=True, 
                        n=args.digit_size,
                        min_sequence_num=args.batch_size,
                        z_lists=z_lists,
                        digit_base=args.digit_base,
                        dataset=args.dataset,
                    )
    
    test_data = get_data(
                        train=False, 
                        get_pseudo_label=True, 
                        n=args.digit_size,
                        digit_base=args.digit_base,
                        dataset=args.dataset
                    )
    bridge = PhaseBridge(model, abducer_list, metric)
    bridge.train(train_data, test_data, batch_size=args.batch_size, max_iters=args.max_iters)
    # bridge.test(test_data)


def get_args():
    parser = argparse.ArgumentParser(prog="Addition Experiment, Phase ABL")
    parser.add_argument("--dataset", type=str, default="MNIST")
    parser.add_argument("--digit_size", type=int, default=3)
    parser.add_argument("--digit_base", type=int, default=10)
    parser.add_argument("--dist_func", type=str, choices=['hamming', 'confidence'], default='confidence')
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--min_phase_size", type=float, default=2)
    parser.add_argument("--max_iters", type=int, default=5000)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    return args


def seed_everything(seed: int = 0):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


if __name__ == "__main__":
    args = get_args()
    main(args)
