import argparse
from itertools import count
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.optimize
from math import *
import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from collections import deque, namedtuple
from torch.autograd import Variable
from RL import RL_step
from utils import *
from env import dynamics
import time
np.random.seed(0)
torch.manual_seed(0)


torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True
DEVICE = torch.device("cpu") #if not torch.cuda.is_available() else torch.device("cuda")
device=DEVICE
torch.set_default_tensor_type('torch.DoubleTensor')

class Parser():
    def __init__(self,gamma=0.995,tau=0.97,l2_reg=1e-3,max_kl=1e-2,damping=1e-1,batch_size=10000,terminal_time_onetrack=1000,log_interval=1,seed=543,test_reward_iteration=1):
        self.gamma=gamma
        self.tau=tau
        self.l2_reg=l2_reg
        self.max_kl=max_kl
        self.damping=damping
        self.batch_size=batch_size
        self.log_interval=log_interval
        self.seed=seed
        self.test_reward_iteration=test_reward_iteration
        self.terminal_time_onetrack=terminal_time_onetrack

args = Parser()

goal=np.array([0.0,0.0])
obstacle=np.array([0.0,0.0])
env=dynamics(goal,obstacle)

num_inputs = env.dim_states
num_actions = env.dim_actions
x_lim=env.x_lim
y_lim=env.y_lim
#torch.manual_seed(args.seed)

class reward_estimator(nn.Module):
    def __init__(self):
        super(reward_estimator,self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(num_inputs,64),    
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,1)    
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(),lr = 0.001)
        
    def forward(self,x):
        x = x.to(torch.float64)
        return self.fc(x)
    
    def plot(self,goal,obstacle,trajectory_list,x_lim,y_lim,iteration,loss,SR,CVR):
        fig,ax=plt.subplots()
        ax.axis('scaled')
        ax.axis([0,int(x_lim),0,int(y_lim)])
        #reward_map=np.zeros((int(10*y_lim),int(10*x_lim)))
        #for x in range(int(10*x_lim)):
        #    for y in range(int(10*y_lim)):
        #        reward_map[int(10*y_lim)-1-y,x]=self.forward(torch.FloatTensor([x,y])).item()
        #max_value=np.max(reward_map)
        #min_value=np.min(reward_map)
        #print('max_value',max_value)
        #print('min_value',min_value)
        #print('mean_value',np.mean(reward_map))
        #print('reward map', reward_map)
        #reward_map=(reward_map-min_value)/(max_value-min_value)
        #plt.imshow(reward_map, cmap='viridis',extent=[0,10*int(x_lim),0,10*int(y_lim)])
        #plt.colorbar()
        goal_block=plt.Rectangle((goal[0]-0.5,goal[1]-0.5),1,1,edgecolor='tab:brown',facecolor='tab:brown')
        obstacle_block=plt.Rectangle((obstacle[0],obstacle[1]),3,1,edgecolor='tab:orange',facecolor='tab:orange')
        for j in range(len(trajectory_list)):
            trajectory=trajectory_list[j]
            x=[]
            y=[]
            print('trajectory length',len(trajectory))
            for i in range(len(trajectory)):
                x.append(trajectory[i][0])
                y.append(trajectory[i][1])
            plt.plot(x,y,'tab:blue',linewidth=2)
        plt.tick_params(left = False, right = False, labelleft = False, labelbottom = False, bottom = False)
        ax.text(0.5,2.5,str(loss),color='black',fontsize=10)
        ax.text(0.5,1.5,str(SR),color='black',fontsize=10)
        ax.text(0.5,0.5,str(CVR),color='black',fontsize=10)
        ax.add_patch(goal_block)
        ax.add_patch(obstacle_block)
        plt.savefig('reward'+str(iteration)+'.pdf',bbox_inches = 'tight',pad_inches = 0) 

class cost_estimator(nn.Module):
    def __init__(self):
        super(cost_estimator,self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(1,64),    
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,1)    
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(),lr = 0.001)
        
    def forward(self,x):
        x = x.to(torch.float64)
        return self.fc(x)

