import numpy as np
from matplotlib import pyplot as plt
import deepdish as dd
from GridWorld import GridWorld
from library import *
from collections import defaultdict


env = GridWorld()
T_states=[(3,3),(3,9),(9,3),(9,9),
          (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
          (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
          (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
          (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
n_goals = len(T_states)

######################################
bases = []
n=int(np.ceil(np.log2(len(T_states))))
m=(2**n)/2
for i in range(n):
    bases.append([])
    b=False
    for j in range(0,2**n):
        if j>=len(T_states):
            break
        if b:
            bases[i].append(1) #1=True=rmax
        else:
            bases[i].append(0) #0=False=rmin
        if (j+1)%m==0:
            if b:
                b=False
            else:
                b=True
    m=m/2
bases_boolean_algebra = bases
bases_boolean_algebra_n = len(bases)
######################################
bases_disjunction = np.eye(n_goals)
######################################
    
def task_exp(tasks, task):
    exp = 'm+'
    for i in range(n_goals):
        if task[i]:
            for j in range(len(tasks)):
                if tasks[j][i]:
                    exp += str(j)
                else:
                    exp += '-'+str(j)
                exp += '.'
            exp = exp[:-1]
            exp += '+'
    exp = exp[:-1]
    return exp

def exp_task(tasks, exp):
    task = [0]*n_goals
    for i in range(n_goals):
        for e1 in exp.split('+')[1:]:
            b = 1
            for e2 in e1.split('.'):
                j = abs(int(e2))
                if '-' not in e2:
                    b *= tasks[j][i]
                else:
                    b *= 1-tasks[j][i]
                if not b:
                    break
            if b:
                task[i] = 1
                break
    return task
######################################
    
def sample_random():
    return [np.random.randint(2) for _ in range(n_goals)]

def sample_best(i):
    if i<bases_boolean_algebra_n:
        return list(bases_boolean_algebra[i])
    else:
        return sample_random()

def sample_worst(i):
    if i<n_goals:
        return list(bases_disjunction[i])
    else:
        return sample_random()

# tasks = [sample_random() for _ in range(10)]
# task = sample_random()
# exp = task_exp(tasks, task)
# task_ = exp_task(tasks, exp)
# print(str(task))
# print(str(task_))
# print(task==task_)
######################################
T_states_ = [[pos,pos] for pos in T_states]

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0.1
maxiter=500
epsilon=0.3
alpha=0.1

num_runs = 1
num_tasks = 50
data_best = np.zeros((num_runs,num_tasks)) 
data_worst = np.zeros((num_runs,num_tasks)) 
data_random = np.zeros((num_runs,num_tasks)) 

learner = GOAL#Goal_Oriented_Q_learning #GOAL #
# for i in range(num_runs):
#     print('run',i)
#     tasks_best = []
#     values_best = []
#     tasks_worst = []
#     values_worst = []
#     tasks_random = []
#     values_random = []
#     for j in range(num_tasks):
#         print('task',j)
#         task = sample_best(j)
#         exp = task_exp(tasks_best, task)
#         task_ = exp_task(tasks_best, exp)
#         if task==task_:        
#             data_best[i,j] = 0
#         else:
#             goals = []
#             for g in range(n_goals):
#                 if task[g]==1:
#                     goals.append(T_states_[g])
#             # print('*')
#             env = GridWorld(goals=goals, goal_reward=goal_reward, 
#                             step_reward=step_reward, slip_prob=slip_prob)
#             Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
#             data_best[i,j] = stats1['T']
#             tasks_best.append(task)
#             values_best.append(Q)

        # task = sample_worst(j)
        # exp = task_exp(tasks_worst, task)
        # task_ = exp_task(tasks_worst, exp)
        # if task==task_:        
        #     data_worst[i,j] = 0
        # else:
        #     goals = []
        #     for g in range(n_goals):
        #         if task[g]==1:
        #             goals.append(T_states_[g])
        #     # print('**')
        #     env = GridWorld(goals=goals, goal_reward=goal_reward, 
        #                     step_reward=step_reward, slip_prob=slip_prob)
        #     Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
        #     data_worst[i,j] = stats1['T']
        #     tasks_worst.append(task)
        #     values_worst.append(Q)
            
        # task = sample_random()
        # exp = task_exp(tasks_random, task)
        # task_ = exp_task(tasks_random, exp)
        # if task==task_:        
        #     data_random[i,j] = 0
        # else:
        #     goals = []
        #     for g in range(n_goals):
        #         if task[g]==1:
        #             goals.append(T_states_[g])
        #     # print('***')
        #     env = GridWorld(goals=goals, goal_reward=goal_reward, 
        #                     step_reward=step_reward, slip_prob=slip_prob)
        #     Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
        #     data_random[i,j] = stats1['T']
        #     tasks_random.append(task)
        #     values_random.append(Q)

# dd.io.save('data/exp1.h5', (data_best,data_worst,data_random))

tasks_best, values_best = dd.io.load('data/exp1_.h5')

for i in range(6):
    Q = values_best[i]
    fig = env.render( P=EQ_P(Q), V = EQ_V(Q))
    fig.savefig("plots/values/"+str(i)+".pdf", bbox_inches='tight')

EQ_min = defaultdict(lambda: defaultdict(lambda: np.zeros(env.action_space.n)))
# EQ_max = values_best[0]
# for Q in values_best:
#     EQ_max = OR(EQ_max,Q)
# EQ_max = OR(EQ_max,values_worst[0])

env = GridWorld(goals=T_states_, goal_reward=goal_reward, 
                step_reward=step_reward, slip_prob=slip_prob)
EQ_max,_ = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
    
def exp_value(values, exp):     
    Q = EQ_min
    for e1 in exp.split('+')[1:]:
        Q_ = EQ_max
        for e2 in e1.split('.'):
            j = abs(int(e2))
            if '-' not in e2:
                Q_ = AND(values[j],Q_)
            else:
                Q_ = AND(NOTD(values[j],EQ_max),Q_)
        Q = OR(Q_,Q)
    return Q

# T_states=[(3,3),(3,9),(9,3),(9,9),
#           (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
#           (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
#           (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
#           (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
def task_exp(tasks, task):
    s=['A','B','C','D','E','F']
    exp = 'm+'
    for i in range(n_goals):
        if task[i]:
            for j in range(len(tasks)):
                if tasks[j][i]:
                    exp += s[j]
                else:
                    exp += '~'+s[j]
            exp += '+'
    exp = exp[:-1]
    return exp

tasks = [
         [1]+[0]*39,
        # [0]*4+[1]*36,
        # [1,1,0,0]+[1]*10+[0]*9+[1]*4+[0]*4+[1]*5+[0]*4,
        # [1,0,1,0]+[1]*5+[0]*5+[1]*5+[0]*4+[1]*8+[0]*9,
        # [1,0,0,0]+[1]*5+[0]*5+[0]*9+[1]*4+[0]*4+[0]*9,
        ]
for i in [0]:
    task = tasks[i]
    exp = task_exp(tasks_best, task)
    Q = exp_value(values_best, exp)
    fig = env.render( P=EQ_P(Q), V = EQ_V(Q))
    fig.savefig("plots/values/"+str(12)+".pdf", bbox_inches='tight')


##############################################################################

# import random

# goal_reward=1
# step_reward=0
# gamma=0.9
# slip_prob=0
# maxiter=200
# epsilon=1
# alpha=1

# def evaluateQ(goals,Q,slip_prob=0):
#     env = GridWorld(goals=goals, goal_reward=goal_reward, 
#                     step_reward=step_reward, slip_prob=slip_prob)
#     policy =  Q_P(Q)
#     state = env.reset()
#     done = False
#     t=0
#     G = 0
#     while not done and t<100:
#         action = policy[state]        
#         state_, reward, done, _ = env.step(action)       
#         state = state_      
#         G += reward*(gamma**t)
#         t += 1
#     return G

# def evaluateEQ(goals,EQ,slip_prob=0):
#     env = GridWorld(goals=goals, goal_reward=goal_reward, 
#                     step_reward=step_reward, slip_prob=slip_prob)
#     policy =  EQ_P(EQ)
#     state = env.reset()
#     done = False
#     t=0
#     G = 0
#     while not done and t<100:
#         action = policy[state]        
#         state_, reward, done, _ = env.step(action)       
#         state = state_            
#         G += reward*(gamma**t)
#         t += 1
#     return G

# def sample_random(n=None):
#     if n==None:
#         return [np.random.randint(2) for _ in range(n_goals)]
#     task = np.zeros(n_goals)
#     i = random.sample(range(n_goals),random.randint(1,n))
#     task[i] = 1
#     return list(task)

# num_runs = 1000
# num_tasks = 50
# data_optimal = np.zeros((num_runs,num_tasks)) 
# data_composed = np.zeros((num_runs,num_tasks)) 
# for i in range(num_tasks):
#     print(i)
#     task = sample_random(n=n_goals)
#     goals = []
#     for g in range(n_goals):
#         if task[g]==1:
#             goals.append(T_states_[g])
#     env = GridWorld(goals=goals, goal_reward=goal_reward, 
#                     step_reward=step_reward, slip_prob=slip_prob)
#     Q,stats1 = Q_learning(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
    
#     print('*')
#     exp = task_exp(tasks_best, task)
#     EQ = exp_value(values_best, exp)
#     print('**')
    
#     for j in range(num_runs):
#         data_optimal[j,i] = evaluateQ(goals,Q,slip_prob=0.1)
#         data_composed[j,i] = evaluateEQ(goals,EQ,slip_prob=0.1)

# dd.io.save('data/exp2.h5', (data_optimal,data_composed))       

 

# def MAX_(values):
#     Q_max = defaultdict(lambda: np.zeros(env.action_space.n))
#     for Q in values:
#         for s in list(set(list(Q_max.keys())) | set(list(Q.keys()))):
#             Q_max[s] = np.max([Q_max[s],Q[s]],axis=0)
#     return Q_max

# T_states_ = [[pos,pos] for pos in T_states]

# goal_reward=1
# step_reward=0
# gamma=0.9
# slip_prob=0.1
# maxiter=500
# epsilon=0.1
# alpha=0.1

# num_runs = 25
# num_tasks = 50
# data= np.zeros((num_runs,num_tasks)) 
# for i in range(num_runs):
#     print(i)
#     values=[]
#     for j in range(num_tasks):
#         # print(j)
#         task = sample_random()
#         goals = []
#         for g in range(n_goals):
#             if task[g]==1:
#                 goals.append(T_states_[g])
#         env = GridWorld(goals=goals, goal_reward=goal_reward, 
#                         step_reward=step_reward, slip_prob=slip_prob)
        
#         # delta = 0.05
#         # pmin = 1/(2**n_goals)
#         # if j > (np.log(delta)/np.log(1-pmin)):
#         #     Q_init=MAX_(values)
#         # else:
#         #     Q_init=defaultdict(lambda: np.zeros(env.action_space.n)+(goal_reward/(1-gamma)))            
        
#         Q_init=MAX_(values)
#         Q,stats1 = Q_learning(env, Q_init=Q_init, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
#         data[i,j] = stats1['T']
#         values.append(Q)
#     dd.io.save('data/exp3.h5', data)    