import numpy as np
import random
import os
import hashlib
import itertools
from collections import defaultdict
import torch.multiprocessing as mp

from sympy.logic import SOPform
from sympy import symbols as Symbols
import boolean

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

import gym
from envs.env import *

use_cuda = torch.cuda.is_available()
device = torch.device('cuda') if use_cuda else torch.device('cpu')
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor


class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def __call__(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)

def to_hash(x):
    hash_x = hashlib.md5(x.tostring()).hexdigest()
    return hash_x

class ReplayBuffer(object):
    def __init__(self, maxsize, batch_size=32, rmin=0, rmax=1, gamma=0.95):
        """Create Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self.size = 0
        self._maxsize = maxsize
        self.batch_size = batch_size
        self._next_idx = 0
        self.rmin = rmin
        self.rmax = rmax
        self.rbarmin = rmin#/(1-gamma)

    def add(self, obs_t, action, reward, obs_tp1, done):
        data = (obs_t, action, reward, obs_tp1, done)   
        # Init replay buffer if empty:
        if self.size==0:
            self._storage = [(obs_t, action, reward, obs_tp1 , done) 
                            for _ in range(self._maxsize)]
            # self._storage = [(obs_t.copy(), action, reward, obs_tp1.copy() , done) 
            #                 for _ in range(self._maxsize)]
            print("Replay buffer initialised")
        
        # add data to replay buffer
        self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self._maxsize
        self.size = min(self.size+1,self._maxsize)
            
    def sample_transitions(self):
        return np.random.randint(0,len(self._storage),self.batch_size) 
    
    def sample(self, goal=None, features = None, n_step=10, gamma=0.95):
        """Sample a batch of experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        Returns
        -------
        obs_goal__batch: np.array
            batch of observations per goals
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_goal_batch: np.array
            next set of observations seen after executing act_batch per goals
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
        
        transitions = self.sample_transitions()
        for t in transitions:
            # G = 0
            # obs_t, action, _, _, _ = self._storage[t]
            # for i in range(t,t+n_step):
            #     _, _, reward, obs_tp1, done = self._storage[i%self._maxsize]
            #     if goal:
            #         reward = self.rbarmin if (done and to_hash(obs_tp1) != goal) else reward
            #     G += (gamma**(i-t))*reward
            #     if done:
            #         break
            # reward = G

            obs_t, action, reward, obs_tp1, done = self._storage[t]            
            if goal:
                reward = self.rbarmin if (done and to_hash(obs_tp1) != goal) else reward
            if features and done:
                reward = features[goal][to_hash(obs_tp1)]
            
            obses_t.append(np.array(obs_t, copy=False))
            actions.append(np.array(action, copy=False))
            rewards.append(reward)
            obses_tp1.append(np.array(obs_tp1, copy=False))
            dones.append(done)
        return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)



class DQN(nn.Module):
    def __init__(self, env):
        super(DQN, self).__init__()
        self.action_space = env.action_space
        self.observation_space = env.observation_space

        l, _, _ = self.observation_space.shape
        k1,s1 = (8,4); k2,s2 = (4,2); k3,s3 = (3,1); c_out = 64
        self.image_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=k1, stride=s1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=k2, stride=s2),
            nn.ReLU(),
            nn.Conv2d(64, c_out, kernel_size=k3, stride=s3),
            nn.ReLU()
        )
        f = lambda l,k,s: (l-k)//s + 1
        self.embedding_size = f(f(f(l,k1,s1),k2,s2),k3,s3)**2*c_out
        # print('embedding_size: ', self.embedding_size)

        self.linear1 = nn.Linear(self.embedding_size, 512)
        self.head = nn.Linear(512, self.action_space.n)

    def forward(self, obs):
        obs = obs.permute(0, 3, 1, 2)
        x = self.image_conv(obs)
        x = x.reshape(x.size(0), -1)        
        x = F.relu(self.linear1(x))
        x = self.head(x)
        return x


class SFDQN(nn.Module):
    def __init__(self, dqns, weights=None):
        super(SFDQN, self).__init__()
        self.dqns = dqns
        self.weights = weights
    
    def forward(self, obs_goal):
        qs = [self.weights[i]*self.dqns[i](obs_goal) for i in range(len(self.dqns))]
        qs = torch.stack(tuple(qs), 0)
        q = qs.sum(0)[0]        
        return q
    