class saved_buffer(object):
    def __init__(self, memory_size: int) -> None:
        self.memory_size = memory_size
        self.buffer = deque(maxlen=self.memory_size)

    def add(self, experience) -> None:
        self.buffer.append(experience)

    def size(self):
        return len(self.buffer)

    def sample(self, batch_size: int, continuous: bool = True):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        if continuous:
            rand = random.randint(0, len(self.buffer) - batch_size)
            return [self.buffer[i] for i in range(rand, rand + batch_size)]
        else:
            indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False)
            return [self.buffer[i] for i in indexes]
        
    def calculate_reward(self,steps='NULL'):
        if steps =='NULL':
            return [self.buffer[i] for i in range(len(self.buffer))]
        else:
            return [self.buffer[i] for i in range(steps)]

    def calculate_cost(self,steps='NULL'):
        if steps =='NULL':
            return [self.buffer[i] for i in range(len(self.buffer))]
        else:
            return [self.buffer[i] for i in range(steps)]
    
    def clear(self):
        self.buffer.clear()

    def save(self, path):
        b = np.asarray(self.buffer)
        #print(b.shape)
        np.save(path, b)

    def load(self, path):
        b = np.load(path+'.npy', allow_pickle=True)
        #assert(b.shape[0] == self.memory_size)

        for i in range(b.shape[0]):
            self.add(b[i])

