from abl.evaluation import BaseMetric
from abl.learning import ABLModel
from abl.reasoning import ReasonerBase, WeaklySupervisedReasoner
from ..learning import ABLModel, WeaklySupervisedABLModel
from ..reasoning import ReasonerBase
from ..evaluation import BaseMetric
from .base_bridge import BaseBridge
from typing import List, Union, Any, Tuple, Dict, Optional
from numpy import ndarray
import wandb
from torch.utils.data import DataLoader
from ..dataset import BridgeDataset
from ..utils.logger import print_log
import numpy as np
import pandas as pd
from itertools import count


class SimpleBridge(BaseBridge):
    def __init__(
        self,
        model: ABLModel,
        abducer_list: List[ReasonerBase],
        metric_list: List[BaseMetric],
    ) -> None:
        super().__init__(model, abducer_list)
        self.metric_list = metric_list

    def predict(self, X) -> Tuple[List[List[Any]], ndarray]:
        pred_res = self.model.predict(X)
        pred_idx, pred_prob = pred_res["label"], pred_res["prob"]
        return pred_idx, pred_prob

    def abduce_pseudo_label(
        self,
        pred_prob: ndarray,
        pred_pseudo_label: List[List[Any]],
        Y: List[Any],
        max_revision: int = -1,
        require_more_revision: int = 0,
    ) -> List[List[Any]]:
        return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)

    def idx_to_pseudo_label(self, idx: List[List[Any]], mapping: Dict = None) -> List[List[Any]]:
        if mapping is None:
            mapping = self.abducer.mapping
        return [[mapping[_idx] for _idx in sub_list] for sub_list in idx]

    def pseudo_label_to_idx(self, pseudo_label: List[List[Any]], mapping: Dict = None) -> List[List[Any]]:
        if mapping is None:
            mapping = self.abducer.remapping

        def recursive_map(func, nested_list):
            if isinstance(nested_list, (list, tuple)):
                return [recursive_map(func, x) for x in nested_list]
            else:
                return func(nested_list)

        return recursive_map(lambda x: mapping[x], pseudo_label)

    def train(
        self,
        train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]],
        max_iters: int = 500,
        batch_size: Union[int, float] = -1,
        eval_interval: int = 10,
        test_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] = None,
    ):
        dataset = BridgeDataset(*train_data)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
        )

        iter_cnt = 0
        while iter_cnt <= max_iters:
            for seg_idx, (X, Z, Y) in enumerate(data_loader):
                iter_cnt += 1
                if iter_cnt > max_iters:
                    break
                pred_idx, pred_prob = self.predict(X)
                pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
                abduced_pseudo_label = self.abduce_pseudo_label(pred_prob, pred_pseudo_label, Y)
                abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
                loss = self.model.train(X, abduced_label)
                
                # abduce_acc = np.mean([i==j for i,j in zip(abduced_label, self.pseudo_label_to_idx(Z))])
                # pred_char_acc = np.mean([np.mean([x==y for x,y in zip(i,j)]) for i,j in zip(pred_pseudo_label, Z)])
                # reasoning_acc = np.mean([self.abducer.kb.logic_forward(i) == j for i,j in zip(pred_pseudo_label, Y)])
                
                # print_log(
                #     f"Iter [{iter_cnt}], model loss is {loss:.5f}, pred_char_acc/reasoning_acc is {pred_char_acc:.5f}/{reasoning_acc:.5f}",
                #     logger="current",
                # )
                # wandb.log({"train loss": loss, "abduce_acc":abduce_acc, "pred_char_acc":pred_char_acc})
            
                if iter_cnt % eval_interval == 0:
                    print_log(f"Evaluation start: Iter [{iter_cnt}]", logger="current")
                    self.test(train_data, tag="train")
                    self.test(test_data, tag="test")

    def _valid(self, data_loader, tag="", verbose=1):
        for X, Z, Y in data_loader:
            pred_idx, pred_prob = self.predict(X)
            pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
            data_samples = dict(
                pred_idx=pred_idx,
                pred_prob=pred_prob,
                pred_pseudo_label=pred_pseudo_label,
                gt_pseudo_label=Z,
                Y=Y,
                logic_forward=self.abducer.kb.logic_forward,
            )
            for metric in self.metric_list:
                metric.process(data_samples)

        res = dict()
        for metric in self.metric_list:
            res.update(metric.evaluate())
        pred_acc = [round(res['confusion_matrix'][str(i)], 4) for i in range(len(res['confusion_matrix']))]
        # wandb.log({f"{k}/{tag}": v for k, v in res.items()})
        if verbose:
            msg = f"({tag}): "
            msg += ", ".join([f"{k}: {v:.4f}" for k, v in res.items() if 'confusion' not in k])
            print_log(msg, logger="current")
            print_log(f"pred_acc: {pred_acc}", logger="current")
        return pred_acc

    def valid(self, valid_data, batch_size=1000, tag="valid", verbose=1):
        dataset = BridgeDataset(*valid_data)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
        )
        pred_acc = self._valid(data_loader, tag, verbose)
        return pred_acc

    def test(self, test_data, batch_size=1000, tag="test", verbose=1):
        pred_acc = self.valid(test_data, batch_size, tag=tag, verbose=verbose)
        return pred_acc

