import argparse
import os
import gym
import scipy.optimize
from env.navigation import navigation_env
import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from utils import *
from trpo import one_step_trpo,conjugate_gradients,trpo_step

import pickle

from copy import deepcopy

from train_trpo import model_lower,args,env,index,num_inputs, num_actions,select_action_test,select_action,compute_adavatage,task_specific_adaptation,sample_data_for_task_specific

running_state=0
with open("./check_point/"+str(args.env_name)+"_running_state_"+model_lower+".pkl",'rb') as file:
    running_state  = pickle.loads(file.read())


x_lim = 7.0
y_lim = 12.0
initial_state = [3.5,1.0]

def sample_data_for_task_specific_test(goal,obstacle,policy_net,batch_size):
    memory = Memory()
    env = navigation_env(x_lim,y_lim,initial_state,goal,obstacle)
    accumulated_raward_batch = 0
    num_episodes = 0
    for i in range(batch_size):
        state = env.reset()[0]
        state = running_state(state)

        reward_sum = 0
        for t in range(args.max_length):
            action = select_action_test(state,policy_net)
            action = action.data[0].numpy()
            next_state, reward, done, truncated, info = env.step(action)
            reward_sum += reward
            next_state = running_state(next_state)
            path_number = i

            memory.push(state, np.array([action]), path_number, next_state, reward)
            if args.render:
                env.render()
            state = next_state
            if done or truncated:
                break

        num_episodes += 1
        accumulated_raward_batch += reward_sum

    accumulated_raward_batch /= num_episodes
    batch = memory.sample()

    return batch,accumulated_raward_batch

def success_rate(goal,obstacle,initial_state,policy):
    env = navigation_env(x_lim,y_lim,initial_state,goal,obstacle)
    success_rate_list=[]
    for _ in range(50):
        state = env.reset(initial_state)[0]
        state = running_state(state)
        for t in range(args.max_length):
            action = select_action(state,policy)
            action = action.data[0].numpy()
            next_state, reward, done, truncated, info = env.step(action)
            if env.is_in_obstacle(next_state):
                success_rate_list.append(0)
                break
            if env.is_in_goal(next_state):
                success_rate_list.append(1)
                break
            if done:
                success_rate_list.append(0)
                break     
            if t==args.max_length-1:
                success_rate_list.append(0) 
            next_state = running_state(next_state)
            state = next_state 
    return np.mean(success_rate_list)


if __name__ == "__main__":

    meta_policy_net = torch.load("./check_point/"+str(args.env_name)+"_meta_policy_net_"+model_lower+".pkl")

    meta_lambda_now=args.meta_lambda
    print(meta_lambda_now)
    print(model_lower, "running_state: ",running_state.rs.n) 
    print("index: ", index)

    accumulated_raward_k_adaptation=[[],[],[],[]]
    accumulated_raward_k_adaptation2=[[],[],[],[]]
    accumulated_raward_k_adaptation3=[[],[],[],[]]
    success_rate_set=[[],[],[],[]]
    success_rate_set2=[[],[],[],[]]
    success_rate_set3=[[],[],[],[]]

    goal_set = np.loadtxt('test_tasks_goal.txt')
    obstacle_set = np.loadtxt('test_tasks_obstacle.txt')

    for task_number in range(len(goal_set)):
        goal=goal_set[task_number]
        obstacle = obstacle_set[task_number]
        print("task_number: ",task_number) 

        previous_policy_net = Policy(num_inputs, num_actions)
        for i,param in enumerate(previous_policy_net.parameters()):
            param.data.copy_(list(meta_policy_net.parameters())[i].clone().detach().data)

        for iteration_number in range(4): 
            print(torch.exp(previous_policy_net.action_log_std)) 

            _,accumulated_raward_batch=sample_data_for_task_specific_test(goal,obstacle,previous_policy_net,args.batch_size)
            successrate=success_rate(goal,obstacle,initial_state,previous_policy_net)
            batch,batch_extra,accumulated_raward_batch2=sample_data_for_task_specific(goal,obstacle,previous_policy_net,args.batch_size)
            successrate2=success_rate(goal,obstacle,initial_state,previous_policy_net)
            print("task_number: ",task_number)
            print('(adaptation {}) \tSuccess rate {:.2f}'.format(iteration_number, successrate))
            print('(adaptation {}) \tSuccess rate {:.2f}'.format(iteration_number, successrate2))

            if task_number >0:
                success_rate_set[iteration_number].append(successrate)
                success_rate_set2[iteration_number].append(successrate2)
                success_rate_set3[iteration_number].append(max(successrate,successrate2))
        
            q_values = compute_adavatage(batch,batch_extra,args.batch_size)
            q_values = (q_values - q_values.mean()) 

            task_specific_policy=Policy(num_inputs, num_actions)
            for i,param in enumerate(task_specific_policy.parameters()):
                param.data.copy_(list(previous_policy_net.parameters())[i].clone().detach().data)
            task_specific_policy=task_specific_adaptation(task_specific_policy,previous_policy_net,batch,q_values,meta_lambda_now,index)

            for i,param in enumerate(previous_policy_net.parameters()):
                param.data.copy_(list(task_specific_policy.parameters())[i].clone().detach().data)
    
    print("-----------------")
    #a0=np.array(success_rate_set[0])
    #a1=np.array(success_rate_set[1])
    #a2=np.array(success_rate_set[2])
    #a3=np.array(success_rate_set[3])
    
    #print(a0)
    #print(a0.mean())
    #print(a1)
    #print(a1.mean())
    #print(a2)
    #print(a2.mean())
    #print(a3)
    #print(a3.mean())
    #print("-----------------")
    #a0=np.array(success_rate_set2[0])
    #a1=np.array(success_rate_set2[1])
    #a2=np.array(success_rate_set2[2])
    #a3=np.array(success_rate_set2[3])
    #print(a0)
    #print(a0.mean())
    #print(a1)
    #print(a1.mean())
    #print(a2)
    #print(a2.mean())
    #print(a3)
    #print(a3.mean())
    #print("-----------------")
    #a0=np.array(success_rate_set3[0])
    #a1=np.array(success_rate_set3[1])
    #a2=np.array(success_rate_set3[2])
    #a3=np.array(success_rate_set3[3])
    #print(a0)
    #print(a0.mean())
    #print(a1)
    #print(a1.mean())
    #print(a2)
    #print(a2.mean())
    #print(a3)
    #print(a3.mean())
    a4 = [max(success_rate_set3[0][i],success_rate_set3[1][i],success_rate_set3[2][i],success_rate_set3[3][i]) for i in range(19)]
    print(np.mean(a4))