import sys
sys.path.append('.')
from adv_imitation import DiscountALAgent
from envs.CliffWalking.CliffWalking import DisCliffWalking,CliffWalking
from envs.bandit.bandit_env import DisBandit,Bandit
import numpy as np
from utils.flags import FLAGS
from utils.Logger import logger
from utils.utils import sample_dataset_from_distribution, get_optimal_policy, estimate_occupancy_measure_from_data,get_next_state,sample_dataset,sample_dataset_per_traj
from utils.envs.env_utils import set_init_state_dis
from typing import List
import os
import yaml
import cvxpy as cp
import torch 
from tqdm import tqdm 
import argparse
EPS = 1e-8
INF=1e8
class OfflineIQLearn(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,max_episode:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.q_function = np.random.random(size=[self.dim_state, self.dim_action, self.max_episode])
        tmpq = self.q_function - np.max(self.q_function, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        self.policy = expq / np.sum(expq, axis=1, keepdims=True)
    
    def train_with_solver(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray): # directly use the solver 
        # notice that transition_prob is of shape [dim_state,dim_action,dim_state], where the last dim stands for s_next
        for h in range(self.max_episode-1,0,-1):
            q_variable = cp.Variable((self.dim_state,self.dim_action))
            v_variable = cp.log_sum_exp(q_variable,axis=1,keepdims=True)
            # rho_0 = cp.sum(v_variable[dataset[:,0]]-self.gamma*v_variable[dataset[:,1]])/len(dataset)
            target = cp.sum(cp.multiply(expert_occupancy_measure[:,:,h],q_variable-v_variable))
            objective = cp.Maximize(target)
            constraints = [q_variable <= 1000, q_variable >= -1000]
            problem = cp.Problem(objective,constraints)
            result = problem.solve(solver="ECOS")
    
            q = q_variable.value
            self.q_function[:,:,h] = q
            # stableQ = q-np.max(q,axis=1,keepdims=True)
            # expQ = np.exp(stableQ)
            # get a stable policy
            
            # self.policy[:,:,h] = expQ/np.sum(expQ,axis=1,keepdims=True)
        
        # Process with Q1
        q_variable = cp.Variable((self.dim_state,self.dim_action))
        v_variable = cp.log_sum_exp(q_variable,axis=1,keepdims=True)
        target = cp.sum(cp.multiply(expert_occupancy_measure[:,:,0],q_variable)) - init_state_dis@v_variable
        objective = cp.Maximize(target)
        constraints = [q_variable <= 1000, q_variable >= -1000]
        problem = cp.Problem(objective,constraints)
        result = problem.solve(solver="ECOS")
        q = q_variable.value
        self.q_function[:,:,0] = q
        # stableQ = q-np.max(q,axis=1,keepdims=True)
        # expQ = np.exp(stableQ/0.1)
        # # get a stable policy
        # self.policy[:,:,0] = expQ/np.sum(expQ,axis=1,keepdims=True)
        self.policy = get_policy_fromq(self.q_function)
        
    
    def _train_with_gradient(self,expert_occupancy_measure:np.ndarray,init_state_dis,max_iterations:int,init_q:np.ndarray):
        q = init_q
        for i in range(max_iterations):
            q_old = q.copy()
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            
            if init_state_dis is None:
                grad = expert_occupancy_measure - expert_occupancy_measure*pi
            else:
                grad = expert_occupancy_measure - init_state_dis.reshape((-1,1))*pi
            
            step_size=100
            q = np.clip(q_old + step_size*grad,a_min=-10000,a_max=10000)
        tmpq = q - np.max(q, axis=1, keepdims=True)
        expq = np.exp(tmpq)
        policy = expq / np.sum(expq, axis=1, keepdims=True)
        return q,policy
    
    def train_with_gradient(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        for h in range(self.max_episode-1,0,-1):
            self.q_function[:,:,h],self.policy[:,:,h] = self._train_with_gradient(expert_occupancy_measure[:,:,h],None,max_iterations,self.q_function[:,:,h])
            
        # Process with Q1
        self.q_function[:,:,0],self.policy[:,:,0] = self._train_with_gradient(expert_occupancy_measure[:,:,0],init_state_dis,max_iterations,self.q_function[:,:,0])

    
    def train_with_gradient_paper(self,expert_occupancy_measure:np.ndarray,dataset: np.ndarray,transition_prob:np.ndarray,init_state_dis:np.ndarray,max_iterations:int):
        dist_s = np.bincount(dataset[:,0],minlength=self.dim_state)
        dist_s = dist_s/np.sum(dist_s)
        dist_s_next = np.bincount(dataset[:,1],minlength=self.dim_state)
        dist_s_next = dist_s_next/np.sum(dist_s_next)
        for i in range(max_iterations):
            q_old = self.q_function.copy() 
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            occ_mul_tran = np.sum(expert_occupancy_measure.reshape((self.dim_state,self.dim_action,1))*transition_prob,axis=(0,1)).reshape((-1,1))
            
            
            grad = expert_occupancy_measure - self.gamma*occ_mul_tran*pi-(dist_s.reshape((-1,1))*pi-self.gamma*dist_s_next.reshape((-1,1))*pi)

            step_size = np.sqrt(np.divide(2 * self.dim_state * self.dim_action, max(max_iterations, 1)))
            step_size = np.clip(step_size, a_max=1, a_min=EPS)
            step_size=2e8
            self.q_function = np.clip(q_old + step_size*grad,a_min=-100000,a_max=100000)


        q = self.q_function
        stableQ = q-np.max(q,axis=1,keepdims=True)
        expQ = np.exp(stableQ)
        # get a stable policy
        self.policy = expQ/np.sum(expQ,axis=1,keepdims=True)

def get_policy_fromq(q):
    tmpq = q - np.max(q, axis=1, keepdims=True)
    expq = np.exp(tmpq/0.5)
    policy = expq / np.sum(expq, axis=1, keepdims=True)
    return policy

  
class OnlineIQLFix():
    def __init__(self, dim_state:int, dim_action:int,max_episode:float, left=-1,right=1,penalty=10) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.left = left 
        self.right = right 
        self.penalty = penalty 
        self.q_function = torch.nn.Parameter(torch.randn(size=[self.dim_state, self.dim_action, self.max_episode]))
        self.optimizer = torch.optim.SGD([self.q_function],lr=200)
        # self.q_function = -np.ones(shape=[self.dim_state, self.dim_action, self.max_episode])
       
        self.policy = get_policy_fromq(self.q_function.detach().numpy())

    def train_with_torch(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray,transition_prob:np.ndarray):

        expert_occupancy_measure = torch.tensor(expert_occupancy_measure,dtype=torch.float32)
        init_state_dis = torch.tensor(init_state_dis,dtype=torch.float32)
        transition_prob = torch.tensor(transition_prob,dtype=torch.float32)
        for _ in tqdm(range(400)):
            loss = 0
            for h in range(self.max_episode):
                loss-=torch.sum(expert_occupancy_measure[:,:,h]*self.q_function[:,:,h])
                if h!=0:
                    v = torch.logsumexp(self.q_function[:,:,h],dim=1)
                    loss += torch.sum(v.unsqueeze(1)*expert_occupancy_measure[:,:,h])
                else:
                    v0 = torch.logsumexp(self.q_function[:,:,h],dim=1)
                    loss += init_state_dis@v0 
                
                if h!=self.max_episode-1:
                    nextv = torch.logsumexp(self.q_function[:,:,h+1],dim=1)
                    reward = self.q_function[:,:,h]-transition_prob@nextv
                    
                else:
                    reward = self.q_function[:,:,h]
                bellman_restrict = self.penalty*(torch.relu(self.left-reward)**2+torch.relu(reward-self.right)**2).mean()
                loss+=bellman_restrict
                
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        self.policy = get_policy_fromq(self.q_function.detach().numpy())



class OnlineIQLearn(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,max_episode:float, reward_scale = 1.0, ent_reg_coef=1.0) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.reward_scale = reward_scale
        self.ent_reg_coef = ent_reg_coef
        self.q_function = np.random.random(size=[self.dim_state, self.dim_action, self.max_episode])
        # self.q_function = -np.ones(shape=[self.dim_state, self.dim_action, self.max_episode])
        tmpq = self.q_function - np.max(self.q_function, axis=1, keepdims=True)
        expq = np.exp(tmpq / self.ent_reg_coef)
        self.policy = expq / np.sum(expq, axis=1, keepdims=True)
  
    def _train_with_gradient(self,expert_occupancy_measure:np.ndarray,init_state_dis,max_iterations:int,init_q:np.ndarray):
        q = init_q
        for i in range(max_iterations):
            q_old = q.copy()
            stableQ = q_old-np.max(q_old,axis=1,keepdims=True) 
            expQ = np.exp(stableQ)
            pi = expQ/np.sum(expQ,axis=1,keepdims=True)
            
            if init_state_dis is None:
                grad = expert_occupancy_measure - expert_occupancy_measure*pi
            else:
                grad = expert_occupancy_measure - init_state_dis.reshape((-1,1))*pi
            
            step_size=100
            q = np.clip(q_old + step_size*grad,a_min=-1,a_max=10)
        policy = get_policy_fromq(q)
        return q,policy
    

    def train_with_gradient_paper(self,expert_occupancy_measure:np.ndarray,expert_dataset: np.ndarray,max_iterations:int,env,init_state:np.ndarray):

        for h in range(self.max_episode-1,0,-1):
            self.q_function[:,:,h],self.policy[:,:,h] = self._train_with_gradient(expert_occupancy_measure[:,:,h],None,max_iterations,self.q_function[:,:,h])
            
        init_state_dis = init_state
        self.q_function[:,:,0],self.policy[:,:,0] = self._train_with_gradient(expert_occupancy_measure[:,:,0],init_state_dis,max_iterations,self.q_function[:,:,0])
            
    @property
    def get_policy(self):
        return self.policy.copy()
    
    @property
    def get_q_functions(self):
        return self.q_function.copy()
    
class OnlineIQLearnChi(object):
    def __init__(self, dim_state:int, dim_action:int,max_episode:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.q_function = torch.nn.Parameter(torch.randn(size=[self.dim_state, self.dim_action, self.max_episode]))
        self.optimizer = torch.optim.SGD([self.q_function],lr=1)
        self.policy = get_policy_fromq(self.q_function.detach().numpy())
    
    def train_with_torch(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray,transition_prob:np.ndarray):
        expert_occupancy_measure = torch.tensor(expert_occupancy_measure)
        init_state_dis = torch.tensor(init_state_dis)
        transition_prob = torch.tensor(transition_prob)
        for i in tqdm(range(200)):
            loss = 0
            for h in range(self.max_episode):
                loss -= torch.sum(expert_occupancy_measure[:,:,h]*self.q_function[:,:,h])
                if h!=0:
                    v = torch.logsumexp(self.q_function[:,:,h],dim=1)
                    loss += torch.sum(v.unsqueeze(1)*expert_occupancy_measure[:,:,h])
                else:
                    v0 = torch.logsumexp(self.q_function[:,:,h],dim=1)
                    loss += init_state_dis@v0 

                if h!=self.max_episode-1:
                    nextv = torch.logsumexp(self.q_function[:,:,h+1],dim=1)
                    bellman_restrict = torch.sum(0.01*expert_occupancy_measure[:,:,h]*(self.q_function[:,:,h]-transition_prob@nextv)**2)
                else:
                    bellman_restrict = torch.sum(0.01*expert_occupancy_measure[:,:,h]*(self.q_function[:,:,h])**2)
                loss+=bellman_restrict

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        self.policy = get_policy_fromq(self.q_function.detach().numpy())
    @property
    def get_policy(self):
        return self.policy.copy()
    
    @property
    def get_q_functions(self):
        return self.q_function.copy()


def get_maxpolicy(q):
    ns,na,H = q.shape
    action = np.argmin(q,axis=1)
    policy = np.zeros(shape=(ns,na,H))
    for i in range(ns):
        for j in range(H):
            policy[i,action[i][j],j] = 1
    return policy

class ValueDice(object):# infinite horizon discounted
    def __init__(self, dim_state:int, dim_action:int,max_episode:float) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        # initialize the policy and q_function
        self.max_episode = max_episode 
        self.v_function = torch.nn.Parameter(torch.randn(size=[self.dim_state, self.dim_action, self.max_episode]))
        self.optimizer = torch.optim.SGD([self.v_function],lr=100)
        self.policy = get_maxpolicy(self.v_function.detach().numpy())
        self.gamma = 1.0-1.0/max_episode
    
    def train_with_torch(self,expert_occupancy_measure:np.ndarray,init_state_dis:np.ndarray,transition_prob:np.ndarray):
        expert_occupancy_measure = torch.tensor(expert_occupancy_measure)
        init_state_dis = torch.tensor(init_state_dis)
        transition_prob = torch.tensor(transition_prob)
       
        for i in tqdm(range(200)):
            loss = 0
            item1 = 0
            for h in range(self.max_episode):
                v = self.v_function[:,:,h]
                if h==self.max_episode-1:
                    nextv = torch.zeros(size=(self.dim_state,self.dim_action))
                    item1 += torch.sum(expert_occupancy_measure[:,:,h]*torch.exp(v))
                else:
                    nextv = self.v_function[:,:,h+1]
                    item1 += torch.sum(expert_occupancy_measure[:,:,h]*torch.exp(v-0.1*transition_prob@(torch.min(nextv,dim=1)[0])))
                    
            item2 = (1-self.gamma)*init_state_dis@(torch.min(self.v_function[:,:,0],dim=1)[0])
            loss = torch.log(item1)-item2 
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            # print(self.v_function[-2,:,2])
        self.policy = get_maxpolicy(self.v_function.detach().numpy())
    @property
    def get_policy(self):
        return self.policy.copy()
    
    @property
    def get_q_functions(self):
        return self.q_function.copy()
                
class AIL:
    def __init__(self, n_state: int, n_action: int, max_episode_steps: int,  diameter=10.0)\
            -> None:
        self.n_state = n_state
        self.n_action = n_action
        self.max_episode_steps = max_episode_steps
        
        
        # diameter is the infinity norm of the reward function. It seems that a larger diameter makes the optimization
        # easier.
        self._diameter = diameter
        self._policy = np.full(shape=(self.n_state, self.n_action, self.max_episode_steps),
                               fill_value=float(1.0 / self.n_action),
                               dtype=np.float64
                               )
        # self._reward_function = np.zeros(shape=(self.n_state, self.n_action, self.max_episode_steps), dtype=np.float64)

        tmp_reward_function = np.random.uniform(low=0.0, high=1.0,
                                                  size=(self.n_state, self.n_action, self.max_episode_steps))
        normalizer = np.sum(tmp_reward_function, axis=1, keepdims=True)
        self._reward_function = tmp_reward_function / normalizer

        self._average_occupancy_measure = np.zeros(shape=(self.n_state, self.n_action, self.max_episode_steps),
                                                   dtype=np.float64)
        self._total_occupancy_measure = np.zeros(shape=(self.n_state, self.n_action, self.max_episode_steps),
                                                   dtype=np.float64)

        # Record the grad norm of OGD so far to determine the step size.
        self._grad_norm = 0.0
        
        self.q_functions = np.random.random(size=(self.n_state,self.n_action,self.max_episode_steps))
    
    
    @property        
    def get_policy(self):
        return self._policy.copy()
    
    @property
    def policy(self):
        return self._policy.copy()
    
    
    def train_policy_step(self,transition_prob:np.ndarray):
        M, N, H = self.n_state, self.n_action, self.max_episode_steps
        V_functions = np.zeros((M, H+1))
        Q_functions = np.zeros((M, N, H))
        for h in range(self.max_episode_steps-1,-1,-1):
            V_next = V_functions[:,h+1]
            q = self._reward_function[:,:,h]+transition_prob@V_next
            qmax = np.max(q,axis=1)
            v = qmax+np.log(np.sum(np.exp(q-qmax.reshape((-1,1))),axis=-1))
            
            V_functions[:,h] = v 
            Q_functions[:,:,h] = q 
        
        
        self.q_functions = Q_functions
        self._policy = get_policy_fromq(Q_functions) # self.get_soft_policy_fromq(Q_functions)
            
            
    def train_reward_function(self,expert_occupancy_measure:np.ndarray,policy_occupancy_measure:np.ndarray):
        self._reward_function = self._projected_gradient_descent(expert_occupancy_measure, policy_occupancy_measure)
    
    def _projected_gradient_descent(self, expert_occupancy_measure: np.ndarray,
                                    policy_occupancy_measure: np.ndarray):

        grad = policy_occupancy_measure - expert_occupancy_measure
        # update the grad norm
        self._grad_norm = np.sqrt(self._grad_norm**2 + float(np.sum(np.square(grad))))

        # adaptive step size
        # step_size = np.sqrt(np.divide(2.0 * self.max_episode_steps * self.n_state * self.n_action, self._grad_norm**2))
        # step_size = np.sqrt(
        #     np.divide(self.n_state * self.n_action, self.max_num_iterations))
        step_size = np.sqrt(
            np.divide(self.max_episode_steps * self.n_state * self.n_action, self._grad_norm ** 2))
        step_size = np.clip(step_size, a_max=INF, a_min=EPS)
        old_reward_function = self._reward_function
        reward_function = np.clip(old_reward_function - step_size * grad, a_max=self._diameter, a_min=-self._diameter)

        return reward_function
    


       
    

def train_iql(seed,algorithm):
    FLAGS.seed = seed 
    FLAGS.algorithm = algorithm
    FLAGS.set_seed()
    FLAGS.freeze()
    # FLAGS.finalize()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    # ns,na = 2,2
    num_data = FLAGS.num_data_dict[FLAGS.env.id]
    num_data = num_data

    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns) # ,dis_type='Uniform')
    
    value_errors = dict()
    expert_values = dict()
    iql_values = dict()

    for max_episode_steps in range(2800,3400,200):
        if FLAGS.env.id == 'CliffWalking':
            env = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        elif FLAGS.env.id == 'Bandit':
            env = Bandit(ns, na, init_state_dis, max_episode_steps)
        else:
            raise ValueError('The env {} is not supported.'.format(FLAGS.env.id))

        expert_policy = env.get_optimal_policy()
        expert_value = env.policy_evaluation(expert_policy)
        # sample by the number of the trajectories
        dataset = sample_dataset(env, expert_policy, num_data, is_deterministic=False)
        # print(len(dataset))
        # dataset = sample_dataset_per_traj(env,expert_policy,10,False)
        
        unique_states = np.unique([i[0] for i in dataset])
        print("Seen states are:",unique_states)
        
        
        expert_occupancy = estimate_occupancy_measure_from_data(ns,na,max_episode_steps,dataset)
        
        transition_prob = env.transition_probability

        if algorithm=="BC":
            iql_agent = OfflineIQLearn(ns,na,max_episode_steps)
            iql_agent.train_with_solver(expert_occupancy,init_state_dis)
        elif algorithm=="IQL":
            iql_agent = OnlineIQLearn(ns,na,max_episode_steps)
            iql_agent.train_with_gradient_paper(expert_occupancy,None,200,None,init_state_dis)
            # iql_agent.train_with_solver(expert_occupancy,transition_prob,init_state_dis)
        elif algorithm=="IQLChi":
            iql_agent = OnlineIQLearnChi(ns,na,max_episode_steps)
            iql_agent.train_with_torch(expert_occupancy,init_state_dis,transition_prob)
        elif algorithm=="TVAIL":
            iql_agent = AIL(ns,na,max_episode_steps,5)
            max_iterations = 200
            for i in tqdm(range(max_iterations)):
                policy = iql_agent.get_policy
                policy_occupancy = env.calculate_occupancy_measure(policy)
                iql_agent.train_reward_function(expert_occupancy,policy_occupancy)
                
                iql_agent.train_policy_step(transition_prob)
        elif algorithm=="ValueDice":
            iql_agent = ValueDice(ns,na,max_episode_steps)
            iql_agent.train_with_torch(expert_occupancy,init_state_dis,transition_prob)
        elif algorithm=="IQLFix":
            iql_agent = OnlineIQLFix(ns,na,max_episode_steps,-0.5,0.5,0.01)
            iql_agent.train_with_torch(expert_occupancy,init_state_dis,transition_prob)
        else:
            raise ValueError
            
        iql_policy = iql_agent.policy
        iql_value = env.policy_evaluation(policy=iql_policy)
        value_error = expert_value - iql_value

        logger.info(f'Max episode steps: %d, Expert value: %.4f, {algo} value: %.4f, Value error: %.4f',
                    max_episode_steps, expert_value, iql_value, value_error)
        
        expert_values[max_episode_steps] = [expert_value]
        iql_values[max_episode_steps] = [iql_value]
        value_errors[max_episode_steps] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'iql_evaluate.yml')
    yaml.dump(iql_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility"
    )
    parser.add_argument(
        "--algo",
        type=str,
        default="BC",
        help="Algorithm to run"
    )
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    seed = args.seed 
    algo = args.algo
    print(f"train: {seed}, {algo}")
    train_iql(seed,algo)

