import sys
sys.path.append('.')
from adv_imitation import DiscountALAgent
from envs.CliffWalking.CliffWalking import DisCliffWalking
from envs.bandit.bandit_env import DisBandit
import numpy as np
from utils.flags import FLAGS
from utils.Logger import logger
from utils.utils import sample_dataset_from_distribution, get_optimal_policy, estimate_dis_occupancy_measure_from_data,get_next_state
from utils.envs.env_utils import set_init_state_dis
from typing import List
import os
import yaml
import cvxpy as cp
import random
EPS = 1e-8

class DisAILIQLearn(DiscountALAgent):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,gamma:float,max_num_iterations:int) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        self.n_state = dim_state
        self.n_action = dim_action
        # initialize the policy and q_function
        self.q_function = np.random.random(size=[self.dim_state,self.dim_action])
        temp = np.random.random(size=[self.dim_state,self.dim_action])
        self.policy = temp / np.sum(temp, axis=1, keepdims=True)
        self._policy = self.policy
        self.reward = np.random.random(size=[self.dim_state,self.dim_action])
        
        self._average_occupancy_measure = np.zeros(shape=(self.n_state, self.n_action), dtype=np.float32)
        self.max_num_iterations = max_num_iterations
        self.gamma = gamma
        super(DisAILIQLearn, self).__init__(self.n_state, self.n_action, gamma, max_num_iterations)
 
 
    def train_reward(self,expert_occ,policy_occ):
        self._average_occupancy_measure += policy_occ / self.max_num_iterations
        
        index = expert_occ>policy_occ
        self.reward[index] = 1
        self.reward[~index] = 0
        self._reward_function = self.reward
    def train_occ_policy(self):
        occ = cp.Variable((self.dim_state,self.dim_action),nonneg=True)
        entr = cp.entr(occ)
        s_dis = cp.sum(occ,axis=1)
        s_entr = cp.entr(s_dis)
        
        term1 = cp.sum(cp.multiply(self.reward,occ))
        term2 = cp.sum(entr)+cp.sum(s_entr)
        objective = cp.Maximize(term1+term2)
        constraints = [occ >=0 , cp.sum(occ) ==1 ]
        problem = cp.Problem(objective,constraints)
        result = problem.solve()# solver="ECOS")
        occupancy = occ.value
        self.policy = np.zeros((self.dim_state,self.dim_action))
        for i in range(self.dim_state):
            if np.all(occupancy[i]==[occupancy[i][0] for j in range(self.dim_action)]):
                self.policy[i] = (1.0/self.dim_action)*np.ones(self.dim_action)
            else:
                index = np.argmax(occupancy[i])
                self.policy[i][index] = 1.0

    def get_policy_from_occ(self):
        normalizer = np.sum(self._average_occupancy_measure, axis=1)
        policy = np.zeros(shape=[self.n_state, self.n_action], dtype=np.float32)
        for state in range(self.n_state):
            if normalizer[state] < EPS:
                policy[state] = (1.0 / self.n_action) * np.ones(shape=self.n_action, dtype=np.float32)
            else:
                policy[state] = self._average_occupancy_measure[state] / normalizer[state]

        assert np.allclose(np.sum(policy, axis=1), np.ones(self.n_state))
        return policy
    
    def train_q_function(self,expert_occupancy_measure,policy,init_state_dis,transition_prob):
        q = cp.Variable((self.dim_state,self.dim_action))
        term1 = cp.sum(cp.multiply(expert_occupancy_measure,q))
        
        # tran = np.reshape(transition_prob,[self.dim_state,self.dim_action,self.dim_state,self.dim_action])
        piq = cp.sum(cp.multiply(policy,q),axis=1)
        
        term2 = 0
        for s in range(self.dim_state):
            term2+= self.gamma*cp.sum(cp.multiply(expert_occupancy_measure,transition_prob[:,:,s]*piq[s]))
        # term2 = self.gamma*cp.sum(cp.multiply(expert_occupancy_measure,cp.sum(cp.multiply(transition_prob,piq),axis=-1)))
        
        term3 = (1-self.gamma)*cp.sum(cp.multiply(init_state_dis,piq))
        
        target = cp.Maximize(term1-term2-term3)
        constraints = [q<=10,q>=-10]
        problem = cp.Problem(target,constraints)
        result = problem.solve()
        q_function = q.value 
        self.q_function = q_function.copy()
    
    def train_q_wo_pi(self,expert_occupancy_measure,init_state_dis,transition_prob):
        gamma = self.gamma 
        q = cp.Variable((self.dim_state,self.dim_action))
        term1 = cp.sum(cp.multiply(expert_occupancy_measure,q))
        logsumexp = cp.log_sum_exp(q,axis=1)
        term2 = 0
        for s in range(self.dim_state):
            term2+=cp.sum(cp.multiply(expert_occupancy_measure,transition_prob[:,:,s]*logsumexp[s]))
        term2 = gamma*term2 
        term3 = (1-gamma)*cp.sum(cp.multiply(init_state_dis,logsumexp))
        target = cp.Maximize(term1-term2-term3)
        constraints = [q<=100,q>=-100]
        problem = cp.Problem(target,constraints)
        result = problem.solve()
        q_function = q.value 
        return q_function
    
    
    
    def train_iql_policy(self,expert_occupancy_measure,transition_prob,init_state_dis):
        policy = cp.Variable((self.dim_state,self.dim_action))
        q_function = self.q_function
        
        piq = cp.sum(cp.multiply(policy,q_function),axis=1)
        term1 = 0
        for s in range(self.dim_state):
            term1+= self.gamma*cp.sum(cp.multiply(expert_occupancy_measure,transition_prob[:,:,s]*piq[s]))
        
        
        entpi = cp.sum(cp.entr(policy),axis=1)
        term2 = 0
        for s in range(self.dim_state):
            term2 += self.gamma*cp.sum(cp.multiply(expert_occupancy_measure,transition_prob[:,:,s]*piq[s]))
        
        
        term3 = (1-self.gamma)*cp.sum(cp.multiply(init_state_dis,piq))
        
        term4 = (1-self.gamma)*cp.sum(cp.multiply(init_state_dis,entpi))
        
        target = cp.Minimize(-term1-term2-term3-term4)
        constrains = [policy>=0,cp.sum(policy,axis=1)==np.ones(shape=(self.dim_state,))]
        problem = cp.Problem(target,constrains)
        problem.solve(solver="ECOS")# verbose=True)
        self._policy = policy.value
        
        
        

