import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
import pickle
from . import roc_mc as roc
import multiprocessing as mp

sk = roc.sinkhorn_knopp_unbalanced
infty = 1e9
figsize=np.array([2,1.5])

def bayesian(*x):
    m = np.exp(-x[2])
    return m/np.sum(m, axis=0)*x[1]
    
class BayesianDiscr:
    sk = roc.sinkhorn_knopp_unbalanced
    bi = lambda *x: bayesian(*x)
    def __init__(self, M, rr, cc, T, **args):
        self.M = M
        self.rr = rr
        self.cc = cc
        self.T = T
        if "path" in args.keys():
            self.path = args["path"]
            
        self.log = {"data":[], "post":[],
                    "result":[],
                    "M":self.M.copy(),
                    "rr": self.rr.copy(),
                    "cc":self.cc.copy()}
            
    def sampler(self):
        return np.random.choice(n, p=self.rr)
            
    def single(self, log=True, method='sk'):
        index = len(self.log["result"])
        self.log["data"] += [[]]
        self.log['post'] += [[]]
        self.log['result'] += [0]
        if method=='sk':
            matmet = BayesianDiscr.sk
        else:
            matmet = BayesianDiscr.bi
        post  = self.cc.copy()
        M = self.M.copy()
        for i in range(self.T):
            M = matmet(self.rr.copy(),
                       post.copy(),
                       -np.log(self.M),
                       1, 0, 1e9)
            d = self.sampler()
            # print(M)
            post = M[d]/np.sum(M[d])
            if log:
                self.log["post"][index] += [post.copy()]
                self.log['data'][index] += [d]
        res = np.argmax(post)
        self.log['result'][index] = res
        return res
    
    def calc(self, repeats = 100, log=True, method='sk'):
        for i in range(repeats):
            self.single(log=log, method=method)
            
        return self.log['result']
    
class PathDiscr(BayesianDiscr):
    def __init__(self, M, rr, cc, T, **args):
        '''
        args: 
            append: bool
            path: list
            append_direct: bool
        '''
        super().__init__(M,rr,cc,T)
        self.append = False
        if "append" in args.keys():
            if args['append']:
                self.append=True
                
        self.append = False
        
        if "path" not in args.keys():
            self.path = [(1,0,1e9)] * int(T)
        else:
            self.path = args["path"]
            
        if self.append:
            self.path += [(0.01, 1e9, 0)]
        
        if "append_direct" in args.keys():
            if args["append_direct"]:
                self.append_direct = True
                
    def set_sampler(self, sampler):
        self.sampler = sampler
                
    def single(self, log=True, method='sk'):
        index = len(self.log["result"])
        self.log["data"] += [[]]
        self.log['post'] += [[]]
        self.log['result'] += [None,]
        if method=='sk':
            matmet = BayesianDiscr.sk
        else:
            matmet = BayesianDiscr.bi
        post  = self.cc.copy()
        
        M = self.M.copy()
        for i in range(self.T):
            M = matmet(self.rr.copy(),
                       post.copy(),
                       -np.log(self.M),
                       *self.path[i])
            d = self.sampler()
            # print(M)
            post = M[d]/np.sum(M[d])
            if log:
                self.log["post"][index] += [post.copy()]
                self.log['data'][index] += [d]
        
        if self.append:
            # print(M, self.path[self.T])
            if self.append_direct:
                post = np.zeros_like(post)
                post[np.argmax(M[d])] = 1.
            else:
                M = matmet(self.rr.copy(),
                           post.copy(),
                           -np.log(self.M),
                           *self.path[self.T])
                post = M[d] / np.sum(M[d])
                # print(M, self.path[self.T])
            
            if log:
                self.log["post"][index] += [post.copy()]
                self.log['data'][index] += [d]
        res = np.argmax(post)
        self.log['result'][index] = post.copy()
        return res

    
    
class BayesianDiscrConclusiveness:
    sk = roc.sinkhorn_knopp_unbalanced
    bi = lambda *x: bayesian(*x)
    def __init__(self, M, rr, cc, T, **args):
        self.M = M
        self.rr = rr
        self.cc = cc
        self.T = T
        self.n, self.m = M.shape
        if "path" in args.keys():
            self.path = args["path"]
        else:
            self.path = [(1,0,1e9) * T]
            
        self.log = {"data":[], "post":[],
                    "result":[],
                    "teacher": [],
                    "hypos":[],
                    "M":self.M.copy(),
                    "rr": self.rr.copy(),
                    "cc":self.cc.copy()}
            
    def sampler(self, prob=None):
        if prob is None:
            return -1, np.random.choice(self.n, p=self.rr)
        else:
            h = np.random.choice(self.m, p=prob)
            return h, np.random.choice(self.n, p=self.M[:,h])
    
    def get_teacher(self):
        return np.random.dirichlet([1,] * self.m)
            
    def single(self, log=True, method='sk'):
        teacher = self.get_teacher()
        datadist = np.ones(self.rr.shape[0])
        index = len(self.log["result"])
        self.log["data"] += [[]]
        self.log['post'] += [[]]
        self.log['hypos'] += [[]]
        self.log["teacher"] += [teacher]
        self.log['result'] += [0]
        if method=='sk':
            matmet = BayesianDiscrConclusiveness.sk
        else:
            matmet = BayesianDiscrConclusiveness.bi
        post  = self.cc.copy()
        M = self.M.copy()
        for i in range(self.T):
            M = matmet(datadist/np.sum(datadist),
                       post.copy(),
                       -np.log(self.M),
                       *self.path[i])
            h, d = self.sampler(teacher)
            datadist[d] += 1
            # print(M)
            post = M[d] / np.sum(M[d])
            if log:
                self.log["post"][index] += [post.copy()]
                self.log['data'][index] += [d]
                self.log['hypos'][index] += [h]
        # res = np.argmax(post)
        self.log['result'][index] = post
        return post
    
    def calc(self, repeats = 100, log=True, method='sk'):
        for i in range(repeats):
            self.single(log=log, method=method)
            
        return self.log['result']
    
    
    
    
def single_stat(pack):
    h, seed, M, rr, cc, cls, param, T, repeats = pack
    n, m = M.shape
    np.random.seed(seed)
    obj = cls(M, M[:, h], np.ones_like(cc)/m, T, **param)
    obj.calc(repeats)
    return obj.log
    
    
def statistics(M, rr, cc, cls, param:dict, filename:str, T: int, repeats: int = 1000, prefix="./data/app_theta_1_"):
    n, m = M.shape
    logs = []
    seeds = np.random.randint(int(1e9),size=m)
    packs = [(h, seeds[h], M, rr, cc, cls, param, T, repeats) for h in range(m)]
#     def single(h):
#         obj = cls(M, M[:, h], np.ones_like(cc)/m, T, **param)
#         obj.calc(repeats)
#         return obj.log

    pool = mp.Pool()
    logs = pool.map(single_stat, packs)
    
#     for h in range(m):
#         print(h)
#         obj = cls(M, M[:, h], np.ones_like(cc)/m, T, **param)
#         obj.calc(repeats)
#         logs += [obj.log,]
        
        
    with open(prefix+filename, "wb") as fp:
        pickle.dump([logs, seeds], fp)
    return logs


