import numpy as np
from matplotlib import pyplot as plt
import deepdish as dd
from GridWorld import GridWorld
from library import *
from collections import defaultdict


env = GridWorld()
T_states=[(3,3),(3,9),(9,3),(9,9),
          (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
          (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
          (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
          (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
n_goals = len(T_states)

######################################
bases = []
n=int(np.ceil(np.log2(len(T_states))))
m=(2**n)/2
for i in range(n):
    bases.append([])
    b=False
    for j in range(0,2**n):
        if j>=len(T_states):
            break
        if b:
            bases[i].append(1) #1=True=rmax
        else:
            bases[i].append(0) #0=False=rmin
        if (j+1)%m==0:
            if b:
                b=False
            else:
                b=True
    m=m/2
bases_boolean_algebra = bases
bases_boolean_algebra_n = len(bases)
######################################
bases_disjunction = np.eye(n_goals)
######################################
    
def task_exp(tasks, task):
    exp = 'm+'
    for i in range(n_goals):
        if task[i]:
            for j in range(len(tasks)):
                if tasks[j][i]:
                    exp += str(j)
                else:
                    exp += '-'+str(j)
                exp += '.'
            exp = exp[:-1]
            exp += '+'
    exp = exp[:-1]
    return exp

def exp_task(tasks, exp):
    task = [0]*n_goals
    for i in range(n_goals):
        for e1 in exp.split('+')[1:]:
            b = 1
            for e2 in e1.split('.'):
                j = abs(int(e2))
                if '-' not in e2:
                    b *= tasks[j][i]
                else:
                    b *= 1-tasks[j][i]
                if not b:
                    break
            if b:
                task[i] = 1
                break
    return task
######################################
import random    
def sample_random():
    task = np.zeros(n_goals,dtype=int)
    i = random.sample(range(n_goals),random.randint(0,n_goals))
    task[i] = 1
    return list(task)
    # return [np.random.randint(2) for _ in range(n_goals)]

def sample_best(i):
    if i<bases_boolean_algebra_n:
        return list(bases_boolean_algebra[i])
    else:
        return sample_random()

def sample_worst(i):
    if i<n_goals:
        return list(bases_disjunction[i])
    else:
        return sample_random()

# tasks = [sample_random() for _ in range(10)]
# task = sample_random()
# exp = task_exp(tasks, task)
# task_ = exp_task(tasks, exp)
# print(str(task))
# print(str(task_))
# print(task==task_)
######################################
T_states_ = [[pos,pos] for pos in T_states]

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0.1
maxiter=500
epsilon=0.1
alpha=0.1

num_runs = 1
num_tasks = 50
data_best = np.zeros((num_runs,num_tasks)) 
data_worst = np.zeros((num_runs,num_tasks)) 
data_random = np.zeros((num_runs,num_tasks)) 

learner = Goal_Oriented_Q_learning #GOAL #
for i in range(num_runs):
    print('run',i)
    tasks_best = []
    values_best = []
    tasks_worst = []
    values_worst = []
    tasks_random = []
    values_random = []
    for j in range(num_tasks):
        print('task',j)
        task = sample_best(j)
        exp = task_exp(tasks_best, task)
        task_ = exp_task(tasks_best, exp)
        if task==task_:        
            data_best[i,j] = 0
        else:
            goals = []
            for g in range(n_goals):
                if task[g]==1:
                    goals.append(T_states_[g])
            # print('*')
            env = GridWorld(goals=goals, goal_reward=goal_reward, 
                            step_reward=step_reward, slip_prob=slip_prob)
            Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
            data_best[i,j] = stats1['T']
            tasks_best.append(task)
            values_best.append(Q)

        task = sample_worst(j)
        exp = task_exp(tasks_worst, task)
        task_ = exp_task(tasks_worst, exp)
        if task==task_:        
            data_worst[i,j] = 0
        else:
            goals = []
            for g in range(n_goals):
                if task[g]==1:
                    goals.append(T_states_[g])
            # print('**')
            env = GridWorld(goals=goals, goal_reward=goal_reward, 
                            step_reward=step_reward, slip_prob=slip_prob)
            Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
            data_worst[i,j] = stats1['T']
            tasks_worst.append(task)
            values_worst.append(Q)
            
        task = sample_random()
        exp = task_exp(tasks_random, task)
        task_ = exp_task(tasks_random, exp)
        if task==task_:        
            data_random[i,j] = 0
        else:
            goals = []
            for g in range(n_goals):
                if task[g]==1:
                    goals.append(T_states_[g])
            # print('***')
            env = GridWorld(goals=goals, goal_reward=goal_reward, 
                            step_reward=step_reward, slip_prob=slip_prob)
            Q,stats1 = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
            data_random[i,j] = stats1['T']
            tasks_random.append(task)
            values_random.append(Q)

    dd.io.save('data/exp1.h5', (data_best,data_worst,data_random))

tasks_best, values_best = dd.io.load('data/exp1_.h5')

for i in range(6):
    Q = values_best[i]
    fig = env.render( P=EQ_P(Q), V = EQ_V(Q))
    fig.savefig("plots/values/"+str(i)+".pdf", bbox_inches='tight')

EQ_min = defaultdict(lambda: defaultdict(lambda: np.zeros(env.action_space.n)))
# EQ_max = values_best[0]
# for Q in values_best:
#     EQ_max = OR(EQ_max,Q)
# EQ_max = OR(EQ_max,values_worst[0])

env = GridWorld(goals=T_states_, goal_reward=goal_reward, 
                step_reward=step_reward, slip_prob=slip_prob)
EQ_max,_ = learner(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
    
def exp_value(values, exp):     
    Q = EQ_min
    for e1 in exp.split('+')[1:]:
        Q_ = EQ_max
        for e2 in e1.split('.'):
            j = abs(int(e2))
            if '-' not in e2:
                Q_ = AND(values[j],Q_)
            else:
                Q_ = AND(NOTD(values[j],EQ_max),Q_)
        Q = OR(Q_,Q)
    return Q

# T_states=[(3,3),(3,9),(9,3),(9,9),
#           (1,1),(1,2),(1,3),(1,4),(1,5),(1,7),(1,8),(1,9),(1,10),(1,11),
#           (11,1),(11,2),(11,3),(11,4),(11,5),(11,7),(11,8),(11,9),(11,10),
#           (2,1),(3,1),(4,1),(5,1),(7,1),(8,1),(9,1),(10,1),
#           (2,11),(3,11),(4,11),(5,11),(6,11),(8,11),(9,11),(10,11),(11,11)]
tasks = [
         [1]+[0]*39,
        # [0]*4+[1]*36,
        # [1,1,0,0]+[1]*10+[0]*9+[1]*4+[0]*4+[1]*5+[0]*4,
        # [1,0,1,0]+[1]*5+[0]*5+[1]*5+[0]*4+[1]*8+[0]*9,
        # [1,0,0,0]+[1]*5+[0]*5+[0]*9+[1]*4+[0]*4+[0]*9,
        ]
for i in [0]:
    task = tasks[i]
    exp = task_exp(tasks_best, task)
    Q = exp_value(values_best, exp)
    fig = env.render( P=EQ_P(Q), V = EQ_V(Q))
    fig.savefig("plots/values/"+str(12)+".pdf", bbox_inches='tight')


##############################################################################

import random

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0
maxiter=200
epsilon=1
alpha=1

def evaluateQ(goals,Q,slip_prob=0):
    env = GridWorld(goals=goals, goal_reward=goal_reward, 
                    step_reward=step_reward, slip_prob=slip_prob)
    policy =  Q_P(Q)
    state = env.reset()
    done = False
    t=0
    G = 0
    while not done and t<100:
        action = policy[state]        
        state_, reward, done, _ = env.step(action)       
        state = state_      
        G += reward*(gamma**t)
        t += 1
    return G

def evaluateEQ(goals,EQ,slip_prob=0):
    env = GridWorld(goals=goals, goal_reward=goal_reward, 
                    step_reward=step_reward, slip_prob=slip_prob)
    policy =  EQ_P(EQ)
    state = env.reset()
    done = False
    t=0
    G = 0
    while not done and t<100:
        action = policy[state]        
        state_, reward, done, _ = env.step(action)       
        state = state_            
        G += reward*(gamma**t)
        t += 1
    return G

def sample_random(n=None):
    if n==None:
        return [np.random.randint(2) for _ in range(n_goals)]
    task = np.zeros(n_goals)
    i = random.sample(range(n_goals),random.randint(1,n))
    task[i] = 1
    return list(task)

num_runs = 1000
num_tasks = 50
data_optimal = np.zeros((num_runs,num_tasks)) 
data_composed = np.zeros((num_runs,num_tasks)) 
for i in range(num_tasks):
    print(i)
    task = sample_random(n=n_goals)
    goals = []
    for g in range(n_goals):
        if task[g]==1:
            goals.append(T_states_[g])
    env = GridWorld(goals=goals, goal_reward=goal_reward, 
                    step_reward=step_reward, slip_prob=slip_prob)
    Q,stats1 = Q_learning(env, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
    
    print('*')
    exp = task_exp(tasks_best, task)
    EQ = exp_value(values_best, exp)
    print('**')
    
    for j in range(num_runs):
        data_optimal[j,i] = evaluateQ(goals,Q,slip_prob=0.1)
        data_composed[j,i] = evaluateEQ(goals,EQ,slip_prob=0.1)

dd.io.save('data/exp2.h5', (data_optimal,data_composed))       

 

def MAX_(values):
    Q_max = defaultdict(lambda: np.zeros(env.action_space.n))
    for Q in values:
        for s in list(set(list(Q_max.keys())) | set(list(Q.keys()))):
            Q_max[s] = np.max([Q_max[s],Q[s]],axis=0)
    return Q_max

T_states_ = [[pos,pos] for pos in T_states]

goal_reward=1
step_reward=0
gamma=0.9
slip_prob=0.1
maxiter=500
epsilon=0.1
alpha=0.1

num_runs = 25
num_tasks = 50
data= np.zeros((num_runs,num_tasks)) 
for i in range(num_runs):
    print(i)
    values=[]
    for j in range(num_tasks):
        # print(j)
        task = sample_random()
        goals = []
        for g in range(n_goals):
            if task[g]==1:
                goals.append(T_states_[g])
        env = GridWorld(goals=goals, goal_reward=goal_reward, 
                        step_reward=step_reward, slip_prob=slip_prob)
        
        # delta = 0.05
        # pmin = 1/(2**n_goals)
        # if j > (np.log(delta)/np.log(1-pmin)):
        #     Q_init=MAX_(values)
        # else:
        #     Q_init=defaultdict(lambda: np.zeros(env.action_space.n)+(goal_reward/(1-gamma)))            
        
        Q_init=MAX_(values)
        Q,stats1 = Q_learning(env, Q_init=Q_init, gamma=gamma, epsilon=epsilon, alpha=alpha, maxiter=maxiter)
        data[i,j] = stats1['T']
        values.append(Q)
    dd.io.save('data/exp3.h5', data)    