import random

import torch
import torch.nn as nn

from sklearn import tree

import gurobipy as gp
from gurobipy import GRB


dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

"""
    Decision Tree
"""
class Tree:
    def __init__(self, d_in, p_in, p_out):
        """
            d_in is the input size
            p_in the positions of the inputs
            p_out the position of the output
        """
        self.d_in = d_in
        self.p_in = p_in
        self.p_out = p_out

        self.empty = d_in==0

        self.clf = None
        self.decision_rules = () # tuple of (terminal or not, (input variable,tuple,tuple)/output)

    def learn(self, x, y):
        if self.empty:
            x = [[0] for _ in y]
        clf = tree.DecisionTreeClassifier()
        clf = clf.fit(x,y)
        self.clf = clf
        left_child = clf.tree_.children_left
        right_child = clf.tree_.children_right
        feature = clf.tree_.feature
        value = clf.tree_.value
        output = []
        for i in range(clf.tree_.node_count):
            if clf.tree_.n_classes==1:
                output.append(y[0])
            else:
                a,b = value[i,0,:]
                if a>b:
                    output.append(0)
                else:
                    output.append(1)
        
        def recurse(node):
            if left_child[node]==-1:
                return (1,output[node])
            return (0,(feature[node],recurse(left_child[node]),recurse(right_child[node])))
        
        self.decision_rules = recurse(0)
    
    def add_constraint_to_model(self, model, u):
        # add a binary variable to each node if its part of the decision trajectory then force the output to be equal to the feature node if its part of the trajectory
        if self.empty:
            model.addConstr(u[self.p_out]==self.decision_rule[1])
            return
        def recurse(dt, root, parent_b, side):

            # is the node active or not
            b = model.addVar(vtype=GRB.BINARY)
            if root:
                model.addConstr(b==1)
            else:
                # if parent is activated and side is true <=> node activated
                and_parent_b_side = model.addVar(vtype=GRB.BINARY)
                model.addGenConstrAnd(and_parent_b_side, [parent_b,side])
                model.addConstr((and_parent_b_side==1)>>(b==1))
                model.addConstr((b==1)>>(and_parent_b_side==1))
            
            terminal,decision_rule = dt
            if terminal:
                model.addConstr((b==1)>>(u[self.p_out]==decision_rule))
            else:
                input_var,dt0,dt1 = decision_rule
                input_pos = self.p_in[input_var]
                
                left = model.addVar(vtype=GRB.BINARY)
                right = model.addVar(vtype=GRB.BINARY)
                model.addConstr(left==1-u[input_pos])
                model.addConstr(right==u[input_pos])

                recurse(dt0, False, b, left)
                recurse(dt1, False, b, right)

        recurse(self.decision_rules,True,None,None)
    
    def predict(self, x):
        if self.empty:
            x = [0]
            return self.clf.predict([x]).tolist()[0]
        return self.clf.predict([[x[i] for i in self.p_in]]).tolist()[0]