class ComposedDQN(nn.Module):
    def __init__(self, dqns, max_evf=None, min_evf=None, compose="or"):
        super(ComposedDQN, self).__init__()
        self.compose = compose
        self.max_evf = max_evf
        self.min_evf = min_evf
        self.dqns = dqns
    
    def forward(self, obs_goal):
        qs = [self.dqns[i](obs_goal) for i in range(len(self.dqns))]
        qs = torch.stack(tuple(qs), 0)
        # vs = qs.squeeze(1).max(1)[0]
        if self.compose=="or":
            q = qs.max(0)[0]
            # q = qs[vs.max(0)[1]]
        elif self.compose=="and":
            q = qs.min(0)[0]
            # q = qs[vs.min(0)[1]]
        else: #not
            q = (self.max_evf(obs_goal)+self.min_evf(obs_goal)) - qs[0]
            # qmax = self.max_evf(obs_goal)
            # qmin = self.min_evf(obs_goal)
            # q = torch.where(torch.abs(qs[0]-qmin)<torch.abs(qs[0]-qmax),qmax,qmin)
        
        return q

def GPISF(models, weights):
    return {goal: SFDQN([model[goal] for model in models], [weight[goal] for weight in weights]) for goal in models[0].keys()}

def OR(models):
    return {goal: ComposedDQN([model[goal] for model in models], compose="or") for goal in models[0].keys()}

def AND(models):
    return {goal: ComposedDQN([model[goal] for model in models], compose="and") for goal in models[0].keys()}

def NOT(model, max_evf, min_evf):
    return {goal: ComposedDQN([model[goal]], max_evf=max_evf[goal], min_evf=min_evf[goal], compose="not") for goal in model.keys()}



def save(path, model):
    torch.save(model, path)
    
def load(path, env):
    model = torch.load(path, map_location=device)
    for goal in model:
        dqn = DQN(env)
        dqn.load_state_dict(model[goal])
        model[goal] = dqn
        if use_cuda:
            model[goal].cuda()
    return model

def select_goal(model,obs,training=False):
    goals = list(model.keys())
    idx = np.random.randint(len(goals))
    if not training:
        obs = torch.from_numpy(obs).type(FloatTensor).unsqueeze(0)
        with torch.no_grad():  
            values = []
            for goal in goals:
                values.append(model[goal](obs).squeeze(0))
            values = torch.stack(values,0)
            idx = values.data.max(1)[0].max(0)[1].item()
            # print('values per goal, number of goals, best goal: \n', values.data.max(1)[0],len(values),idx)
    return goals[idx]

def select_action(model,obs,goal=None):
    goal = goal if goal != None else select_goal(model,obs)
    obs = torch.from_numpy(obs).type(FloatTensor).unsqueeze(0)
    action = model[goal](obs).squeeze(0).max(0)[1].item()
    return action

def update_td_loss(args):
    goal, sample, edqn, target_edqn, sup_model, optimizer, gamma, maxiter = args   
    loss = 0 
    for _ in range(maxiter):
        obs_batch, act_batch, rew_batch, next_obs_batch, done_mask = sample(goal)
        obs_batch = torch.from_numpy(obs_batch).type(FloatTensor)
        act_batch = torch.from_numpy(act_batch).type(LongTensor)
        rew_batch = torch.from_numpy(rew_batch).type(FloatTensor)
        next_obs_batch = torch.from_numpy(next_obs_batch).type(FloatTensor)
        not_done_mask = torch.from_numpy(1 - done_mask).type(FloatTensor)
        
        if sup_model:
            current_q_values = edqn(obs_batch).squeeze()
            target_q_values = sup_model(obs_batch).detach().squeeze()
        else:
            current_q_values = edqn(obs_batch).gather(1, act_batch.unsqueeze(1)).squeeze()
            next_max_q = target_edqn(next_obs_batch).detach().max(1)[0]
            next_q_values = not_done_mask * next_max_q
            target_q_values = rew_batch + (gamma * next_q_values)

        loss += F.smooth_l1_loss(current_q_values, target_q_values)

    optimizer.zero_grad()
    loss.backward()
    for params in edqn.parameters():
        params.grad.data.clamp_(-1, 1)
    optimizer.step()  
    
    return loss.item()

