import sys
sys.path.append('.')
from adv_imitation import DiscountALAgent
from envs.CliffWalking.CliffWalking import DisCliffWalking
from envs.bandit.bandit_env import DisBandit
import numpy as np
from utils.flags import FLAGS
from utils.Logger import logger
from utils.utils import sample_dataset_from_distribution, get_optimal_policy, estimate_dis_occupancy_measure_from_data,get_next_state
from utils.envs.env_utils import set_init_state_dis
from typing import List
import os
import yaml
import cvxpy as cp
import random
EPS = 1e-8

class DisOfflineIQLearn(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,gamma:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.q_function = np.random.random(size=[self.dim_state,self.dim_action])
        tmpq = self.q_function - np.max(self.q_function, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        self.policy = expq / np.sum(expq, axis=1, keepdims=True)
        self.gamma = gamma
    
    def train_with_solver(self,expert_occupancy_measure:np.ndarray,dataset: np.ndarray,transition_prob:np.ndarray,init_state_dis:np.ndarray): # directly use the solver 
        # notice that transition_prob is of shape [dim_state,dim_action,dim_state], where the last dim stands for s_next
        q_variable = cp.Variable((self.dim_state,self.dim_action))
        v_variable = cp.log_sum_exp(q_variable,axis=1)
        # rho_0 = cp.sum(v_variable[dataset[:,0]]-self.gamma*v_variable[dataset[:,1]])/len(dataset)
        rho_0 = (1-self.gamma)*init_state_dis@v_variable
        
        
        term1 = cp.sum(cp.multiply(expert_occupancy_measure,q_variable))
        # implement based on matrix form
        for i in range(self.dim_state):
            term1 -= cp.sum(cp.multiply(expert_occupancy_measure *  self.gamma, transition_prob[:,:,i]*v_variable[i]))
        
        # implement based on element 
        # for i in range(self.dim_state):
        #     for j in range(self.dim_action):
        #         term1 -= expert_occupancy_measure[i,j]*self.gamma*cp.sum(cp.multiply(transition_prob[i,j],v_variable))
        objective = cp.Maximize(term1-rho_0)
        constraints = [q_variable <= 10000, q_variable >= -10000]
        problem = cp.Problem(objective,constraints)
        
        
        result = problem.solve(solver="ECOS")
        
        q = q_variable.value
        self.q_function = q
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        
        # get a stable policy
        
        self.policy = expQ/np.sum(expQ,axis=1,keepdims=True)
    
    def train_with_gradient(self,expert_occupancy_measure:np.ndarray,dataset: np.ndarray,transition_prob:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        
        for i in range(max_iterations):
            q_old = self.q_function.copy() 
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            occ_mul_tran = np.sum(expert_occupancy_measure.reshape((self.dim_state,self.dim_action,1))*transition_prob,axis=(0,1)).reshape((-1,1))
            grad = expert_occupancy_measure - self.gamma*occ_mul_tran*pi-(1-self.gamma)*init_state_dis.reshape((-1,1))*pi

            step_size = np.sqrt(np.divide(2 * self.dim_state * self.dim_action, max(max_iterations, 1)))
            step_size = np.clip(step_size, a_max=1, a_min=EPS)
            step_size=1e8
            self.q_function = np.clip(q_old + step_size*grad,a_min=-100000,a_max=100000)


        q = self.q_function
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        # get a stable policy
        self.policy = expQ/np.sum(expQ,axis=1,keepdims=True)
    
    
    def train_with_gradient_paper(self,expert_occupancy_measure:np.ndarray,dataset: np.ndarray,transition_prob:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        dist_s = np.bincount(dataset[:,0],minlength=self.dim_state)
        dist_s = dist_s/np.sum(dist_s)
        dist_s_next = np.bincount(dataset[:,1],minlength=self.dim_state)
        dist_s_next = dist_s_next/np.sum(dist_s_next)
        for i in range(max_iterations):
            q_old = self.q_function.copy() 
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            occ_mul_tran = np.sum(expert_occupancy_measure.reshape((self.dim_state,self.dim_action,1))*transition_prob,axis=(0,1)).reshape((-1,1))
            
            
            grad = expert_occupancy_measure - self.gamma*occ_mul_tran*pi-(dist_s.reshape((-1,1))*pi-self.gamma*dist_s_next.reshape((-1,1))*pi)

            step_size = np.sqrt(np.divide(2 * self.dim_state * self.dim_action, max(max_iterations, 1)))
            step_size = np.clip(step_size, a_max=1, a_min=EPS)
            step_size=1e8
            self.q_function = np.clip(q_old + step_size*grad,a_min=-100000,a_max=100000)


        q = self.q_function
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        # get a stable policy
        self.policy = expQ/np.sum(expQ,axis=1,keepdims=True)
    

class DisOnlineIQLearn(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,gamma:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.q_function = np.random.random(size=[self.dim_state,self.dim_action])
        tmpq = self.q_function - np.max(self.q_function, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        self.policy = expq / np.sum(expq, axis=1, keepdims=True)
        self.gamma = gamma
 
    
    def train_with_gradient_paper(self,expert_occupancy_measure:np.ndarray,expert_dataset: np.ndarray,transition_prob:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        expert_nums = len(expert_dataset)
        for i in range(max_iterations):
            q_old = self.q_function.copy() 
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            occ_mul_tran = np.sum(expert_occupancy_measure.reshape((self.dim_state,self.dim_action,1))*transition_prob,axis=(0,1)).reshape((-1,1))
            
            tmp,_ = sample_dataset_from_distribution(init_state_dis,pi,1000)
            policy_dataset = get_next_state(transition_prob,tmp)
            
            
            dataset = np.concatenate([expert_dataset[:1000],policy_dataset],axis=0)
            
            dist_s = np.bincount(dataset[:,0],minlength=self.dim_state)
            dist_s = dist_s/np.sum(dist_s)
            dist_s_next = np.bincount(dataset[:,1],minlength=self.dim_state)
            dist_s_next = dist_s_next/np.sum(dist_s_next)
            
            grad = expert_occupancy_measure - self.gamma*occ_mul_tran*pi-(dist_s.reshape((-1,1))*pi-self.gamma*dist_s_next.reshape((-1,1))*pi)

            step_size = np.sqrt(np.divide(2 * self.dim_state * self.dim_action, max(max_iterations, 1)))
            step_size = np.clip(step_size, a_max=1, a_min=EPS)
            step_size=1e8
            self.q_function = np.clip(q_old + step_size*grad,a_min=-100000,a_max=100000)


        q = self.q_function
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        # get a stable policy
        self.policy = expQ/np.sum(expQ,axis=1,keepdims=True)    
    
#  def _train_with_gradient_corrected(self,expert_occupancy_measure:np.ndarray,init_state_dis,max_iterations:int,h:int,transition_prob:np.ndarray):
        
    #     for i in range(max_iterations):
    #         q = self.q_function[:,:,h]
    #         stableQ = q-np.max(q,axis=1,keepdims=True) 
    #         expQ = np.exp(stableQ)
    #         pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            
    #         grad = np.zeros_like(self.q_function)
    #         if init_state_dis is None:
    #             grad[:,:,h] = expert_occupancy_measure - expert_occupancy_measure*pi
    #         else:
    #             grad[:,:,h] = expert_occupancy_measure - init_state_dis.reshape((-1,1))*pi
    #         for t in range(self.max_episode-1,-1,-1):
    #             next_q = np.zeros_like(self.q_function[:,:,t]) if t==self.max_episode-1 else self.q_function[:,:,t+1]
    #             next_v = np.zeros(shape=(self.dim_state,)) if t==self.max_episode-1 else np.log(np.sum(np.exp(next_q),axis=-1))
                
    #             g = self.q_function[:,:,t]-transition_prob@next_v
    #             if t!=0:
    #                 v = np.log(np.sum(np.exp(self.q_function[:,:,t]),axis=-1))
                    
    #                 g += -np.sum(np.expand_dims(self.q_function[:,:,t-1],axis=2)*transition_prob,axis=(0,1)).reshape((-1,1))*pi+np.sum(np.expand_dims(transition_prob@v,axis=2)*transition_prob,axis=(0,1)).reshape((-1,1))*pi
    #             grad[:,:,t] -= 2*self.lamda*g
            
    #         step_size=10
    #         self.q_function = self.q_function + step_size*grad

    # def train_with_gradient_corrected(self,expert_occupancy_measure:np.ndarray,max_iterations:int,init_state:np.ndarray,transition_prob:np.ndarray):
    #     from tqdm import tqdm 
    #     for h in tqdm(range(self.max_episode-1,0,-1)):
    #         self._train_with_gradient_corrected(expert_occupancy_measure[:,:,h],None,max_iterations,h,transition_prob)
    #         # print(self.q_function[8,:,0])
            
            
            
    #     init_state_dis = init_state
    #     self._train_with_gradient_corrected(expert_occupancy_measure[:,:,0],init_state_dis,max_iterations,0,transition_prob)
    #     for h in range(self.max_episode-1,-1,-1):
    #         q = self.q_function[:,:,h]
    #         tmpq = (q - np.max(q, axis=1, keepdims=True))/0.01
    #         expq = np.exp(tmpq)
    #         policy = expq / np.sum(expq, axis=1, keepdims=True)
    #         self.policy[:,:,h] = policy
   

def train_dis_iql():
    FLAGS.set_seed()
    FLAGS.freeze()

    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    # max_num_iterations = FLAGS.GAIL.max_num_iter_dict[FLAGS.env.id]
    # max_num_iterations = FLAGS.GAIL.max_num_iterations
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    if FLAGS.env.id == 'CliffWalking':
        tmpenv = DisCliffWalking(ns, na, 0.99, init_state_dis, optimal_action)
        transition_prob = tmpenv.transition_probability
        next_state_dataset = get_next_state(transition_prob,dataset)
        
    # dataset, uniques_states = sample_dataset_from_distribution_v2(init_state_dis, expert_policy, num_data,)
    expert_occupancy_measure = estimate_dis_occupancy_measure_from_data(ns, na, dataset)
    # print(uniques_states)
    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))
    logger.info('Missing mass: %.8f', missing_mass)

    value_errors = dict()
    expert_values = dict()
    iql_values = dict()

    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - float(1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        expert_value = env.policy_evaluation(expert_policy)
        expert_values[effective_horizon] = [expert_value]

        
        # dis_gail_agent = DisTableGAIL(ns, na, gamma, max_num_iterations)
        # dis_iql_agent = DisOnlineIQLearn(ns, na, gamma)
        dis_iql_agent = DisOfflineIQLearn(ns,na,gamma)
        logger.info('Begin training in effective horizon = %d', effective_horizon)

        transition_prob = env.transition_probability
        # dis_iql_agent.train_with_gradient_paper(expert_occupancy_measure,next_state_dataset,transition_prob,init_state_dis,1000)
        dis_iql_agent.train_with_solver(expert_occupancy_measure,dataset,transition_prob,init_state_dis)
        
        
        iql_policy = dis_iql_agent.policy
        iql_value = env.policy_evaluation(iql_policy)
        value_error = expert_value - iql_value
        iql_values[effective_horizon] = [iql_value]
        value_errors[effective_horizon] = [value_error]
        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, IQL value: %.4f,'
                    'Value error: %.4f,', effective_horizon, gamma, expert_value, iql_value, value_error)

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'iql_evaluate.yml')
    yaml.dump(iql_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    train_dis_iql()