class PhaseBridge(SimpleBridge):
    def __init__(
        self,
        model: ABLModel,
        abducer_list: List[ReasonerBase],
        metric_list: List[BaseMetric],
    ) -> None:
        super().__init__(model, abducer_list, metric_list)

    def train(
        self,
        train_data_list: List[Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]],
        test_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]],
        max_iters: int = 500,
        batch_size: Union[int, float] = -1,
        eval_interval: int = 10,
    ):
        phase_idx = 0
        train_pool = train_data_list[phase_idx]
        self.abducer = self.abducer_list[phase_idx]
        digit_base = len(self.abducer_list[-1].kb.pseudo_label_list)
        
        dataset = BridgeDataset(*train_pool)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
        )

        iter_cnt = 0
        while iter_cnt <= max_iters:
            next_flag = False
            for seg_idx, (X, Z, Y) in enumerate(data_loader):
                iter_cnt += 1
                if iter_cnt > max_iters:
                    break
                pred_idx, pred_prob = self.predict(X)
                pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
                abduced_pseudo_label = self.abduce_pseudo_label(pred_prob, pred_pseudo_label, Y)
                abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
                loss = self.model.train(X, abduced_label)
                
                # abduce_acc = np.mean([i==j for i,j in zip(abduced_label, self.pseudo_label_to_idx(Z))])
                # pred_char_acc = np.mean([np.mean([x==y for x,y in zip(i,j)]) for i,j in zip(pred_pseudo_label, Z)])
                # reasoning_acc = np.mean([self.abducer.kb.logic_forward(i) == j for i,j in zip(pred_pseudo_label, Y)])
                
                
                # for i in range(digit_base):
                #     wandb.log({f"char_acc_{i}": char_acc[i]})
                
                # print_log(f"char_acc: [{', '.join([f'{acc:.4f}'.rstrip('0').rstrip('.') for acc in char_acc])}]")
                # print_log(
                    # f"Iter(Phase) [{iter_cnt}]({phase_idx + 1}), model loss is {loss:.5f}, pred_char_acc/reasoning_acc is {pred_char_acc:.5f}/{reasoning_acc:.5f}",
                    # logger="current",
                # )
                # print(f"abduce_acc: {abduce_acc:.5f}")
                # wandb.log({"train loss": loss, "abduce_acc":abduce_acc, "pred_char_acc":pred_char_acc})

                if iter_cnt % eval_interval == 0:
                    print_log(f"Evaluation start: Iter [{iter_cnt}]", logger="current")
                    self.test(train_pool, tag="train")
                    self.test(test_data, tag="test")
                
                if phase_idx + 1 < len(train_data_list):
                    char_acc = self.test(test_data, tag="test", verbose=0)
                    if all(char_acc[i] > 0.1 for i in self.abducer.kb.pseudo_label_list):
                        next_flag = True
                        break
            
            if next_flag:
                phase_idx += 1
                train_pool = train_data_list[phase_idx]
                self.abducer = self.abducer_list[phase_idx]
            
                dataset = BridgeDataset(*train_pool)
                data_loader = DataLoader(
                    dataset,
                    batch_size=batch_size,
                    collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
                )


class WeaklySupervisedBridge(SimpleBridge):
    def __init__(self, model: WeaklySupervisedABLModel, abducer: WeaklySupervisedReasoner, metric_list: List[BaseMetric]) -> None:
        # assert isinstance(abducer, WeaklySupervisedReasoner), f"abducer should be an instance of WeaklySupervisedReasoner but get {type(abducer)}"
        super().__init__(model, abducer, metric_list)

    def abduce_candidates_set(self, pred_prob: ndarray, pred_pseudo_label: List[List[Any]], Y: List[Any], max_revision: int = -1, require_more_revision: int = 1) -> List[List[Any]]:
        return self.abducer.batch_abduce_candidates_set(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)

    def train(
        self,
        train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]],
        max_iters: int = 500,
        batch_size: Union[int, float] = -1,
        eval_interval: int = 10,
        more_revision: int = 3,
        test_data=None,
    ):
        dataset = BridgeDataset(*train_data)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
        )

        iter_cnt = 0
        while iter_cnt <= max_iters:
            for seg_idx, (X, Z, Y) in enumerate(data_loader):
                iter_cnt += 1
                if iter_cnt > max_iters:
                    break
                pred_idx, pred_prob = self.predict(X)
                pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
                abduced_candidates_set = self.abduce_candidates_set(pred_prob, pred_pseudo_label, Y, require_more_revision=more_revision)
                abduced_candidates_set = self.pseudo_label_to_idx(abduced_candidates_set)
                loss,confidence, abduce_acc = self.model.train(X, abduced_candidates_set, Z=self.pseudo_label_to_idx(Z))
                # pred_char_acc = np.mean([np.mean([x==y for x,y in zip(i,j)]) for i,j in zip(pred_pseudo_label, Z)])
                # reasoning_acc = np.mean([self.abducer.kb.logic_forward(i) == j for i,j in zip(pred_pseudo_label, Y)])
                
                # print_log(
                    # f"Iter [{iter_cnt}], model loss is {loss:.5f}, pred_char_acc/reasoning_acc is {pred_char_acc:.5f}/{reasoning_acc:.5f}",
                    # logger="current",
                # )
                # candidate_set_size = sum([len(x) for x in abduced_candidates_set]) / len(abduced_candidates_set)
                # wandb.log({"train loss": loss, "candidate set size": candidate_set_size, "Confidence": confidence, "abduce_acc":abduce_acc})

                if iter_cnt % eval_interval == 0:
                    print_log(f"Evaluation start: Iter [{iter_cnt}]", logger="current")
                    self.test(train_data, tag="train")
                    self.test(test_data, tag="test")