class Agent(object):
    def __init__(self,
                 env,
                 path='',
                 goal=None,
                 ngoals=16,
                 gamma=0.95,
                 learning_rate=1e-4,
                 replay_buffer=None):

        self.env = env
        self.path = path
        self.ngoals = ngoals
        self.gamma = gamma
        self.learning_rate = learning_rate
        # self.replay_buffer = replay_buffer

        self.goals = {}
        self.edqn = {}
        self.target_edqn = {}
        self.optimizer = {} 
        self.replay_buffer = {}
        self.dqns = []
        self.target_dqns = []
        self.optimizers = [] 
        self.replay_buffers = []
        for i in range(ngoals):
            dqn = DQN(self.env)
            target_dqn = DQN(self.env)
            target_dqn.load_state_dict(dqn.state_dict())
            if use_cuda:
                dqn.cuda()
                target_dqn.cuda()
            self.dqns.append(dqn)
            self.target_dqns.append(target_dqn)  
            self.optimizers.append( optim.Adam(dqn.parameters(), lr=self.learning_rate) )
            self.replay_buffers.append(ReplayBuffer(replay_buffer._maxsize,
                                            replay_buffer.batch_size, rmin=replay_buffer.rmin))
        
        self.add(goal)

    def add(self, goal):
        goals = list(self.edqn.keys())
        if goal not in goals:
            self.goals[goal] = len(goals)
            self.edqn[goal] = self.dqns[len(goals)]
            self.target_edqn[goal] = self.target_dqns[len(goals)]
            self.optimizer[goal] = self.optimizers[len(goals)]
            self.replay_buffer[goal] = self.replay_buffers[len(goals)]
    
    def update_td_loss(self, supervised=None):
        maxiter = 1
        goals = list(self.edqn.keys())
        # goals = [goals[np.random.randint(len(goals))] for _ in range(16)]
        
        args = []
        for goal in goals:
            sup_model=None
            if supervised:
                task, task_, model_, max_evf = supervised
                sup_model = model_[goal] if task[goal] == task_[goal] else NOT(model_, max_evf)[goal]
            arg = (goal, self.replay_buffer[goal].sample, self.edqn[goal], self.target_edqn[goal], sup_model, self.optimizer[goal], self.gamma, maxiter)
            args.append(arg)
            update_td_loss(arg)        
        # losses = list(mp.Pool(processes=4).map(update_td_loss,args))

    def update_target_network(self):
        for goal in self.edqn.keys():
            self.target_edqn[goal].load_state_dict(self.edqn[goal].state_dict())

    def save(self):
        data = {goal: self.edqn[goal].state_dict() for goal in self.edqn.keys()}
        save(self.path, data)


class AgentSF(object):
    def __init__(self,
                 env,
                 path='',
                 goal=None,
                 ngoals=16,
                 gamma=0.95,
                 learning_rate=1e-4,
                 replay_buffer=None):

        self.env = env
        self.path = path
        self.ngoals = ngoals
        self.gamma = gamma
        self.learning_rate = learning_rate
        # self.replay_buffer = replay_buffer

        self.goals = {}
        self.edqn = {}
        self.target_edqn = {}
        self.optimizer = {} 
        self.replay_buffer = {}
        self.dqns = []
        self.target_dqns = []
        self.optimizers = [] 
        self.replay_buffers = []
        for i in range(ngoals):
            dqn = DQN(self.env)
            target_dqn = DQN(self.env)
            target_dqn.load_state_dict(dqn.state_dict())
            if use_cuda:
                dqn.cuda()
                target_dqn.cuda()
            self.dqns.append(dqn)
            self.target_dqns.append(target_dqn)  
            self.optimizers.append( optim.Adam(dqn.parameters(), lr=self.learning_rate) )
            self.replay_buffers.append(ReplayBuffer(replay_buffer._maxsize,
                                            replay_buffer.batch_size, rmin=replay_buffer.rmin))
        
        self.add(goal)

    def add(self, goal):
        goals = list(self.edqn.keys())
        if goal not in goals:
            self.goals[goal] = len(goals)
            self.edqn[goal] = self.dqns[len(goals)]
            self.target_edqn[goal] = self.target_dqns[len(goals)]
            self.optimizer[goal] = self.optimizers[len(goals)]
            self.replay_buffer[goal] = self.replay_buffers[len(goals)]
    
    def update_td_loss(self, sfeatures, features, weights):
        maxiter = 1
        dims = list(self.edqn.keys())
        # goals = [goals[np.random.randint(len(goals))] for _ in range(16)]
        
        for d in dims:                
            obs_batch, act_batch, rew_batch, next_obs_batch, done_mask = self.replay_buffer[d].sample(d,features=features)
            obs_batch = torch.from_numpy(obs_batch).type(FloatTensor)
            act_batch = torch.from_numpy(act_batch).type(LongTensor)
            rew_batch = torch.from_numpy(rew_batch).type(FloatTensor)
            next_obs_batch = torch.from_numpy(next_obs_batch).type(FloatTensor)
            not_done_mask = torch.from_numpy(1 - done_mask).type(FloatTensor)
            
            target_values = {}
            value = []
            for i in weights:
                v = self.target_edqn[i](next_obs_batch).detach()
                value.append(v.squeeze(0)*weights[i])
                target_values[i] = v
            value = torch.stack(value,0).sum(0)
            target_actions = value.max(1)[1]
            
            current_q_values = self.edqn[d](obs_batch).gather(1, act_batch.unsqueeze(1)).squeeze()
            next_max_q = target_values[d].gather(1, act_batch.unsqueeze(1)).squeeze()
            next_q_values = not_done_mask * next_max_q
            target_q_values = rew_batch + (self.gamma * next_q_values)

            loss = F.smooth_l1_loss(current_q_values, target_q_values)

            self.optimizer[d].zero_grad()
            loss.backward()
            for params in self.edqn[d].parameters():
                params.grad.data.clamp_(-1, 1)
            self.optimizer[d].step()  

    def update_target_network(self):
        for goal in self.edqn.keys():
            self.target_edqn[goal].load_state_dict(self.edqn[goal].state_dict())

    def save(self):
        data = {goal: self.edqn[goal].state_dict() for goal in self.edqn.keys()}
        save(self.path, data)