def train_dis_iql():
    FLAGS.set_seed()
    FLAGS.freeze()

    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    # max_num_iterations = FLAGS.GAIL.max_num_iter_dict[FLAGS.env.id]
    # max_num_iterations = FLAGS.GAIL.max_num_iterations
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    if FLAGS.env.id == 'CliffWalking':
        tmpenv = DisCliffWalking(ns, na, 0.99, init_state_dis, optimal_action)
        transition_prob = tmpenv.transition_probability
        next_state_dataset = get_next_state(transition_prob,dataset)
        
    # dataset, uniques_states = sample_dataset_from_distribution_v2(init_state_dis, expert_policy, num_data,)
    expert_occupancy_measure = estimate_dis_occupancy_measure_from_data(ns, na, dataset)
    print("seen states:",uniques_states)
    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))
    logger.info('Missing mass: %.8f', missing_mass)

    value_errors = dict()
    expert_values = dict()
    iql_values = dict()

    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - float(1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        expert_value = env.policy_evaluation(expert_policy)
        expert_values[effective_horizon] = [expert_value]

        
        # dis_gail_agent = DisTableGAIL(ns, na, gamma, max_num_iterations)
        # dis_iql_agent = DisOnlineIQLearn(ns, na, gamma)
        max_num_iterations = FLAGS.GAIL.max_num_iterations
        dis_iql_agent = DisAILIQLearn(ns,na,gamma,max_num_iterations)
        logger.info('Begin training in effective horizon = %d', effective_horizon)

        transition_prob = env.transition_probability
        # train disailiql:
        for t in range(400):
            policy = dis_iql_agent.get_policy
            policy_occupancy_measure = env.calculate_occupancy_measure_v2(policy)
            # dis_iql_agent.train_q_function(expert_occupancy_measure,policy,init_state_dis,transition_prob)
            # dis_iql_agent.train_iql_policy(expert_occupancy_measure,transition_prob,init_state_dis)
            dis_iql_agent.train_reward(expert_occupancy_measure, policy_occupancy_measure)
            if t % FLAGS.GAIL.train_policy_freq == 0:
                # dis_iql_agent.train_policy_step(transition_prob)
                dis_iql_agent.train_soft_policy_step(transition_prob)
                # dis_iql_agent.train_iql_policy(expert_occupancy_measure,transition_prob,init_state_dis)
            if t % 200 == 0:
                policy = dis_iql_agent.get_policy
                policy_value = env.policy_evaluation(policy)
                logger.info('The policy value at iterations %d: %.4f', t, policy_value)
                occupancy_measure_loss = float(np.sum(np.abs(expert_occupancy_measure - policy_occupancy_measure)))
                logger.info('Iteration %d:, Occupancy measure distance: %.3f', t, occupancy_measure_loss)
        
        
      
        

        # iql_policy = dis_iql_agent.get_policy
        
        iql_policy = dis_iql_agent.get_policy_from_occ()
        print(iql_policy[6],iql_policy[8])
        exit()
        iql_value = env.policy_evaluation(iql_policy)
        value_error = expert_value - iql_value
        iql_values[effective_horizon] = [iql_value]
        value_errors[effective_horizon] = [value_error]
        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, IQL value: %.4f,'
                    'Value error: %.4f,', effective_horizon, gamma, expert_value, iql_value, value_error)

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'iql_evaluate.yml')
    yaml.dump(iql_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    train_dis_iql()

