import numpy as np
import bnlearn as bn
import networkx as nx
import neal
import dimod

class EfficientConversion:
        
    def encode(self, CPS):
        n = int((len(CPS.columns) - 2) / 3)        
        self.X = CPS.columns[(2 * n + 1):(3 * n + 1)]
        self.U_list = [] 
        self.V_list = [] 
        self.s0_list = [] 
        self.s1_list = [] 
        self.s2_list = [] 
        self.t_list = [] 
        self.delta_lower = 0.0 
        self.xi_lower = 0.0 
        for (i, X_) in enumerate(self.X):
            CPS_ = CPS[CPS["i"] == i + 1]            
            UV = [set([self.X[j] for j in range(n) if CPS_.iloc[h, 1 + 2 * n + j] == 1]) for h in range(len(CPS_))]
            foo = []
            U = [U_ for U_ in [set([self.X[j] for j in range(n) if CPS_.iloc[h, 1 + j] == 1]) for h in range(len(CPS_))] 
                                                                if (U_ not in foo and not foo.append(U_)) and (U_ != set())]
            foo = []
            V = [V_ for V_ in [set([self.X[j] for j in range(n) if CPS_.iloc[h, 1 + n + j] == 1]) for h in range(len(CPS_))] 
                                                                if (V_ not in foo and not foo.append(V_)) and (V_ != set())]
            s0 = [CPS_.iloc[h, -1] for (h, UV_) in enumerate(UV) if UV_ == set()][0]
            s1 = [CPS_.iloc[h, -1] - s0 for U_ in U for (h, UV_) in enumerate(UV) if UV_ == U_]
            s2 = [CPS_.iloc[h, -1] - s0 for V_ in V for (h, UV_) in enumerate(UV) if UV_ == V_]
            t = [[CPS_.iloc[h, -1] - s1[h1] - s2[h2] - s0 for (h2, V_) in enumerate(V) for (h, UV_) in enumerate(UV) if UV_ == (U_ | V_)] for (h1, U_) in enumerate(U)]
            self.U_list += [U]
            self.V_list += [V]
            self.s0_list += [s0]
            self.s1_list += [s1]
            self.s2_list += [s2]
            self.t_list += [t]
            t = np.array(t)
            def max_(hoge):
                if hoge == []:
                    return 0
                else:
                    return max(hoge)
            self.delta_lower = max_([self.delta_lower, 
                               max_([-s1[h1]+sum([max(0,-t[h1,h2]) for h2 in range(len(s2))]) for h1 in range(len(s1))]),
                               max_([-s2[h2]+sum([max(0,-t[h1,h2]) for h1 in range(len(s1))]) for h2 in range(len(s2))])])
            self.xi_lower = max_([self.xi_lower, 
                            max_([-s2[h2]-min(0,t[h1,h2]) for j in range(n) for (h2, V_) in enumerate(V) for h1 in range(len(s1)) if self.X[j] in V_]),
                            max_([-s2[h2]-t[h1,h2]+s2[h2_]+t[h1,h2_] for j in range(n) for (h2, V_) in enumerate(V) for (h2_, V__) in enumerate(V) 
                                                                     for h1 in range(len(s1)) if self.X[j] in V_ and self.X[j] not in V__ and V__ < V_]),
                            max_([-s1[h1]-min(0,t[h1,h2]) for j in range(n) for (h1, U_) in enumerate(U) for h2 in range(len(s2)) if self.X[j] in U_]),
                            max_([-s1[h1]-t[h1,h2]+s1[h1_]+t[h1_,h2] for j in range(n) for (h1, U_) in enumerate(U) for (h1_, U__) in enumerate(U) 
                                                                     for h2 in range(len(s2)) if self.X[j] in U_ and self.X[j] not in U__ and U__ < U_])])
        return [self.X, self.U_list, self.V_list, self.s0_list, self.s1_list, self.s2_list, self.t_list, self.delta_lower, self.xi_lower]
    
    def Hamiltonian(self, delta1_ratio=1.0, delta2_ratio=1.0, xi_ratio=1.0):
        n = len(self.X)
        delta1 = self.delta_lower * delta1_ratio
        delta2 = delta1 * (n - 2) * delta2_ratio 
        delta3 = self.xi_lower * xi_ratio 
        self.s0_sum = sum(self.s0_list)
        self.H = dict()
        for i in range(n):
            if self.U_list[i] != []:
                for ii in range(len(self.U_list[i])):
                    if i > 0:
                        c = sum([1 for iii in range(i) if self.X[iii] in self.U_list[i][ii]])
                    else:
                        c = 0
                    self.H.update({((1,ii,i),(1,ii,i)): self.s1_list[i][ii] + delta2 * c})
            if self.V_list[i] != []:
                for ii in range(len(self.V_list[i])):
                    if i > 0:
                        c = sum([1 for iii in range(i) if self.X[iii] in self.V_list[i][ii]])
                    else:
                        c = 0
                    self.H.update({((2,ii,i),(2,ii,i)): self.s2_list[i][ii] + delta2 * c})
            if self.U_list[i] != [] and self.V_list[i] != []:
                for ii in range(len(self.U_list[i])):
                    for iii in range(len(self.V_list[i])):
                        self.H.update({((1,ii,i),(2,iii,i)): self.t_list[i][ii][iii]})
            if len(self.U_list[i]) >= 2:
                for ii in range(len(self.U_list[i])-1):
                    for iii in range(ii+1,len(self.U_list[i])):
                        self.H.update({((1,ii,i), (1,iii,i)): delta3}) 
            if len(self.V_list[i]) >= 2:
                for ii in range(len(self.V_list[i])-1):
                    for iii in range(ii+1,len(self.V_list[i])):
                        self.H.update({((2,ii,i), (2,iii,i)): delta3}) 
        penalty = np.zeros((n, n))
        for i in range(n-2):
            for ii in range(i+1,n-1):
                for iii in range(ii+1,n):
                    penalty[i, iii] += delta1 
                    self.H.update({((i,iii), (ii,iii)): -delta1}) 
                    self.H.update({((i,iii), (i,ii)): -delta1}) 
                    self.H.update({((i,ii), (ii,iii)): delta1}) 
        for i in range(n-2):
            for iii in range(i+2,n):
                self.H.update({((i,iii), (i,iii)): penalty[i, iii]})
        for i in range(1,n):
            for ii in range(i):
                if self.U_list[ii] != []:
                    for iii in range(len(self.U_list[ii])):
                        if self.X[i] in self.U_list[ii][iii]:
                            self.H.update({((1,iii,ii),(ii,i)): delta2})
                if self.U_list[i] != []:
                    for iii in range(len(self.U_list[i])):
                        if self.X[ii] in self.U_list[i][iii]:
                            self.H.update({((1,iii,i),(ii,i)): - delta2})
                if self.V_list[ii] != []:
                    for iii in range(len(self.V_list[ii])):
                        if self.X[i] in self.V_list[ii][iii]:
                            self.H.update({((2,iii,ii),(ii,i)): delta2})
                if self.V_list[i] != []:
                    for iii in range(len(self.V_list[i])):
                        if self.X[ii] in self.V_list[i][iii]:
                            self.H.update({((2,iii,i),(ii,i)): - delta2})
        return self.H
    
    def solver(self, num_reads=1, num_sweeps=1000, beta_schedule_type='geometric'):
        bqm = dimod.BinaryQuadraticModel.from_qubo(self.H, self.s0_sum)
        response = neal.SimulatedAnnealingSampler().sample(bqm, num_reads=num_reads, num_sweeps=num_sweeps, beta_schedule_type=beta_schedule_type)
        sample = None
        for a, r in enumerate(response):
            if a == 0:
                sample = r
                break
        edges = []
        fuga = 0
        for i, X_ in enumerate(self.X):
            if self.U_list[i] != []:
                hoge = 0
                for ii in range(len(self.U_list[i])):
                    if sample[(1,ii,i)] == 1:
                        for iii in list(self.U_list[i][ii]):
                            edges += [(iii, X_)]
                        hoge += 1
                if hoge > 1:
                    fuga += 1
            if self.V_list[i] != []:
                hoge = 0
                for ii in range(len(self.V_list[i])):
                    if sample[(2,ii,i)] == 1:
                        for iii in list(self.V_list[i][ii]):                            
                            edges += [(iii, X_)]
                        hoge += 1
                if hoge > 1:
                    fuga += 1
        G = nx.DiGraph(edges)
        for X_ in self.X:
            G.add_node(X_)
        if fuga > 0:
            print("Constraints Error: p is more than 2.")
        if nx.is_directed_acyclic_graph(G) == False:            
            print("Constraints Error: Not DAG")
        score = self.s0_sum
        for i, X_ in enumerate(self.X):
            UV = set()
            for e in edges:
                if e[1] == X_:
                    UV = UV | {e[0]}
            Z = set()
            for U in self.U_list[i]:
                Z = Z | U
            Z_ = set()
            for V in self.V_list[i]:
                Z_ = Z_ | V
            for ii, U in enumerate(self.U_list[i]):
                if U == UV & Z:
                    score += self.s1_list[i][ii]
            for iii, V in enumerate(self.V_list[i]):
                if V == UV & Z_:
                    score += self.s2_list[i][iii]
            for ii, U in enumerate(self.U_list[i]):
                for iii, V in enumerate(self.V_list[i]):
                    if U | V == UV:
                        score += self.t_list[i][ii][iii]
        return score