def evaluate_sop(args):
    env, model, task, learned, max_task, min_task, max_evf, min_evf, eval_type, num_episodes, max_episode_timesteps = args
    # max_episode_timesteps = 50

    tasks, values = learned
    exp = task_exp(tasks, task)
    task_ = exp_task(tasks, max_task, min_task, exp) 
    model_ = exp_value(values, max_evf, min_evf, exp) 
    if task!=task_:
        model_ = model
    
    if eval_type=="transfer":
        model_ = exp_value(values, max_evf, min_evf, exp) 
    elif eval_type=="continual":
        model_ = model
    
    returns, successes = (0, 0)
    for _ in range(num_episodes):
        obs = env.reset()
        goal = select_goal(model_, obs)
        for t in range(max_episode_timesteps):
            action = select_action(model_, obs, goal)
            new_obs, reward, done, info = env.step(action)
            obs = new_obs
            returns+=reward
            successes+=(reward>0)+0
            if done:
                break
    return [returns, successes]

def evaluate_sf(args):
    env, model, weights, learned, eval_type, num_episodes, max_episode_timesteps = args
    # max_episode_timesteps = 50

    _, sfeatures = learned   
    dims = list(model.keys())
    if eval_type=="continual":
        sfeatures = [model]
        dims = [0]
    
    returns, successes = (0, 0)
    for _ in range(num_episodes):
        obs = env.reset()
        for t in range(max_episode_timesteps):
            obs_ = torch.from_numpy(obs).type(FloatTensor).unsqueeze(0)
            with torch.no_grad():  
                values = []
                for d in dims:
                    value = [sfeatures[d][i](obs_).squeeze(0)*weights[i] for i in weights]
                    value = torch.stack(value,0).sum(0)
                    values.append(value)
                values = torch.stack(values,0)
                action = values.data.max(0)[0].max(0)[1].item()
                
            new_obs, reward, done, info = env.step(action)
            obs = new_obs
            returns+=reward
            successes+=(reward>0)+0
            if done:
                break
    return [returns, successes]

def evaluate_dqn(args):
    env, model, task, num_episodes, max_episode_timesteps = args
    # max_episode_timesteps = 50

    returns, successes = (0, 0)
    for _ in range(num_episodes):
        obs = env.reset()
        goal = select_goal(model, obs)
        for t in range(max_episode_timesteps):
            action = select_action(model, obs, goal)
            new_obs, reward, done, info = env.step(action)
            obs = new_obs
            returns+=reward
            successes+=(reward>0)+0
            if done:
                break
    return [returns, successes]

