import torch.nn as nn
import torch
from itertools import filterfalse
from abl.reasoning import ReasonerBase, KBBase, WeaklySupervisedReasoner, MixedReasoner
from pathlib import Path
from abl.learning import BasicNN, ABLModel, WeaklySupervisedABLModel, WeaklySupervisedNN
from abl.bridge import SimpleBridge, WeaklySupervisedBridge, MixedWSBridge
from abl.evaluation import SymbolMetric, ABLMetric
from abl.utils import ABLLogger, solutionOfEnsemble, solutionOfMod, print_log
from collections import defaultdict
from models.nn import LeNet5, ResNet50
import argparse
import wandb
from datasets.addition import digits_to_number, get_ensemble_add

_ROOT = Path(__file__).parent
def split_list(lst):
    middle = len(lst) // 2
    list1 = lst[:middle]
    list2 = lst[middle:]
    return list1, list2

def parse_nums_and_ops(lst):
    op = lst[-1]
    list1, list2 = split_list(lst[:-1])
    return list1, list2, op


class ensemble_KB(KBBase):
    def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[1*2 + 1], max_err=0, use_cache=True, kb_file_path='', mod1:int = 10, mod2:int = 10):
        self.mod1, self.mod2 = mod1, mod2
        super().__init__(pseudo_label_list + [mod1, mod2], prebuild_GKB, GKB_len_list, max_err, use_cache, kb_file_path)
        
    def _valid_candidate(self, lst):
        if (len(lst) - 1) % 2!= 0: return False
        if lst[-1] not in [self.mod1, self.mod2]: return False
        return True
    
    def _find_candidate_GKB(self, pred_res, y):
        possible_candidate = self.base[len(pred_res)][y]
        possible_candidate = list(filter(lambda x: x[-1] == pred_res[-1], possible_candidate))
        return possible_candidate
        
    def logic_forward(self, lsts):
        if not self._valid_candidate(lsts): return None
        nums1, nums2, op = parse_nums_and_ops(lsts)
        nums1, nums2 = digits_to_number(nums1), digits_to_number(nums2)
        if nums1 >= (len(lsts) - 1) / 2  * 10 or nums2 >= (len(lsts) - 1) / 2 * 10: return None
        return (nums1 + nums2) % op

def main(args):
    if args.no_ensemble:
        args.mod2 = args.mod1
    else: 
        if args.mod1 > args.mod2: return
    sols1 = solutionOfMod(args.mod1)
    sols2 = solutionOfMod(args.mod2)
    num_sols = solutionOfEnsemble(args.mod1, args.mod2)
    logger = ABLLogger.get_instance("abl")
    wandb.init(
        project="LearnablityOfNeSy",
        group=f"{args.group_hint}: addition {args.dataset}-{args.digit_size}-{args.mod1}-{args.mod2} ws abl",
    )
    seed_everything(args.seed)
    kb = ensemble_KB(prebuild_GKB=True, GKB_len_list=[args.digit_size * 2 + 1], kb_file_path=f"{_ROOT}/kb_cache/addition_{args.digit_size}_mod1{args.mod1}_mod2{args.mod2}_kb", mod1 = args.mod1, mod2 = args.mod2)
    abducer = MixedReasoner(kb, mapping= {i:i for i in range(10)})
    cls_map = defaultdict(lambda: LeNet5(
        num_classes=len(kb.pseudo_label_list)))
    cls_map.update({"MNIST": LeNet5(num_classes=10, weight='./lenet5_weights.pth')})
    cls_map.update({"KMNIST": LeNet5(num_classes=10)})
    cls_map.update({"CIFAR": ResNet50(num_classes=10, pretrained=args.pretrained)})
    cls_map.update({"SVHN": ResNet50(num_classes=10, pretrained=args.pretrained)})
    cls = cls_map[args.dataset]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss(reduction="none")
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, cls.parameters()), lr=0.0015, betas=(0.9, 0.99))
    # lr scheduler 
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.9, total_iters=10)
    base_model = WeaklySupervisedNN(
        cls,
        criterion,
        optimizer,
        device,
        save_interval=1,
        save_dir=logger.save_dir,
        batch_size=256,
        num_epochs=1,
        scheduler=scheduler,
    )
    model = WeaklySupervisedABLModel(
        base_model, topK=args.topk, temp=args.temp)
    metric = [
        SymbolMetric(prefix=f"{args.dataset}_add", ignore_end=True),
        ABLMetric(prefix=f"{args.dataset}_add"),
    ]

    wandb.log({"num_sols": num_sols})     
    print_log(msg=f"Nums of solutions: {num_sols}", logger='current')
    
    train_data = get_ensemble_add(args.dataset, train=True, get_pseudo_label=True, n=args.digit_size, mod1=args.mod1, mod2=args.mod2, sample_size=args.sample_size)
    test_data = get_ensemble_add(args.dataset, train=False, get_pseudo_label=True, n=args.digit_size, mod1=args.mod1, mod2=args.mod2)
    bridge = MixedWSBridge(model, abducer, metric)
    bridge.train(
        train_data,
        epochs=args.epoches,
        batch_size=1024,
        test_data=test_data,
        more_revision=2,
    )
    # bridge.test(test_data)


import argparse


def get_args():
    parser = argparse.ArgumentParser(
        prog="Addition Experiment, Weakly Supervised ABL")
    parser.add_argument("--group_hint", type=str, default="")
    parser.add_argument("--dataset", type=str, default="MNIST")
    parser.add_argument("--digit_size", type=int, default=1)
    parser.add_argument("--epoches", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--temp", type=float, default=0.2)
    parser.add_argument("--mod1", type=int, default=5)
    parser.add_argument("--mod2", type=int, default=5)
    parser.add_argument("--no-ensemble", action='store_true', default=False, help="Enable no-ensemble mode")
    parser.add_argument("--pretrained", action='store_true', default=False, help="Enable no-ensemble mode")
    parser.add_argument("--sample_size", type=int, default=30000)
    parser.add_argument(
        "--topk",
        type=int,
        default=16,
        help="choose only top k candidates, k=-1 means use all of them.",
    )
    args = parser.parse_args()
    return args



def seed_everything(seed: int = 0):
    import random
    import os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


if __name__ == "__main__":
    args = get_args()
    main(args)
