import torch.nn as nn
import torch

import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from reasoning import UnsupKBBase, UnsupReasonerBase, PhaseUnsupBridge, ValChessMetric

from abl.learning import BasicNN, ABLModel
from abl.utils import ABLLogger, print_log
from collections import defaultdict
from models.nn import LeNet5, ResNet18
import argparse

# import wandb
from datasets.chess import get_chess
import random, os
import numpy as np
from torch.utils.data import Dataset
import torch
import time


class sort_KB(UnsupKBBase):
    def __init__(
        self,
        pseudo_label_list=list(range(2)),
        num_digits=1,
        num_classes=2,
        ind=False,
        prebuild_GKB=False,
        GKB_len_list=None,
    ):
        super().__init__(
            pseudo_label_list,
            ind=ind,
            prebuild_GKB=prebuild_GKB,
            GKB_len_list=GKB_len_list,
            num_digits=num_digits,
        )
        self.num_classes = num_classes
        self.max_times = num_classes ** (num_digits)
        self.pseudo_label_list = list(range(num_classes))
        self.num_digits = num_digits
        self.require_more_revision = num_digits
        self.count = 0
        self.prebuild_GKB = prebuild_GKB
        self.prebuild_kb()

    def attack(self, type, x1, y1, x2, y2):
        if type == 0:
            return self.pawn_attack(x1, y1, x2, y2)
        elif type == 1:
            return self.rook_attack(x1, y1, x2, y2)
        elif type == 2:
            return self.bishop_attack(x1, y1, x2, y2)
        elif type == 3:
            return self.knight_attack(x1, y1, x2, y2)
        elif type == 4:
            return self.king_attack(x1, y1, x2, y2)
        elif type == 5:
            return self.queen_attack(x1, y1, x2, y2)
        return False

    def king_attack(self, x1, y1, x2, y2):
        # King moves one step in any direction
        return abs(x1 - x2) <= 1 and abs(y1 - y2) <= 1

    def queen_attack(self, x1, y1, x2, y2):
        # Queen moves straight or diagonal
        return self.straight_attack(x1, y1, x2, y2) or self.diagonal_attack(
            x1, y1, x2, y2
        )

    def rook_attack(self, x1, y1, x2, y2):
        # Rook moves straight
        return self.straight_attack(x1, y1, x2, y2)

    def bishop_attack(self, x1, y1, x2, y2):
        # Bishop moves diagonally
        return self.diagonal_attack(x1, y1, x2, y2)

    def knight_attack(self, x1, y1, x2, y2):
        # Knight moves in an "L" shape
        return (abs(x1 - x2) == 2 and abs(y1 - y2) == 1) or (
            abs(x1 - x2) == 1 and abs(y1 - y2) == 2
        )

    def pawn_attack(self, x1, y1, x2, y2):
        # Pawn attacks diagonally (assuming it's a white pawn)
        return abs(x1 - x2) == 1 and y2 - y1 == 1

    def straight_attack(self, x1, y1, x2, y2):
        # Moves straight: either same row or same column
        return x1 == x2 or y1 == y2

    def diagonal_attack(self, x1, y1, x2, y2):
        # Diagonal move: difference between x and y is the same
        return abs(x1 - x2) == abs(y1 - y2)

    def logic_forward(self, type, pos):
        # type, pos =nums
        self.count += 1
        l = len(type)
        for i in range(l):
            for j in range(i + 1, l):
                if self.attack(type[i], pos[i][0], pos[i][1], pos[j][0], pos[j][1]):
                    return True
        return False


def main(args):
    log_name = f"cabl_chess"
    logger = ABLLogger.get_instance("abl", logger_name=f"cabl", log_name=log_name)
    print_log(f"args: {args}", logger="current")
    seed_everything(args.seed)
    abducer_list = []
    for i in range(1, args.num_classes + 1, 2):
        kb = sort_KB(
            num_classes=i,
            num_digits=args.digit_size,
            ind=args.ind,
            prebuild_GKB=args.prebuild,
            GKB_len_list=[args.digit_size * 3 + 1],
        )
        mapping = {index: label for index, label in enumerate(list(range(args.num_classes)))}
        abducer = UnsupReasonerBase(kb, dist_func=args.dist_func, mapping=mapping, use_zoopt=False)
        abducer_list.append(abducer)
    cls = LeNet5(num_classes=args.num_classes)
    device = torch.device(
        "cuda:" + str(args.device) if torch.cuda.is_available() else "cpu"
    )
    criterion = nn.CrossEntropyLoss(reduction="none")
    optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr, betas=(0.9, 0.99))
    base_model = BasicNN(
        cls,
        criterion,
        optimizer,
        device,
        save_interval=1,
        save_dir=logger.save_dir,
        batch_size=256,
        num_epochs=1,
    )
    model = ABLModel(base_model)
    metric = [ValChessMetric(prefix=f"chess")]
    train_data = []
    for i in range(1, args.num_classes + 1, 2):
        train_data_seg, prior = get_chess(
            train=True,
            n=args.digit_size,
            num_classes=i,
            sequence_num=args.train_sequence_num,
            seed=args.seed,
        )
        train_data.append(train_data_seg)
    test_data, prior = get_chess(
        train=False,
        n=args.digit_size,
        num_classes=args.num_classes,
        sequence_num=args.test_sequence_num,
        seed=args.seed,
    )
    bridge = PhaseUnsupBridge(
        model,
        abducer_list,
        metric,
        val=args.val,
        require_more_revision=args.require_more_revision,
    )
    start_time = time.time()
    bridge.train(train_data, test_data=test_data, max_iter=args.max_iter, batch_size=256, prior=prior)
    end_time = time.time()
    used_time = end_time - start_time
    print("used_time:", used_time)
    bridge.test(test_data)


def get_args():
    parser = argparse.ArgumentParser(prog="Chess Experiment, Phase ABL")
    parser.add_argument("--max_iter", type=int, default=5000)
    parser.add_argument("--digit_size", type=int, default=2)
    parser.add_argument(
        "--dist_func", type=str, choices=["hamming", "confidence"], default="confidence"
    )
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num_classes", type=int, default=6)
    parser.add_argument("--train_sequence_num", type=int, default=10000)
    parser.add_argument("--test_sequence_num", type=int, default=1000)
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--require_more_revision", type=int, default=2)
    parser.add_argument("--T", type=float, default=1)
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--use_prob", type=bool, default=False)
    parser.add_argument("--use_weight", type=bool, default=False)
    parser.add_argument("--top_k", type=int, default=1)
    parser.add_argument("--val", type=bool, default=False)
    parser.add_argument("--ind", type=bool, default=False)
    parser.add_argument("--prebuild", type=bool, default=False)
    args = parser.parse_args()
    return args


def seed_everything(seed: int = 0):

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


if __name__ == "__main__":
    args = get_args()
    main(args)