def train_sop(env,
            learned=None,
            test=False,
            env_key="BabyAI-PickupDistCustom-v0",
            goal = None,
            num_dists=None,
            tile_size=8,
            path='./data/logs',
            agents_models=None,
            load_models=False,
            save_models=False,
            save_logs=False,
            max_episodes=int(1e6),
            learning_starts=int(1e3),
            replay_buffer_size=int(5e5),
            train_freq=4,
            target_update_freq=int(1e3),
            batch_size=32,
            gamma=0.99,
            learning_rate=1e-4,
            eps_initial=1,
            eps_final=0.05,
            eps_success=0.99,
            timesteps_success = 100,
            eval_interval = 100,
            eval_type = None,
            mean_episodes=100,
            eps_timesteps=int(1e5),
            print_freq=10):
    
    # Initialising epsilon schedule
    eps_schedule = LinearSchedule(eps_timesteps, eps_final, eps_initial)

    # Initialise replay buffer
    replay_buffer = ReplayBuffer(replay_buffer_size,
                                        batch_size, rmin=env.rmin)

    # Initialise agents/models               
    vgoal = to_hash(env.reset()*0)
    agent = Agent(env, gamma=gamma, learning_rate=learning_rate, path=path, replay_buffer=replay_buffer, goal=vgoal)
    model = agent.edqn
    task = defaultdict(int)

    tasks, values = learned
    min_task, min_evf = tasks['a'], values['a']
    max_task, max_evf = defaultdict(int), {}
    max_evf = OR(list(values.values()))
    for goal in tasks['a'].keys():
        agent.add(goal)
        max_task[goal] = max([tasks[i][goal] for i in tasks])

    # Train 
    eval_returns = []
    episode_returns = []
    episode_successes = []
    steps = 0
    for episode in range(max_episodes):
        obs = env.reset()
        
        exp = task_exp(tasks, task)
        task_ = exp_task(tasks, max_task, min_task, exp)
        model_ = exp_value(values, max_evf, min_evf, exp)
        if task!=task_:
            model_ = OR([model_,model])
        if eval_type=="transfer":
            model_ = model
        elif eval_type=="continual":
            model_ = OR([model_,model])

        goal = select_goal(model_, obs, training=False)
        
        episode_returns.append(0.0)
        episode_successes.append(0.0)
        done = False
        t = 0
        eps = []
        while not done and t < timesteps_success:
            # Collect experience            
            if random.random() > eps_schedule(steps):
                action = select_action(model_, obs, goal)
            else:
                action = env.action_space.sample()
                
            new_obs, reward, done, _ = env.step(action)
            eps.append((obs, action, reward, new_obs, done))
            agent.replay_buffer[goal].add(obs, action, reward, new_obs, done)
            if steps < learning_starts:
                for g in agent.replay_buffer.keys():
                    agent.replay_buffer[g].add(obs, action, reward, new_obs, done)

            episode_returns[-1] += (gamma**t)*reward
            episode_successes[-1] = (t<timesteps_success)*(reward>0)

            # Update goals    
            if done:
                new_obs_hash = to_hash(new_obs)
                agent.add(new_obs_hash)
                task[new_obs_hash] = 0 + ((task[new_obs_hash] + learning_rate*(int(reward>0) - task[new_obs_hash])) > 0)

                for e in eps:
                    obs, action, reward, new_obs, done = e
                    agent.replay_buffer[new_obs_hash].add(obs, action, reward, new_obs, done)
                                    
            t += 1
            steps += 1
            obs = new_obs

        if not eval_type:
            # Update agent   
            if steps > learning_starts and task!=task_:
                agent.update_td_loss()
                if steps % target_update_freq == 0:
                    agent.update_target_network()
        elif eval_type=="continual":
            # Update agent   
            if steps > learning_starts:
                agent.update_td_loss()
                if steps % target_update_freq == 0:
                    agent.update_target_network() 
        
        if steps > learning_starts and episode % eval_interval == 0:
            print("evaluating ...")
            args = [env,
                    model,
                    task,
                    learned, max_task, min_task, max_evf, min_evf,
                    eval_type,
                    mean_episodes,
                    timesteps_success]
            returns, successes = evaluate_sop(args)
            avg_return, success_rate = (returns/mean_episodes, successes/mean_episodes)
            eval_returns.append(avg_return)
                        
        # Print training progress
        if print_freq is not None and episode % print_freq == 0:
            avg_return = round(np.mean(episode_returns[-mean_episodes:]), 2)
            success_rate = round(np.mean(episode_successes[-mean_episodes:]), 2)
            print("--------------------------------------------------------")
            print("steps {}".format(steps))
            print("episodes {}".format(episode))
            print("goals {}".format(len(list(agent.edqn.keys()))))
            print("exp {}".format(exp))
            print("task : {}".format(list(task.values())))
            print("task_: {}".format(list(task_.values())))
            print("task==task_: {}".format(task==task_))
            print("return      : {}, eval {}".format(avg_return, ([0]+eval_returns)[-1]))
            print("success rate: {}".format(success_rate))
            print("exploration {}".format( eps_schedule(steps)))
            print("--------------------------------------------------------")
            if save_logs:
                stats = {'episode_returns':episode_returns, 'episode_successes':episode_successes, 'eval_returns':eval_returns}
                torch.save(stats, save_logs)
    
    return task, model, {'episode_returns':episode_returns, 'episode_successes':episode_successes, 'eval_returns':eval_returns}

