from typing import List
import itertools

import numpy as np
from scipy.optimize import linear_sum_assignment

from scorers.fscorer import FScorer
from scorers.base_scorer import BaseScorer


class GroupFScorer(BaseScorer):
    def __init__(self):
        self.__inner_scorer = FScorer()

    def pseudo_documents(self, doc: dict) -> List[dict]:
        docs = []
        for ann in doc['annotations']:
            for val in ann['values']:
                assert 'children' in val
                docs.append({
                    'name': '',
                    'annotations': val['children']
                })
        return docs

    def best_permutation(self, out_items: List[dict], ref_items: List[dict]):
        out_items = self.pseudo_documents(out_items)
        ref_items = self.pseudo_documents(ref_items)
        target_length = max(len(out_items), len(ref_items))
        out_items = self.pad(out_items, target_length)
        ref_items = self.pad(ref_items, target_length)
        matrix = []
        for o in out_items:
            row = []
            for ri, r in enumerate(ref_items):
                 fscorer = FScorer()
                 fscorer.add(o, r)
                 row.append(1 - fscorer.f_score())
            matrix.append(row)
        row_ind, col_ind = linear_sum_assignment(np.array(matrix))
        best_out = [out_items[i] for i in row_ind]
        best_ref = [ref_items[i] for i in col_ind]
        return (best_out, best_ref)
    
    def pad(self, items: List[dict], target_length: int):
        for _ in range(target_length - len(items)):
            items.append({'name': '', 'annotations': []})
        return items

    def add(self, out_items: List[str], ref_items: List[str]):
        out_perm, ref_perm = self.best_permutation(out_items, ref_items)
        for o, r in zip(out_perm, ref_perm):
            self.__inner_scorer.add(o, r)

    def support_feature_scores(cls) -> bool:
        return False

    def metric_name(cls) -> str:
        return "GROUP-F1"

    def score(self) -> float:
        return self.__inner_scorer.f_score()