"""
    Implement an autoregressive tree-based learning method for trajectories
"""
class AR_TREE_TRAJ:
    def __init__(self, d_state, H):
        
        d_in = d_state*(H+1) + H + H # states + actions + rewards
        self.d_in = d_in
        self.d_state = d_state
        self.H = H

        self.falsified_trees = [True]*d_in
        
        unit_to_pos = {}
        p = 0
        for t in range(H):
            for i in range(d_state):
                unit_to_pos['t'+str(t)+'state'+str(i)] = p
                p+=1
            if t!=0:
                unit_to_pos['t'+str(t)+'reward'] = p
                p+=1

            unit_to_pos['t'+str(t)+'action'] = p
            p+=1
            
        for i in range(d_state):
            unit_to_pos['t'+str(H)+'state'+str(i)] = p
            p+=1
        unit_to_pos['t'+str(H)+'reward'] = p
        p+=1
        
        if p!=d_in:
            print("error p!=d_in")

        idx = [[] for _ in range(d_in)]
        for t in range(H):
            # state
            for i in range(d_state):
                #auto-regressive
                for j in range(i):
                    idx[unit_to_pos['t'+str(t)+'state'+str(i)]].append(unit_to_pos['t'+str(t)+'state'+str(j)])
            if t!=0:
                # reward
                for j in range(d_state):
                    idx[unit_to_pos['t'+str(t)+'reward']].append(unit_to_pos['t'+str(t)+'state'+str(j)])
            # action
            # no need

            # next state dynamics
            for i in range(d_state):
                # current state
                for j in range(d_state):
                    idx[unit_to_pos['t'+str(t+1)+'state'+str(i)]].append(unit_to_pos['t'+str(t)+'state'+str(j)])
                # action
                idx[unit_to_pos['t'+str(t+1)+'state'+str(i)]].append(unit_to_pos['t'+str(t)+'action'])
        for i in range(d_state):
            for j in range(i):
                idx[unit_to_pos['t'+str(H)+'state'+str(i)]].append(unit_to_pos['t'+str(H)+'state'+str(j)])
        for j in range(d_state):
            idx[unit_to_pos['t'+str(H)+'reward']].append(unit_to_pos['t'+str(H)+'state'+str(j)])

        self.trees = []
        for i,l in enumerate(idx):
            self.trees.append([Tree(len(l),l,i)])
    
    def train(self, x, retrain_only_falsified=False, first_pass=False):
        # x is a list of trajectories
        for i in range(len(self.trees)):
            falsified = self.falsified_trees[i]
            tree_list = self.trees[i]
            if (not retrain_only_falsified) or falsified:
                tree0 = tree_list[0]

                # construct dataset
                X = [[traj[i] for i in tree0.p_in] for traj in x]
                Y = [traj[tree0.p_out] for traj in x]
                
                
                role = self.pos_to_semantic(i)
                if role=='reward' and not first_pass:
                    tree_new = Tree(tree0.d_in, tree0.p_in, tree0.p_out)
                    tree_new.learn(X,Y)
                    tree_list.append(tree_new)
                else:
                    tree0.learn(X,Y)

    
    def decoding_step(self, x):
        # input x is a partial trajectory
        i = len(x)
        out = self.trees[i].predict(x)
        x.append(out)

        return x
    
    def decode(self):
        x = []
        for i in range(self.d_in):
            x = self.decoding_step(x)
        
        return x
    
    def print_trajectory(self):
        print()
        traj = self.decode()
        p = 0
        for t in range(env.H):
            print(traj[p:(p+self.d_state)])
            p+=self.d_state
            if t!=0:
                print(traj[p])
                p+=1
            print(traj[p])
            p+=1
            
        print(env.b)
        print(traj[p:(p+self.d_state)])
        p+=self.d_state
        print(traj[p])
        p+=1
        print()
    
    def pos_to_semantic(self,p):
            d_state = self.H-2+2*(self.H-2)+1
            if p<d_state:
                return 'state'
            if p==d_state:
                return 'action'
            
            p = p-d_state-1
            p_ = p%(d_state+2)
            if p_<d_state:
                return 'state'
            if p_==d_state:
                return 'reward'
            if p==d_state+1:
                return 'action'
            
    def falsified_predictor(self, trajectory):
        self.falsified_trees = [False]*len(self.trees)
        for i in range(len(self.trees)):
            role = self.pos_to_semantic(i)
            if role=='state' or role=='reward':
                x = trajectory[:i]
                out = self.trees[i][-1].predict(x)
                self.falsified_trees[i] = out!=trajectory[i]