def train_sf(env,
            learned=None,
            test=False,
            env_key="BabyAI-PickupDistCustom-v0",
            goal = None,
            num_dists=None,
            tile_size=8,
            path='./data/logs',
            agents_models=None,
            load_models=False,
            save_models=False,
            save_logs=False,
            max_episodes=int(1e6),
            learning_starts=int(1e3),
            replay_buffer_size=int(5e5),
            train_freq=4,
            target_update_freq=int(1e3),
            batch_size=32,
            gamma=0.99,
            learning_rate=1e-4,
            eps_initial=1,
            eps_final=0.05,
            eps_success=0.99,
            timesteps_success = 100,
            eval_interval = 100,
            eval_type = None,
            mean_episodes=100,
            eps_timesteps=int(1e5),
            print_freq=10):
    
    # Initialising epsilon schedule
    eps_schedule = LinearSchedule(eps_timesteps, eps_final, eps_initial)

    # Initialise replay buffer
    replay_buffer = ReplayBuffer(replay_buffer_size,
                                        batch_size, rmin=env.rmin)

    # Initialise agents/models               
    vgoal = '0'
    agent = AgentSF(env, gamma=gamma, learning_rate=learning_rate, path=path, replay_buffer=replay_buffer, goal=vgoal)
    model = agent.edqn
    weights = defaultdict(int)

    features, sfeatures = learned
    for i in features.keys():
        agent.add(i)
        weights[i] = weights[i]

    # Train 
    eval_returns = []
    episode_returns = []
    episode_successes = []
    steps = 0
    for episode in range(max_episodes):
        obs = env.reset()
        
        if eval_type=="continual":
            sfeatures[str(len(sfeatures))] = model
        # model_ = GPISF(sfeatures, weights)
        
        episode_returns.append(0.0)
        episode_successes.append(0.0)
        done = False
        t = 0
        eps = []
        while not done and t < timesteps_success:
            # Collect experience            
            dims = list(model.keys())
            if random.random() > eps_schedule(steps):
                obs_ = torch.from_numpy(obs).type(FloatTensor).unsqueeze(0)
                with torch.no_grad():  
                    values = []
                    for d in sfeatures.keys():
                        value = [sfeatures[d][i](obs_).squeeze(0)*weights[i] for i in weights]
                        value = torch.stack(value,0).sum(0)
                        values.append(value)
                    values = torch.stack(values,0)
                    action = values.data.max(0)[0].max(0)[1].item()
                    d = values.data.max(1)[0].max(0)[1].item()
                    
                # action = select_action(model_, obs)
            else:
                action = env.action_space.sample()
                d = np.random.randint(len(model))
            d = dims[d]
                
            new_obs, reward, done, _ = env.step(action)
            eps.append((obs, action, reward, new_obs, done))
            agent.replay_buffer[d].add(obs, action, reward, new_obs, done)
            if steps < learning_starts:
                for d in agent.replay_buffer.keys():
                    agent.replay_buffer[d].add(obs, action, reward, new_obs, done)

            episode_returns[-1] += (gamma**t)*reward
            episode_successes[-1] = (t<timesteps_success)*(reward>0)

            # Update goals    
            if done:            
                new_obs_hash = to_hash(new_obs)

                fw_ = sum([features[i][new_obs_hash]*weights[i] for i in weights])
                for i in weights:
                    weights[i] = weights[i] + learning_rate*(reward - fw_)*features[i][new_obs_hash]
                                    
            t += 1
            steps += 1
            obs = new_obs

        if eval_type=="continual":
            # Update agent   
            if steps > learning_starts:
                agent.update_td_loss(sfeatures, features, weights)
                if steps % target_update_freq == 0:
                    agent.update_target_network()
        
        # if steps > learning_starts and episode % eval_interval == 0:
        #     print("evaluating ...")
        #     args = [env,
        #             model,
        #             weights,
        #             learned,
        #             eval_type,
        #             mean_episodes,
        #             timesteps_success]
        #     returns, successes = evaluate_sf(args)
        #     avg_return, success_rate = (returns/mean_episodes, successes/mean_episodes)
        #     eval_returns.append(avg_return)
                        
        # Print training progress
        if print_freq is not None and episode % print_freq == 0:
            avg_return = round(np.mean(episode_returns[-mean_episodes:]), 2)
            success_rate = round(np.mean(episode_successes[-mean_episodes:]), 2)
            print("--------------------------------------------------------")
            print("steps {}".format(steps))
            print("episodes {}".format(episode))
            print("dim {}".format(len(list(agent.edqn.keys()))))
            print("weights : {}".format(list(weights.values())))
            print("return      : {}, eval {}".format(avg_return, ([0]+eval_returns)[-1]))
            print("success rate: {}".format(success_rate))
            print("exploration {}".format( eps_schedule(steps)))
            print("--------------------------------------------------------")
            if save_logs:
                stats = {'episode_returns':episode_returns, 'episode_successes':episode_successes, 'eval_returns':eval_returns}
                torch.save(stats, save_logs)
    
    return task, model, {'episode_returns':episode_returns, 'episode_successes':episode_successes, 'eval_returns':eval_returns}

