import numpy as np
from matplotlib import pyplot as plt
import deepdish as dd
from GridWorld import GridWorld
from library import *
from collections import defaultdict


env = GridWorld()
T_states=[(3,3),(3,9),(9,3),(9,9),
          (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
          (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
          (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
          (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
n_goals = len(T_states)

######################################
bases = []
n=int(np.ceil(np.log2(len(T_states))))
m=(2**n)/2
for i in range(n):
    bases.append([])
    b=False
    for j in range(0,2**n):
        if j>=len(T_states):
            break
        if b:
            bases[i].append(1) #1=True=rmax
        else:
            bases[i].append(0) #0=False=rmin
        if (j+1)%m==0:
            if b:
                b=False
            else:
                b=True
    m=m/2
bases_boolean_algebra = bases
bases_boolean_algebra_n = len(bases)
######################################
bases_disjunction = np.eye(n_goals)
######################################
    
def task_exp(tasks, task):
    exp = 'm+'
    for i in range(n_goals):
        if task[i]:
            for j in range(len(tasks)):
                if tasks[j][i]:
                    exp += str(j)
                else:
                    exp += '-'+str(j)
                exp += '.'
            exp = exp[:-1]
            exp += '+'
    exp = exp[:-1]
    return exp

def exp_task(tasks, exp):
    task = [0]*n_goals
    for i in range(n_goals):
        for e1 in exp.split('+')[1:]:
            b = 1
            for e2 in e1.split('.'):
                j = abs(int(e2))
                if '-' not in e2:
                    b *= tasks[j][i]
                else:
                    b *= 1-tasks[j][i]
                if not b:
                    break
            if b:
                task[i] = 1
                break
    return task
######################################
import random    
def sample_random():
    task = np.zeros(n_goals,dtype=int)
    i = random.sample(range(n_goals),random.randint(0,n_goals))
    task[i] = 1
    return list(task)
    # return [np.random.randint(2) for _ in range(n_goals)]

def sample_best(i):
    if i<bases_boolean_algebra_n:
        return list(bases_boolean_algebra[i])
    else:
        return sample_random()

def sample_worst(i):
    if i<n_goals:
        return list(bases_disjunction[i])
    else:
        return sample_random()

# tasks = [sample_random() for _ in range(10)]
# task = sample_random()
# exp = task_exp(tasks, task)
# task_ = exp_task(tasks, exp)
# print(str(task))
# print(str(task_))
# print(task==task_)
######################################
T_states_ = [[pos,pos] for pos in T_states]

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0.1
maxiter=50000
epsilon=0.1
alpha=0.1

num_runs = 1
num_tasks = 50
data_best = np.zeros((num_runs,num_tasks,maxiter)) 

learner = GOAL #Goal_Oriented_Q_learning #
for i in range(num_runs):
    print('run',i)
    tasks_best = []
    values_best = []
    for j in range(num_tasks):
        print('task',j)
        task = sample_best(j)
        exp = task_exp(tasks_best, task)
        task_ = exp_task(tasks_best, exp)
        if task==task_:        
            data_best[i,j] = 0
        else:
            goals = []
            for g in range(n_goals):
                if task[g]==1:
                    goals.append(T_states_[g])
            # print('*')
            env = GridWorld(goals=goals, goal_reward=goal_reward, 
                            step_reward=step_reward, slip_prob=slip_prob)
            Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
            data_best[i,j] = stats1['R']
            tasks_best.append(task)
            values_best.append(Q)            

    dd.io.save('data/returns_base.h5', data_best)
