import sys
sys.path.append('.')
from adv_imitation import DiscountALAgent
from envs.CliffWalking.CliffWalking import DisCliffWalking,CliffWalking
from envs.bandit.bandit_env import DisBandit,Bandit
import numpy as np
from utils.flags import FLAGS
from utils.Logger import logger
from utils.utils import sample_dataset_from_distribution, get_optimal_policy, estimate_occupancy_measure_from_data,get_next_state,sample_dataset
from utils.envs.env_utils import set_init_state_dis
from typing import List
import os
import yaml
import cvxpy as cp
EPS = 1e-8

class OfflineIQLearn(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,max_episode:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.q_function = np.random.random(size=[self.dim_state, self.dim_action, self.max_episode])
        tmpq = self.q_function - np.max(self.q_function, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        self.policy = expq / np.sum(expq, axis=1, keepdims=True)
    
    def train_with_solver(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray): # directly use the solver 
        # notice that transition_prob is of shape [dim_state,dim_action,dim_state], where the last dim stands for s_next
        for h in range(self.max_episode-1,0,-1):
            q_variable = cp.Variable((self.dim_state,self.dim_action))
            v_variable = cp.log_sum_exp(q_variable,axis=1,keepdims=True)
            # rho_0 = cp.sum(v_variable[dataset[:,0]]-self.gamma*v_variable[dataset[:,1]])/len(dataset)
            target = cp.sum(cp.multiply(expert_occupancy_measure[:,:,h],q_variable-v_variable))
            objective = cp.Maximize(target)
            constraints = [q_variable <= 1000, q_variable >= -1000]
            problem = cp.Problem(objective,constraints)
            result = problem.solve(solver="ECOS")
    
            q = q_variable.value
            self.q_function[:,:,h] = q
            stableQ = q-np.max(q,axis=1,keepdims=True)
            expQ = np.exp(stableQ)
            # get a stable policy
            self.policy[:,:,h] = expQ/np.sum(expQ,axis=1,keepdims=True)
        
        # Process with Q1
        q_variable = cp.Variable((self.dim_state,self.dim_action))
        v_variable = cp.log_sum_exp(q_variable,axis=1,keepdims=True)
        target = cp.sum(cp.multiply(expert_occupancy_measure[:,:,0],q_variable)) - init_state_dis@v_variable
        objective = cp.Maximize(target)
        constraints = [q_variable <= 1000, q_variable >= -1000]
        problem = cp.Problem(objective,constraints)
        result = problem.solve(solver="ECOS")
        q = q_variable.value
        self.q_function[:,:,0] = q
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        # get a stable policy
        self.policy[:,:,0] = expQ/np.sum(expQ,axis=1,keepdims=True)
        
    
    def _train_with_gradient(self,expert_occupancy_measure:np.ndarray,init_state_dis,max_iterations:int,init_q:np.ndarray):
        q = init_q
        for i in range(max_iterations):
            q_old = q.copy()
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            
            if init_state_dis is None:
                grad = expert_occupancy_measure - expert_occupancy_measure*pi
            else:
                grad = expert_occupancy_measure - init_state_dis.reshape((-1,1))*pi
            
            step_size=100
            q = np.clip(q_old + step_size*grad,a_min=-10000,a_max=10000)
        tmpq = q - np.max(q, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        policy = expq / np.sum(expq, axis=1, keepdims=True)
        return q,policy
    
    def train_with_gradient(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        for h in range(self.max_episode-1,0,-1):
            self.q_function[:,:,h],self.policy[:,:,h] = self._train_with_gradient(expert_occupancy_measure[:,:,h],None,max_iterations,self.q_function[:,:,h])
            
        # Process with Q1
        self.q_function[:,:,0],self.policy[:,:,0] = self._train_with_gradient(expert_occupancy_measure[:,:,0],init_state_dis,max_iterations,self.q_function[:,:,0])

    
    def train_with_gradient_paper(self,expert_occupancy_measure:np.ndarray,dataset: np.ndarray,transition_prob:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        dist_s = np.bincount(dataset[:,0],minlength=self.dim_state)
        dist_s = dist_s/np.sum(dist_s)
        dist_s_next = np.bincount(dataset[:,1],minlength=self.dim_state)
        dist_s_next = dist_s_next/np.sum(dist_s_next)
        for i in range(max_iterations):
            q_old = self.q_function.copy() 
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            occ_mul_tran = np.sum(expert_occupancy_measure.reshape((self.dim_state,self.dim_action,1))*transition_prob,axis=(0,1)).reshape((-1,1))
            
            
            grad = expert_occupancy_measure - self.gamma*occ_mul_tran*pi-(dist_s.reshape((-1,1))*pi-self.gamma*dist_s_next.reshape((-1,1))*pi)

            step_size = np.sqrt(np.divide(2 * self.dim_state * self.dim_action, max(max_iterations, 1)))
            step_size = np.clip(step_size, a_max=1, a_min=EPS)
            step_size=2e8
            self.q_function = np.clip(q_old + step_size*grad,a_min=-100000,a_max=100000)


        q = self.q_function
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        # get a stable policy
        self.policy = expQ/np.sum(expQ,axis=1,keepdims=True)
    

class OnlineIQLearn(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,max_episode:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.q_function = np.random.random(size=[self.dim_state, self.dim_action, self.max_episode])
        tmpq = self.q_function - np.max(self.q_function, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        self.policy = expq / np.sum(expq, axis=1, keepdims=True)
    def _train_with_gradient(self,expert_occupancy_measure:np.ndarray,init_state_dis,max_iterations:int,init_q:np.ndarray):
        q = init_q
        for i in range(max_iterations):
            q_old = q.copy()
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            
            if init_state_dis is None:
                grad = expert_occupancy_measure - expert_occupancy_measure*pi
            else:
                grad = expert_occupancy_measure - init_state_dis.reshape((-1,1))*pi
            
            step_size=100
            q = np.clip(q_old + step_size*grad,a_min=-10000,a_max=10000)
        tmpq = q - np.max(q, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        policy = expq / np.sum(expq, axis=1, keepdims=True)
        return q,policy
    
    
    def train_with_gradient_paper(self,expert_occupancy_measure:np.ndarray,expert_dataset: np.ndarray,max_iterations:int,env):

        for h in range(self.max_episode-1,0,-1):
            self.q_function[:,:,h],self.policy[:,:,h] = self._train_with_gradient(expert_occupancy_measure[:,:,h],None,max_iterations,self.q_function[:,:,h])
            
        # Process with Q1
        policy_dataset = np.array(sample_dataset(env, self.policy, 2000, is_deterministic=False))
        indices = np.random.choice(len(expert_dataset),size=2000,replace=False)
        dataset = np.concatenate([expert_dataset[indices],policy_dataset],axis=0)
        dataset = dataset[dataset[:,2]==0][:,0]
        init_state_dis = np.bincount(dataset,minlength=self.dim_state)
        init_state_dis = init_state_dis/np.sum(init_state_dis)
        self.q_function[:,:,0],self.policy[:,:,0] = self._train_with_gradient(expert_occupancy_measure[:,:,0],init_state_dis,max_iterations,self.q_function[:,:,0])
        
    

def train_iql():
    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.num_data_dict[FLAGS.env.id]

    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)

    value_errors = dict()
    expert_values = dict()
    iql_values = dict()

    for max_episode_steps in range(200, 3400, 200):
        if FLAGS.env.id == 'CliffWalking':
            env = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        elif FLAGS.env.id == 'Bandit':
            env = Bandit(ns, na, init_state_dis, max_episode_steps)
        else:
            raise ValueError('The env {} is not supported.'.format(FLAGS.env.id))

        expert_policy = env.get_optimal_policy()
        expert_value = env.policy_evaluation(expert_policy)
        dataset = sample_dataset(env, expert_policy, num_data, is_deterministic=False)

        expert_occupancy = estimate_occupancy_measure_from_data(ns,na,max_episode_steps,dataset)

        iql_agent = OnlineIQLearn(ns,na,max_episode_steps)
        
        
        iql_agent.train_with_gradient_paper(expert_occupancy,np.array(dataset),200,env)
        # import copy
        # p1 = copy.deepcopy(iql_agent.policy)
        # iql_agent.train_with_solver(expert_occupancy,init_state_dis)
        # p2 = iql_agent.policy
        # print(np.sum(p1-p2))
        iql_policy = iql_agent.policy
        iql_value = env.policy_evaluation(policy=iql_policy)
        value_error = expert_value - iql_value

        logger.info('Max episode steps: %d, Expert value: %.4f, IQL value: %.4f, Value error: %.4f',
                    max_episode_steps, expert_value, iql_value, value_error)

        expert_values[max_episode_steps] = [expert_value]
        iql_values[max_episode_steps] = [iql_value]
        value_errors[max_episode_steps] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'iql_evaluate.yml')
    yaml.dump(iql_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)
    
if __name__ == '__main__':
    train_iql()