def train_dqn(env,
            test=False,
            env_key="BabyAI-PickupDistCustom-v0",
            goal = None,
            num_dists=None,
            tile_size=8,
            path='./data/logs',
            agents_models=None,
            load_models=False,
            save_models=False,
            save_logs=False,
            max_episodes=int(1e6),
            learning_starts=int(1e3),
            replay_buffer_size=int(5e5),
            train_freq=4,
            target_update_freq=int(1e3),
            batch_size=32,
            gamma=0.99,
            learning_rate=1e-4,
            eps_initial=1.0,
            eps_final=0.05,
            eps_success=0.99,
            timesteps_success = 100,
            eval_interval = 100,
            mean_episodes=100,
            eps_timesteps=int(1e5),
            print_freq=10):
    
    # Initialising epsilon schedule
    eps_schedule = LinearSchedule(eps_timesteps, eps_final, eps_initial)

    # Initialise replay buffer
    replay_buffer = ReplayBuffer(replay_buffer_size,
                                        batch_size, rmin=env.rmin)

    # Initialise agents/models               
    agent = Agent(env, gamma=gamma, learning_rate=learning_rate, path=path, replay_buffer=replay_buffer, goal=goal)
    model = agent.edqn
    task = defaultdict(int)

    # Train 
    eval_returns = []
    episode_returns = []
    episode_successes = []
    steps = 0
    for episode in range(max_episodes):
        obs = env.reset()
        goal = select_goal(model, obs, training=False)
        
        episode_returns.append(0.0)
        episode_successes.append(0.0)
        done = False
        t = 0
        while not done and t < timesteps_success:
            # Collect experience            
            if test or random.random() > eps_schedule(steps):
                action = select_action(model, obs, goal)
            else:
                action = env.action_space.sample()
                
            new_obs, reward, done, _ = env.step(action)
            agent.replay_buffer[goal].add(obs, action, reward, new_obs, done)

            episode_returns[-1] += reward
            episode_successes[-1] = (t<timesteps_success)*(reward>0)
                        
            t += 1
            steps += 1
            obs = new_obs
        
        # Update agent
        if steps > learning_starts:
            agent.update_td_loss()
            if steps % target_update_freq == 0:
                agent.update_target_network()
            
        if steps > learning_starts and episode % eval_interval == 0:
            print("evaluating ...")
            args = [env,
                    model,
                    task,
                    mean_episodes,
                    timesteps_success]
            returns, successes = evaluate_dqn(args)
            avg_return, success_rate = (returns/mean_episodes, successes/mean_episodes)
            eval_returns.append(avg_return)
                        
        # Print training progress
        if done and print_freq is not None and episode % print_freq == 0:
            avg_return = round(np.mean(episode_returns[-mean_episodes-1:-1]), 2)
            success_rate = round(np.mean(episode_successes[-mean_episodes-1:-1]), 2)
            print("--------------------------------------------------------")
            print("steps {}".format(steps))
            print("episodes {}".format(episode))
            print("return      : {}, eval {}".format(avg_return, ([0]+eval_returns)[-1]))
            print("success rate: {}".format(success_rate))
            print("exploration {}".format( eps_schedule(steps)))
            print("--------------------------------------------------------")
            if save_logs:
                stats = {'episode_returns':episode_returns, 'episode_successes':episode_successes, 'eval_returns':eval_returns}
                torch.save(stats, save_logs)
    
    return task, model, {'episode_returns':episode_returns, 'episode_successes':episode_successes, 'eval_returns':eval_returns}


######################################
def get_bases(n_goals):
    n_goals+=1
    bases = []
    n=int(np.ceil(np.log2(n_goals)))
    m=(2**n)/2
    for i in range(n):
        bases.append([])
        b=False
        for j in range(0,2**n):
            if j>=n_goals:
                break
            if b:
                bases[i].append(1) #1=True=rmax
            else:
                bases[i].append(0) #0=False=rmin
            if (j+1)%m==0:
                if b:
                    b=False
                else:
                    b=True
        m=m/2
    return np.array(bases)[:,1:]
    