def select_action(state):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def update_params(batch):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net, torch.Tensor(flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
    set_flat_params_to(value_net, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    RL_step(policy_net, get_loss, get_kl, args.max_kl, args.damping)

running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10) 
Learner_buff=saved_buffer(30000)
Learner_cost_buff=saved_buffer(30000)



def SR_CVR(goal,obstacle,trajectory_list):
    violation=0.0
    success=0.0
    for i in range(len(trajectory_list)):
        trajectory=trajectory_list[i]
        for j in range(len(trajectory)):
            if trajectory[j][0]>=obstacle[0] and trajectory[j][0]<=obstacle[0]+3.0 and trajectory[j][1]>=obstacle[1] and trajectory[j][1]<=obstacle[1]+1.0:
                violation=violation+1.0
                break
            if trajectory[j][0]>=goal[0]-2.5 and trajectory[j][0]<=goal[0]+2.5 and trajectory[j][1]>=goal[1]-2.5 and trajectory[j][1]<=goal[1]+0.5:
                success=success+1.0
                break
    return (success/len(trajectory_list)), (violation/len(trajectory_list))
          

def regularization(parameters):
    a=list(parameters)
    return sum([torch.norm(a[i]) for i in range(len(a))])

SR_list=[]
CVR_list=[]

for task_number in range(59,60):
    goal_obstacle=np.loadtxt('task'+str(task_number)+'goal_obstacle.txt')
    goal=goal_obstacle[0:2]
    obstacle=goal_obstacle[2:4]
    env=dynamics(goal,obstacle)
    Expert_buff=saved_buffer(25000)
    Expert_buff.load('task'+str(task_number)+'_evaluation_set')
    a=Expert_buff.calculate_reward()
    expert_states,_,expert_actions,_,_ = zip(*a)
    expert_states=torch.FloatTensor(expert_states).to(device)

    reward_NN=reward_estimator()
    cost_NN=cost_estimator()
    reward_NN.load_state_dict(torch.load('reward_meta_prior.pt'))
    cost_NN.load_state_dict(torch.load('cost_meta_prior.pt'))
    for i in range(5):
        trajectory_list=[]
        for _ in range(2):
            policy_net = Policy(num_inputs, num_actions)
            value_net = Value(num_inputs)
            for i_episode in range(100):
                memory = Memory()
            #trajectory=[]
                num_steps = 0
                reward_batch = 0
                num_episodes = 0
                while num_steps < args.batch_size:
                    state = env.reset()
                #trajectory.append(state)
                    state = running_state(state)
                    saved_state=state
                    reward_sum = 0
                    for t in range(args.terminal_time_onetrack): # Don't infinite loop while learning
                        action = select_action(state)
                        action = action.data[0].numpy()
                        next_state, reward, done, cost_feature= env.step(action)
                #print(torch.FloatTensor(saved_state))
                        reward=reward_NN(torch.FloatTensor(saved_state)).item()
                        cost=cost_NN(torch.FloatTensor([cost_feature])).item()
                        reward=reward-cost
                #Learner_buff.add((saved_state, next_state, action, 0, done))
                    #trajectory.append(next_state)
                        reward_sum += reward
                        saved_state=next_state
                        next_state = running_state(next_state)

                        mask = 1
                        if done:
                            mask = 0

                        memory.push(state, np.array([action]), mask, next_state, reward)

                        if done:
                            break

                        state = next_state
                
                    num_steps += (t-1)
                    num_episodes += 1
                    reward_batch += reward_sum

                reward_batch /= num_episodes
                batch = memory.sample()
                update_params(batch)

                print('Episode {}\tCumulative reward: {}'.format(i_episode, reward_sum))
    #env.draw_trajectory(trajectory)

            for episodes in range(30):
       
                trajectory=[]
                state = env.reset(np.array([3.5,1.0]))
                trajectory.append(state)
                saved_state=state
                state = running_state(state)
            #restored_trajectory=saved_buffer(25000)

                reward_sum = 0
                for t in range(300): # Don't infinite loop while learning
                    action = select_action(state)
                    action = action.data[0].numpy()
                    next_state, reward, done, cost_feature = env.step(action)
                    reward=reward_NN(torch.FloatTensor(saved_state)).item()
                    Learner_buff.add((saved_state, next_state, action, 0, done))
                    Learner_cost_buff.add((cost_feature,0))
                #restored_trajectory.add((saved_state, next_state, action, 0, done))
                    trajectory.append(next_state)
                    saved_state=next_state
                    reward_sum += reward
                    next_state = running_state(next_state)
                    if done:
                        break
                    state = next_state
            #restored_trajectory.clear()
            #print('test reward:  '+str(reward_sum))
            #env.draw_trajectory(trajectory)
                trajectory_list.append(trajectory)
    #print('length of trajectory list', len(trajectory_list))
        b=Learner_buff.calculate_reward()
        learner_states,_,learner_actions,_,_ = zip(*b)
        learner_states=torch.FloatTensor(learner_states).to(device)
        learner_reward = reward_NN(learner_states).mean()

        expert_reward = reward_NN(expert_states).mean()
        loss = learner_reward - expert_reward#+0.5*regularization(reward_NN.parameters())
        #reward_NN.plot(goal,obstacle,trajectory_list,x_lim,y_lim,i,loss.item(),SR,CVR)
    #if loss<0:
    #    loss=-loss

        reward_NN.opt.zero_grad()
        loss.backward()

        b_cost=Learner_cost_buff.calculate_cost()
        learner_cost_features,_=zip(*b_cost)
        learner_cost_features=torch.FloatTensor(learner_cost_features).to(device)
        loss_cost=-cost_NN(learner_cost_features).mean()

        cost_NN.opt.zero_grad()
        loss_cost.backward()

        print('expert reward',expert_reward)
        print('learner reward',learner_reward)
        print("reward_network_update",'--------',loss)
        reward_NN.opt.step()
        print("--------------------------------------------------------------")
        Learner_buff.clear()
    SR,CVR=SR_CVR(goal,obstacle,trajectory_list)
    SR_list.append(SR)
    CVR_list.append(CVR)
    print('SR is',SR)
    print('CVR is', CVR)
    #reward_NN.plot(goal,obstacle,trajectory_list,x_lim,y_lim,task_number,loss.item(),SR,CVR)

print('SR mean:', np.mean(SR_list))
print('SR std:', np.std(SR_list))

print('CVR mean:', np.mean(CVR_list))
print('CVR std:', np.std(CVR_list))





