from typing import Union

import pandas as pd 
import numpy as np
import networkx as nx

import src.utils as utils

class Client:

    def __init__(
            self,
            name: str,
            data: Union[pd.DataFrame, str],
            cd_function: str = 'pc',
            scoring_function: str = 'bic',
            masked: bool = True, # if False standard PERI
            linear: bool = True, # if False non-linear causal discovery
        ):
            """ initializing client """
            self.name = name
            if isinstance(data, str):
                self.data = pd.read_csv(data)
            else:
                self.data = data

            self.masked = masked    
            print(self.masked)
            self.data = self.data.values
            self.variables = list(range(self.data.shape[1]))
            # cache for regrets and local scores
            self.cache = {i: {} for i in range(len(self.variables))} # regrets   
            self.local_score = {} # score of the locally learned graph
            self.scoring_function = scoring_function
    
            # discovery of local graph and computing local scores
            self.cd_function = utils.get_cd_function(cd_function, linear=linear)
            self.scoring_class = utils.get_scoring_class(scoring_function)(self.data)
            self.graph = self.cd_function(self.data)
            self._compute_local_scores()
            # mapping of child to parents in local graph
            self.child_to_parents = {}
            for i, _ in enumerate(self.variables):
                self.child_to_parents[i] = list(map(int, np.where(self.graph[:, i] == 1)[0]))

            # undirected/bidirected edges
            self.undirected = set()
            for i in range(self.graph.shape[0]):
                for j in range(i+1, self.graph.shape[1]):
                    if self.graph[i, j] == 1 and self.graph[j, i] == 1:
                        edge = [i, j]
                        edge.sort()
                        self.undirected.add(tuple(edge))

    def _compute_local_scores(self):
        for i, var in enumerate(self.variables):
            parents = list(map(int, np.where(self.graph[:, i] == 1)[0]))
            self.local_score[var] = self.scoring_class.local_score(i, parents)
           
    def score(self, server_parents: np.ndarray, server_child: int, undirected: bool = False) -> float:
        """ compute local score """
        if not undirected:
            if self.masked:
                server_parents = list(set(server_parents) & 
                                    set(self.child_to_parents[server_child]))
                server_parents.sort()    
            
            if tuple(server_parents) not in self.cache[server_child]:       
                score = self.scoring_class.local_score(server_child, server_parents)
                score = score - self.local_score[server_child]
                self.cache[server_child][tuple(server_parents)] = score
            else:
                score = self.cache[server_child][tuple(server_parents)]

        if undirected:
            # print('server parents_0:', server_parents)
            # print('serve child:', server_child)
            self.scoring_class = utils.get_scoring_class(
                'bic_pen',
                lmbda=100 * np.log(self.data.shape[0])
            )(self.data)
            # common_undirected = self.undirected & undirected
            # local_missing_undirected_edges = undirected - common_undirected
            for parent in server_parents.copy():
                # uedge = [int(parent), server_child]
                # uedge.sort()
                if self.graph[int(parent), server_child] == 0 and self.graph[server_child, int(parent)] == 0:
                    server_parents.remove(parent)

            # for parent in self.child_to_parents[server_child]:
            #     uedge = [int(parent), server_child]
            #     uedge.sort()
            #     if tuple(uedge) in self.undirected:
            #         server_parents += self.child_to_parents[server_child]

            #     elif self.graph[int(parent), server_child] == 0 and self.graph[server_child, int(parent)] == 0:
            #         server_parents.remove(parent) if parent in server_parents else None
            # print('server name', self.name)
            # print('server parents:', server_parents)
            # print('local parents:', self.child_to_parents[server_child])
            # print('server: ', self.name, 'Server child:', server_child, 'server parents:', list(set(server_parents)), 'local parents:', self.child_to_parents[server_child])
            # print(self.scoring_class.local_score(server_child, server_parents), self.scoring_class.local_score(server_child, self.child_to_parents[server_child]))
            score = self.scoring_class.local_score(server_child, server_parents)
            score -= self.scoring_class.local_score(server_child, self.child_to_parents[server_child])
            # print('score: ', score)
            # score = score - 
            # score = -score

        return score   
    
    def clear_cache(self):
        self.cache = {i: {} for i in range(len(self.variables))} # regrets