# def task_exp(tasks, task, symbols=None): # POS
#     symbols = symbols if symbols else list(tasks.keys())
#     exp = ''
#     goals = tasks['a'].keys()
#     for goal in goals:
#         if not task[goal]:
#             exp += '('
#             for j in symbols:
#                 if j != 'a':
#                     if not tasks[j][goal]:
#                         exp += j
#                     else:
#                         exp += '~' + j
#                     exp += '|'
#             exp = exp[:-1]
#             exp += ')'
#             exp += '&'
#     exp = exp[:-1]
#     exp = exp if exp else ('a')
#     return exp

# def exp_task(tasks, exp): # POS
#     exp = exp.replace("(", "").replace(")", "")
    
#     task = defaultdict(int)
#     goals = tasks['a'].keys()
#     for goal in goals:
#         task[goal] = 1
#         for e1 in exp.split('&'):
#             b = 0
#             for e2 in e1.split('|'):
#                 j = e2[-1]
#                 if '~' not in e2:
#                     b += (tasks[j][goal]>0)
#                 else:
#                     b += 1-(tasks[j][goal]>0)
#                 if b:
#                     break
#             if not b:
#                 task[goal] = b
#                 break
#     return task


# def task_exp(tasks, task, symbols=None): # SOP
#     symbols = symbols if symbols else list(tasks.keys())
#     symbols.remove('a')
#     exp = ''
#     goals = tasks['a'].keys()
#     for goal in goals:
#         if task[goal]:
#             exp += '('
#             for j in symbols:
#                 if tasks[j][goal]:
#                     exp += j
#                 else:
#                     exp += '~' + j
#                 exp += '&'
#             exp = exp[:-1]
#             exp += ')'
#             exp += '|'
#     exp = exp[:-1]
#     exp = exp if exp else ('a')
#     return boolean.BooleanAlgebra().parse(exp).simplify()

def task_exp(tasks, task, symbols=None): # SOP
    symbols = symbols if symbols else list(tasks.keys())
    symbols.remove('a')
    goals = tasks['a'].keys()
    dontcares = []
    # for t in range(len(goals), 2**len(symbols)):
    #     dontcare = list(map(int,bin(t)[2:]))
    #     dontcare = ([0]*(len(symbols)-len(dontcare))) + dontcare
    #     print(dontcare)
    #     dontcares.append(dontcare)
    minterms = []
    for goal in goals:
        if task[goal]:
            minterm = [ tasks[j][goal] for j in symbols]
            minterms.append(minterm)                
    exp = SOPform(Symbols(symbols), minterms, dontcares)
    exp = exp if exp else ('a')
    return str(exp)

def exp_task(tasks, max_task, min_task, exp): # SOP
    exp = exp.replace("(", "").replace(")", "").replace(" ", "")

    task = defaultdict(int)
    goals = tasks['a'].keys()
    for goal in goals:
        task[goal] = 0
        for e1 in exp.split('|'):
            b = 1
            for e2 in e1.split('&'):
                j = e2[-1]
                if '~' not in e2:
                    b *= (tasks[j][goal]>0)
                else:
                    b *= (max_task[goal]+min_task[goal])-(tasks[j][goal]>0)
                if not b:
                    break
            if b:
                task[goal] = 1
                break
    return task

def exp_value(values, max_evf, min_evf, exp):    
    algebra = boolean.BooleanAlgebra()
    exp = algebra.parse(exp).simplify()
     
    def convert_(exp):
        if type(exp) == boolean.Symbol:
            compound = values[str(exp)]
        elif type(exp) == boolean.OR:
            compound = convert_(exp.args[0])
            for sub in exp.args[1:]:
                compound = OR([compound, convert_(sub)])
        elif type(exp) == boolean.AND:
            compound = convert_(exp.args[0])
            for sub in exp.args[1:]:
                compound = AND([compound, convert_(sub)])
        else:
            compound = convert_(exp.args[0])
            compound = NOT(compound, max_evf, min_evf)
        return compound
    
    model = convert_(exp)
    return model

def sample_random(n_goals):
    task = np.zeros(n_goals,dtype=int)
    i = random.sample(range(n_goals),random.randint(0,n_goals))
    task[i] = 1
    return list(task)
    # return [np.random.randint(2) for _ in range(n_goals)]
#############################################################################################