from numpy import linalg
from numpy.lib.function_base import extract
import gym
import random
import numpy as np
from copy import deepcopy
from numpy.core.numeric import Inf
import torch
import torch.nn as nn
from torch import Tensor
import torch.optim as opt
from torch.types import Storage

from replay_buffer_new_environment import ReplayBuffer_imitation
import envs
import time
FLUSH = True
def reset_all_rollouts(venv):
        state_dict = venv.reset()
        initial_o = state_dict['observation']
        initial_ag = state_dict['achieved_goal']
        g = state_dict['desired_goal']
        return initial_o, initial_ag,g

def splid(args, network,env,goal_dim,state_dim, device,TestNetworking = False,writer = None):
    sr1 = []
    num_inputs = goal_dim + state_dim
    print("DIMS",goal_dim,state_dim,flush = FLUSH)
    print(f"Environment is {env}",flush = FLUSH)
    def select_action(action_mean, action_logstd, fctr):
        """
        given mean and std, sample an action from normal(mean, std)
        also returns probability of the given chosen
        """
        action_std = torch.exp(action_logstd) * fctr
        action = torch.normal(action_mean, action_std)
        
        return action

    def eval_policy_50(fctr_used, args, network, device,plot = False):
        # env = gym.make(args.env_name)
        if plot :
            buffer = {'g':[],'ag':[]}
        reward_sum = 0
        succ_game = 0
        for display_i in range(50):  
            state_, initial_ag,goal = reset_all_rollouts(env)
            #state =  env.observation(state_)
            state = np.concatenate(
                (state_, goal),axis = 1)  # state_extended
            episode = []
            env_list = []
            Succ_in_env = 0
            if plot:
                buffer['g'] = goal
                buffer['ag'].append(initial_ag)
            for t in range(args.max_step_per_round):
                network.eval()
                action_mean, action_logstd, value = network(
                    Tensor(state).unsqueeze(0).to(device))
                action_mean = action_mean.detach()
                action_logstd = action_logstd.detach()
                value = value.detach()

                action = select_action(action_mean, action_logstd, fctr_used)
                action = torch.clamp(action, -1, 1)
                action = action.data.cpu().numpy()[0]
                
                state_dict_new, reward, done, info = env.step(action)
                next_state = state_dict_new["observation"]
                goal = state_dict_new['desired_goal']
                if plot:
                    buffer['ag'].append(state_dict_new['achieved_goal'])
                thres_goal = 0.05
                Succ_in_env = np.sum(np.array([i.get('is_success', 0.0) for i in info]))
                if Succ_in_env == 1:
                    break
                

                reward_sum += reward

                mask = 0 if done else 1

                if done:
                    break
                state = next_state
                state = np.concatenate(
                (state, goal),axis = 1)  # state_extended
            succ_game += Succ_in_env
        if plot:
            return succ_game/50, reward_sum /50, buffer
        return succ_game / 50, reward_sum / 50
    def compute_cross_ent_error(batch_size, step_num):
        if ier_buffer.lenth(step_num) == 0:
            return None
        if batch_size > ier_buffer.lenth(step_num):
            return None
        state, action , t_,goal= ier_buffer.sample(batch_size, step_num)
        
        state = np.array(state)
        goal = np.array(goal)
        state = np.concatenate((state,goal),axis = -1)
        state = torch.FloatTensor(state).to(device)
        action_target = torch.FloatTensor(action).to(device)
        action_pred = model_imitation(state)[0]

        loss_func = nn.MSELoss()
        loss = loss_func(action_pred, action_target)
        optimizer_imitation.zero_grad()
        loss.backward()
        optimizer_imitation.step()
        return loss.item()

    
    
    def test_isvalid_multistep(step_lenth, state_start_dict, env,args):
        env_tim = env
        
        goal = state_start_dict['desired_goal']
        if args.env_name == 'Reacher-v2':
            env.envs[0].env.env.goal = goal.reshape((-1))
        state_tim = np.concatenate((state_start_dict['observation'],goal),axis = 1)
        for step_i in range(step_lenth):
            action_tim_mean, action_tim_logstd, value_tim = network(
                Tensor(state_tim).unsqueeze(0).to(device))
            #print(action_tim_mean)
            action_tim_mean = torch.clamp(action_tim_mean, -1, 1)
            
            action_tim = action_tim_mean.cpu().data.numpy()[0]
            #print(action_tim)
            next_state_dict, reward, done, _ = env_tim.step(action_tim)
            
            next_state_tim = np.concatenate((next_state_dict['observation'],goal),axis = 1)
            
            rwd_dis = linalg.norm(next_state_dict['achieved_goal'] -  goal)
            
            
            if rwd_dis <= 0.05:
                
                if step_i <= step_lenth - 1:
                    return 1  # should not learn
                else:
                    return 0  # ok to learn
            state_tim = next_state_tim
        
        
        return 2  # learnable
    

    Horizon_list = [i + 1 for i in range(int(args.Horizon_max))]
    
    Acceptance_rate = []
    FACTOR = args.factor
    
    
    model_imitation = network
    
    if not (goal_dim + state_dim == num_inputs ):
        print("error",flush = FLUSH)
        print(goal_dim,state_dim,num_inputs,flush = FLUSH)
    

    optimizer_imitation = opt.RMSprop(model_imitation.parameters(),
                                      lr=args.lr_hid)

    reward_record = []
    loss_train = []
    global_steps = 0
    reward1 = []
    ier_buffer = ReplayBuffer_imitation(args.replay_buffer_size_IER)
    ier_buffer2 = ReplayBuffer_imitation(2000)
    for i_episode in range(args.num_episode):
        episodic_pass_test_num = 0
        num_steps = 0
        reward_list = []
        len_list = []
        Succ_num = 0
        game_num = 0
        succ_game = 0
        Ret_2 = [0 * _ for _ in range(len(Horizon_list))]
        Ret_1 = [0 * _ for _ in range(len(Horizon_list))]
        Ret_0 = [0 * _ for _ in range(len(Horizon_list))]
        
        while num_steps < args.batch_size:
            '''interactions'''
            
            game_num += 1
            
            state_dict = env.reset()
            
            reward_sum = 0
            episode = []
            env_list = []
            Succ_in_env = 0
        
            for t in range(args.max_step_per_round):
                state = np.concatenate((state_dict['observation'],state_dict['desired_goal']),axis = 1)
                action_mean, action_logstd, value = network(
                    Tensor(state).unsqueeze(0).to(device))
                action, logproba = network.select_action(action_mean,
                                                         action_logstd,
                                                         factor=FACTOR)
                action = torch.clamp(action, -1., 1.)
                action = action.cpu().data.numpy()[0]
                logproba = logproba.cpu().data.numpy()[0]
           
                if len(Horizon_list) >= 1 :
                    env_list.append(deepcopy(env))
                state_dict_new, reward, done, info = env.step(action)
                if reward == 0:
                    Succ_in_env = 1
                    reward = args.reward_pos
                    Succ_num += 1
                
                reward_sum += reward
                mask = 0 if done else 1

                episode.append(
                    (state_dict,value, action, logproba, mask, state_dict_new, reward))
                if done:
                    break
                
                state_dict = state_dict_new
            succ_game += Succ_in_env
            
            '''start learning'''
            for ind, (state_dict, value, action, logproba, mask, next_state_dict,
                      reward) in enumerate(episode):
                '''supervised learning'''
                for t_ in Horizon_list:
                        try:
                            episode[t_ + ind]
                        except:
                            break
                        target_state_dict_ = deepcopy(episode[t_ + ind][0])
                        state_dict_ = deepcopy(state_dict)

                        state_dict_['desired_goal'] = deepcopy(target_state_dict_['achieved_goal'])
                        rwd_temp3 = np.linalg.norm(target_state_dict_['achieved_goal'] - state_dict_['achieved_goal'])
                        
                        if rwd_temp3 > 0.05:
                            
                            if args.NOTTESTFUNCTION is True:
                                
                                ier_buffer.push(state_dict_['observation'], action, '1step',t_,state_dict_['desired_goal'])
                                episodic_pass_test_num += 1
                                Ret_2[t_ - 1] += 1
                                
                            else:
                                ret_tim = test_isvalid_multistep(t_, state_dict_, env_list[ind],args)
                    
                                
                                if ret_tim == 2:
                                    ier_buffer.push(state_dict_['observation'], action, '1step',t_,state_dict_['desired_goal'])
                                    episodic_pass_test_num += 1
                                    Ret_2[t_ - 1] += 1
                                elif ret_tim == 1:
                                    Ret_1[t_ - 1] += 1
                                else:
                                    Ret_0[t_ - 1] += 1
                

            num_steps += (t + 1)
            global_steps += (t + 1)
            reward_list.append(reward_sum)
            len_list.append(t + 1)
            Winrate = 1.0 * succ_game / game_num
        
        print('Return This Episode:', Ret_0, Ret_1, Ret_2,flush = FLUSH)
        Acceptance_rate.append([
            round((Ret_2[_] /
                   (Ret_2[_] + Ret_1[_] + Ret_0[_] + 1e-6)) * 100.0) / 100.0
            for _ in range(len(Ret_2))
        ])

        reward_record.append({
            'episode': i_episode,
            'steps': global_steps,
            'meanepreward': np.mean(reward_list),
            'meaneplen': np.mean(len_list)
        })

        batch_size = episodic_pass_test_num

        SR = 1.0 * Succ_num / num_steps
        #print("SUU")
        for i_epoch in range(
                int(args.num_epoch * batch_size / args.minibatch_size)):
            '''learning'''
            for h in [1]:
                flag = 0
                loss1 = compute_cross_ent_error(args.minibatch_size,
                                                str(h) + 'step')
            
            
        print('ier lenth', ier_buffer.lenth('1step'),
              ier_buffer.lenth('2step'), ier_buffer.lenth('3step'),
              ier_buffer.lenth('4step'), ier_buffer.lenth('5step'),
              ier_buffer.lenth('6step'), ier_buffer.lenth('7step'),flush = FLUSH)

        eval_0_temp,r1 = eval_policy_50(0.0, args, network, device)
        print('Eval_SucessRate:', eval_0_temp, flush = FLUSH)
        print('Eval_Reward:', r1,flush = FLUSH)
        print('Acceptance Rate ', Acceptance_rate[-1],flush = FLUSH)
        print('Traj length in this episode', Ret_2,flush = FLUSH)
        
        if i_episode % args.log_num_episode == 0:
            print('Finished episode: {} Eval_sr: {:.2f} ;' \
                .format(i_episode, eval_0_temp),flush = FLUSH)
            print('-----------------')
            sr1.append(eval_0_temp)
            reward1.append(r1)
    eval_0_temp,r1 = eval_policy_50(0.0, args, network, device)
    storage = [sr1,reward1]
    return (storage)
