import numpy as np
from matplotlib import pyplot as plt
import deepdish as dd
from GridWorld import GridWorld
from library import *
from collections import defaultdict


env = GridWorld()
T_states=[(3,3),(3,9),(9,3),(9,9),
          (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
          (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
          (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
          (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
n_goals = len(T_states)

######################################
bases = []
n=int(np.ceil(np.log2(len(T_states))))
m=(2**n)/2
for i in range(n):
    bases.append([])
    b=False
    for j in range(0,2**n):
        if j>=len(T_states):
            break
        if b:
            bases[i].append(1) #1=True=rmax
        else:
            bases[i].append(0) #0=False=rmin
        if (j+1)%m==0:
            if b:
                b=False
            else:
                b=True
    m=m/2
bases_boolean_algebra = bases
bases_boolean_algebra_n = len(bases)
######################################
bases_disjunction = np.eye(n_goals)
######################################
    
def task_exp(tasks, task):
    exp = 'm+'
    for i in range(n_goals):
        if task[i]:
            for j in range(len(tasks)):
                if tasks[j][i]:
                    exp += str(j)
                else:
                    exp += '-'+str(j)
                exp += '.'
            exp = exp[:-1]
            exp += '+'
    exp = exp[:-1]
    return exp

def exp_task(tasks, exp):
    task = [0]*n_goals
    for i in range(n_goals):
        for e1 in exp.split('+')[1:]:
            b = 1
            for e2 in e1.split('.'):
                j = abs(int(e2))
                if '-' not in e2:
                    b *= tasks[j][i]
                else:
                    b *= 1-tasks[j][i]
                if not b:
                    break
            if b:
                task[i] = 1
                break
    return task
######################################
import random    
def sample_random():
    task = np.zeros(n_goals,dtype=int)
    i = random.sample(range(n_goals),random.randint(0,n_goals))
    task[i] = 1
    return list(task)
    # return [np.random.randint(2) for _ in range(n_goals)]

def sample_best(i):
    if i<bases_boolean_algebra_n:
        return list(bases_boolean_algebra[i])
    else:
        return sample_random()

def sample_worst(i):
    if i<n_goals:
        return list(bases_disjunction[i])
    else:
        return sample_random()

# tasks = [sample_random() for _ in range(10)]
# task = sample_random()
# exp = task_exp(tasks, task)
# task_ = exp_task(tasks, exp)
# print(str(task))
# print(str(task_))
# print(task==task_)
######################################
T_states_ = [[pos,pos] for pos in T_states]

tasks = [
         [1]+[0]*39,
        # [0]*4+[1]*36,
        # [1,1,0,0]+[1]*10+[0]*9+[1]*4+[0]*4+[1]*5+[0]*4,
        # [1,0,1,0]+[1]*5+[0]*5+[1]*5+[0]*4+[1]*8+[0]*9,
        # [1,0,0,0]+[1]*5+[0]*5+[0]*9+[1]*4+[0]*4+[0]*9,
        ]

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0.1
maxiter=20000
epsilon=0.1
alpha=0.1

num_runs = 1
num_tasks = 50
data_best = np.zeros((len(tasks),num_runs,maxiter)) 

tasks_SOP, values_SOP = dd.io.load('data/SOP_base.h5')
tasks_SFGPI, values_SFGPI = dd.io.load('data/SFGPI_base.h5')


# learner = SFGPI #Goal_Oriented_Q_learning #
# for i in range(len(tasks)):
#     print('task',i)
#     goals = []
#     for g in range(n_goals):
#         if task[g]==1:
#             goals.append(T_states_[g])
    
#     for j in range(num_runs):
#         print('run',j)
        
#         envs.append(GridWorld(goals=goals, goal_reward=goal_reward, 
#                         step_reward=step_reward, slip_prob=slip_prob))
        
        
        
        
#         rs, Qs, stats = learner(envs, learned=(tasks_SOP, values_SOP), gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
#         data_best[i,j] = stats['R'][:-2]
#         tasks_best.append(rs)
#         values_best.append(Qs)            
    
#         dd.io.save('data/exp_3_t1.h5', (data_SOP,data_SFGPI))
