import numpy as np
from matplotlib import pyplot as plt
import deepdish as dd
from GridWorld import GridWorld
from library import *
from collections import defaultdict


env = GridWorld()
T_states=[(3,3),(3,9),(9,3),(9,9),
          (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
          (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
          (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
          (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
n_goals = len(T_states)

############################################################################
bases = []
n=int(np.ceil(np.log2(len(T_states))))
m=(2**n)/2
for i in range(n):
    bases.append([])
    b=False
    for j in range(0,2**n):
        if j>=len(T_states):
            break
        if b:
            bases[i].append(1) #1=True=rmax
        else:
            bases[i].append(0) #0=False=rmin
        if (j+1)%m==0:
            if b:
                b=False
            else:
                b=True
    m=m/2
bases_boolean_algebra = bases
bases_boolean_algebra_n = len(bases)
######################################
bases_disjunction = np.eye(n_goals)

############################################################################
import random    
def sample_random(n_goals):
    task = np.zeros(n_goals,dtype=int)
    i = random.sample(range(n_goals),random.randint(0,n_goals))
    task[i] = 1
    return list(task)
    # return [np.random.randint(2) for _ in range(n_goals)]

def sample_best(i, n_goals):
    if i<bases_boolean_algebra_n:
        return list(bases_boolean_algebra[i])
    else:
        return sample_random(n_goals)

def sample_worst(i, n_goals):
    if i<n_goals:
        return list(bases_disjunction[i])
    else:
        return sample_random()

############################################################################
        
# tasks = [sample_random(n_goals) for _ in range(10)]
# task = sample_random(n_goals)
# exp = task_exp(tasks, task, n_goals)
# task_ = exp_task(tasks, exp, n_goals)
# print(str(task))
# print(str(task_))
# print(task==task_)
############################################################################
T_states_ = [[pos,pos] for pos in T_states]

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0.1
maxiter=500
epsilon=0.1
alpha=0.1


# maxiter=1000
# Ts = []
# Qs = []
# tasks = [[1]*n_goals] + [sample_best(j, n_goals) for j in range(bases_boolean_algebra_n)]
# for i in range(len(tasks)):
#     print('task',i)
#     task = tasks[i]
#     goals = []
#     for g in range(n_goals):
#         if task[g]==1:
#             goals.append(T_states_[g])
#     env = GridWorld(T_states=T_states_, goals=goals, goal_reward=goal_reward, 
#                     step_reward=step_reward, slip_prob=slip_prob)
#     T, Q, _ = GOAL(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
#     Ts.append(T)
#     Qs.append(Q)
# # dd.io.save('data/SFGPI_base.h5', (Ts,Qs))
# tasks_SOP, values_SOP = Ts, Qs
# env.render( P=EQ_P(Qs[0]), V = EQ_V(Qs[0]))
# env.render( P=EQ_P(Qs[0],goal='[(11, 11), (11, 11)]'), V = EQ_V(Qs[0],goal='[(11, 11), (11, 11)]'))

# maxiter=20000
# envs = []
# tasks = [sample_worst(j, n_goals) for j in range(n_goals)]
# for task in tasks:
#     goals = []
#     for g in range(n_goals):
#         if task[g]==1:
#             goals.append(T_states_[g])
#     envs.append(GridWorld(T_states=T_states_, goals=goals, goal_reward=goal_reward, 
#                     step_reward=step_reward, slip_prob=slip_prob))
# Rs, Qs, stats = SFGPIB(envs, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)        
# # dd.io.save('data/SFGPI_base.h5', (Rs,Qs))
# rewards_SFGPI, values_SFGPI = Rs, Qs
# env.render( P=EQ_P(Qs[0]), V = EQ_V(Qs[0]))
# env.render( P=EQ_P(Qs[0],goal=0), V = EQ_V(Qs[0],goal=0))

############################################################################

tasks = [
          [1]+[0]*39,#sample_best(6, n_goals), #
        [0]*4+[1]*36,
        [1,1,0,0]+[1]*10+[0]*9+[1]*4+[0]*4+[1]*5+[0]*4,
        [1,0,1,0]+[1]*5+[0]*5+[1]*5+[0]*4+[1]*8+[0]*9,
        [1,0,0,0]+[1]*5+[0]*5+[0]*9+[1]*4+[0]*4+[0]*9,
        ]

maxiter=1000
num_runs = 10
data_Q = np.zeros((len(tasks),num_runs,maxiter//10)) 
data_SOP = np.zeros((len(tasks),num_runs,maxiter//10)) 
data_SFGPI = np.zeros((len(tasks),num_runs,maxiter//10)) 
for i in range(len(tasks)):
    print('task',i)
    task = tasks[i]
    goals = []
    for g in range(n_goals):
        if task[g]==1:
            goals.append(T_states_[g])
    
    env = GridWorld(T_states=T_states_, goals=goals, goal_reward=goal_reward, 
                        step_reward=step_reward, slip_prob=slip_prob)
    
    for j in range(num_runs):
        print('run',j)
        
        Q, stats = Q_learning(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
        data_Q[i,j] = stats['E'][:-1]    
        
        R, Q, stats = GOALT(env, learned=(tasks_SOP, values_SOP), gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
        data_SOP[i,j] = stats['E'][:-1]    
        
        # R, Q, stats, learned = SFGPIT(envs, learned=(rewards_SFGPI, values_SFGPI), gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
        # data_SFGPI[i,j] = stats['R'][:-2]
    
        dd.io.save('data/exp_2_1.h5', (data_SOP,data_SFGPI))