class Agent_Planning(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.H = env.H
        self.env = env
        self.d_state = self.H-2+2*(self.H-2)+1
        self.dynamics_model = AR_TREE_TRAJ(env.d_state-env.H-1, env.H) # autoregressive model for a trajectory

        self.predicted_states = [[] for _ in range(self.H)] #to check if everything goes accorded to plan
        self.planned_actions = []
    
    def optimize_plan(self, t, x):
        ### Construct deterministic optimization model of the dynamics based on the learned model
        ### objective to maximize the probability of rewards
        if t==0:
            return 0
        
        state = x[self.d_state+1+(self.d_state+2)*(t-1):-1]
        if self.predicted_states[t]==state:
            return self.planned_actions[t]
        
        H = env.H
        d_state = H-2+2*(H-2)+1
        traj_len = d_state*(H+1) + H + H
        

        ### x = history of trajectory, current state, reward,
        def pos_to_semantic(p):
            if p<d_state:
                return 'state'
            if p==d_state:
                return 'action'
            
            p = p-d_state-1
            p_ = p%(d_state+2)
            if p_<d_state:
                return 'state'
            if p_==d_state:
                return 'reward'
            if p==d_state+1:
                return 'action'

        
        ### model
        model = gp.Model("NN")
        u = [] # domain of optimization #trajectory
        for i in range(traj_len):
            u.append(model.addVar(name="u"+str(i), vtype=GRB.BINARY))
        
        # x=u
        for i in range(len(x)):
            model.addConstr(u[i]==x[i])
        
        y = []
        for j in range(traj_len):
            if j>=len(x) and pos_to_semantic(j)=='state':
                tree = self.dynamics_model.trees[j][0]
                tree.add_constraint_to_model(model, u)
            
            if pos_to_semantic(j)=='reward':
                tree_list = self.dynamics_model.trees[j]
                for tree in tree_list:
                    tree.add_constraint_to_model(model, u)

                y.append(model.addVar(vtype=GRB.BINARY))
                model.addConstr(u[j]==y[-1])

        model.params.LogToConsole = 0

        model.setObjective(gp.quicksum(y[i] for i in range(len(y))), GRB.MAXIMIZE)
        model.optimize()

        u_ = [int(u.X) for u in u]
        
        self.predicted_states = [[] for _ in range(self.H+1)]
        self.planned_actions = [[] for _ in range(self.H)]
        traj = u_
        p = 0
        for t in range(self.H):
            self.predicted_states[t] = traj[p:(p+self.d_state)]
            p+=self.d_state
            if t!=0:
                p+=1
            self.planned_actions[t] = traj[p]
            p+=1
            
        self.predicted_states[self.H] = traj[p:(p+self.d_state)]

        return u_[len(x)]

    def policy(self, x):
        H = env.H
        d_state = H-2+2*(H-2)+1
        xt = x[:H+1].tolist()
        t = 0
        for xti in xt:
            if xti==1:
                break
            t+=1
        x = x[H+1:]

        with torch.no_grad():
            x = x.tolist()
            if t>0:
                x = [0]*(d_state+1+(d_state+2)*(t-1))+x+[0]
            self.eval()
            a = self.optimize_plan(t, x)
            self.train()

        return a
    
    def policy_with_timestep(self, x):
        t,x = x
        H = env.H
        d_state = H-2+2*(H-2)+1

        with torch.no_grad():
            x = x.tolist()
            if t>0:
                x = [0]*(d_state+1+(d_state+2)*(t-1))+x+[0]
            self.eval()
            a = self.optimize_plan(t, x)
            self.train()

        return a
    
    
    def train_agent(self, N=1000, N_batch=100, lr=1e0, n_iter_max=15000, weight_decay=1e-8):
        uniform_policy = lambda state : int(random.random()>0.5)
        Ds = []


        D = self.env.sample_full_trajectories(uniform_policy, N)
        print(D.size())
        D = D.tolist()

        ok = False
        iter = 0
        while not ok:
            print("\n")
            print("iteration ",iter, flush=True)
            iter+=1
            self.dynamics_model.train(D, retrain_only_falsified=True, first_pass=iter==1)
            
            for i in range(50):
                trajectory,ok = self.env.sample_full_trajectory(lambda state: self.policy_with_timestep(state), return_reward=True)
                self.predicted_states = [[] for _ in range(self.H)]
                if not ok:
                    break
            
            print(tree.export_text(self.dynamics_model.trees[-1][-1].clf, max_depth=100, show_weights=True))
            print("maximum depth",self.dynamics_model.trees[-1][-1].clf.tree_.max_depth)

            self.dynamics_model.falsified_predictor(trajectory)
            print("number of predictors to retrain ", sum(self.dynamics_model.falsified_trees))
            D.append(trajectory)

            
        

        self.env.deterministic_right = True
        ok = self.env.eval(self, N=1)
        self.env.deterministic_right = False
        print()

        return ok,iter



def batchify(D, n_batch):
    n = D.size(0)
    p = torch.randperm(n)
    Ds = []
    n_by_batch = n//n_batch
    for i in range(n_batch):
        d = D[p[(i*n_by_batch):((i+1)*n_by_batch)]]
        Ds.append(d)
 
    return Ds


if __name__ == '__main__':
    import environment

    H = 10
    env = environment.ENV(H)
    print("H = ")
    print(env.H)
    print("b = ")
    print(env.b)

    agent = Agent_Planning(env)

    
    agent.train_agent()

    env.deterministic_right = True
    env.eval(agent,N=1)

    print("Planning solver")
    Hs = [(i+1)*5 for i in range(10)]
    n_tests = 10
    all_results = []
    iterations_full = []
    for H in Hs:
        print()
        print("Testing H = "+str(H))
        
        results = []
        iterations = []
        for i in range(n_tests):
            env = environment.ENV(H)
            agent = Agent_Planning(env)
            _,iter = agent.train_agent()
            env.deterministic_right = True
            results.append(env.eval(agent,N=1))
            iterations.append(iter)
        print("results : ")
        print(results)
        print(sum(results)/n_tests)
        all_results.append(sum(results)/n_tests)
        iterations_full.append(iterations)
    print(all_results)
    print(iterations_full)