import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import random
from network import q_network, Priorobservation, generate_triple, EFE_network, Policy_network
from torch.distributions import Normal
from utils import *
from itertools import count
import shutil
import numpy as np
import argparse
import os
import pickle as pkl

device = torch.device('cuda:0')

def parser_args():
    parser = argparse.ArgumentParser(description="Active IRL on classic environment in Open AI Gym")
    
    
    parser.add_argument("-n", "--name", type=str,
                        help="name of environment")
    parser.add_argument("--suffix", type=str, 
                        help="suffix for saveing files")
    parser.add_argument("-c", "--cuda_device", type=int, default=0,
                        help="cuda device number, default 0")
    parser.add_argument("-i", "--iter", type=int, default=100, 
                        help="number of iterations")
    parser.add_argument("-u", "--use_kl",  action='store_true',
                        help="whether using kl reward or not")
    parser.add_argument("-s", "--scaler", type=str, default='minmax', 
                       help="minmax or standard or identity, default minmax")
    parser.add_argument("-g", "--gamma", type=float, default='1', 
                        help='discount factor')
    parser.add_argument("-e", "--expert_prior", action='store_true', 
                       help="use expert prior or global prior preference")
    parser.add_argument("-eb", "--expert_batch", action='store_true', 
                        help='using expert batch when training generative model and RL network')
    parser.add_argument("-num", "--num_sample", type=int, default=50)
    #parser.add_argument("-r", "--default_reward", type=int,default=0, 
    #                    help='default reward for every time step, default 0')
    args = parser.parse_args()
    return args    

def optimize_expert_prior_preference(iterations):
    BATCH_SIZE = 128
    prior = Priorobservation(OBS_DIM, 512).to(device)
    parameters = list(prior.parameters())
    optimizer = torch.optim.Adam(parameters, lr=0.001, weight_decay=1e-4)
    print("Learning Expert Prior Preference")
    for i in range(iterations):
        batch = expert_memory.sample(BATCH_SIZE)
        
        input = [x[0] for x in batch]
        input = torch.cat(input).to(device).float()
        target = [x[3] for x in batch]
        target = torch.cat(target).to(device).float()
        
        prior_mu, prior_var = prior(input)
        loss = ((prior_mu - target)**2)/(prior_var**2) + torch.log(prior_var**2)
        loss = loss.mean()/2
        
        print("###################Iteration : {}, Loss : {}####################".format(i+1, 
                                                                                        loss.item()), end='\r')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('')
    print("Done Expert Prior Preference")
    prior = prior.train()
    return prior

def save_model(env_name, path='models/'):
    torch.save(policy_net.state_dict(), os.path.join(path, '{}_policy_net.bin'.format(env_name)))
    torch.save(decoder.state_dict(), os.path.join(path, '{}_decoder.bin'.format(env_name)))
    torch.save(encoder.state_dict(), os.path.join(path, '{}_encoder.bin'.format(env_name)))
    torch.save(transition.state_dict(), os.path.join(path, '{}_transition.bin'.format(env_name)))
    
def load_best_model(env_name, path='models/'):
    policy_net.load_state_dict(torch.load(os.path.join(path, '{}_policy_net.bin'.format(env_name))))
    decoder.load_state_dict(torch.load(os.path.join(path, '{}_decoder.bin'.format(env_name))))
    encoder.load_state_dict(torch.load(os.path.join(path, '{}_encoder.bin'.format(env_name))))
    transition.load_state_dict(torch.load(os.path.join(path, '{}_transition.bin'.format(env_name))))
    
def update_target_net():
    target_net.load_state_dict(policy_net.state_dict())
    
def select_action(state, eps=0.05):
    state = state.to(device).float()
    if random.random()<eps:
        return action_space.sample()
    else:
        with torch.no_grad():
            q_values = policy_net(state).squeeze()
            return torch.argmax(q_values).item()
        
def mean_reward(trials=100):
    R = []
    for i in range(trials):
        done = False
        cumul_reward = 0
        obs_list = []
        action_list = []
        obs = env.reset()
        obs = torch.from_numpy(obs).reshape(1, -1)
        obs = scaler.transform(obs).float()
        obs_list.append(obs)
        while not done:
            with torch.no_grad():
                s, _ = encoder(obs.to(device).float())

            action = select_action(s, eps=0)
            next_obs, reward, done, _  = env.step(action)
            cumul_reward+=reward
            next_obs = torch.from_numpy(next_obs).reshape(1, -1)
            next_obs = scaler.transform(next_obs).float()
            obs_list.append(next_obs)
            obs = next_obs
        R.append(cumul_reward)
    return sum(R)/trials

def generate_replay_memory(size=100000):
    memory = ReplayMemory(size)
    state = env.reset()
    state= torch.from_numpy(state).reshape(1, -1)
    state = scaler.transform(state).float()
    # fill 10% of memory capacity
    for i in range(int(memory.capacity/10)):
        action = select_action(state, eps=1)
        next_state, reward, done, _ = env.step(action)
        next_state = torch.from_numpy(next_state).reshape(1, -1)
        next_state = scaler.transform(next_state).float()

        memory.push([state,
                     action,
                     reward,
                     next_state, 
                     not done])

        if done:
            state = env.reset()
            state= torch.from_numpy(state).reshape(1, -1)
            state = scaler.transform(state).float()
        else:
            state = next_state
    return memory

count_active_reward = 1
def active_reward(o_prev, o, actions, scale=1, use_kl=True):
    global count_active_reward
    with torch.no_grad():
        s_prev, _ = encoder(o_prev)
        s_mu, s_sigma = transition(s_prev)
        actions = actions.unsqueeze(1).unsqueeze(1)
        actions = actions.expand(*s_mu.size()[:2], 1)

        s_mu = s_mu.gather(2, actions).squeeze()
        s_sigma = s_sigma.gather(2, actions).squeeze()

        s_vari_mu, s_vari_sigma = encoder(o)
        expert_mu, expert_sigma = prior_pref(o_prev)
        
        likelihood = -(o-expert_mu)**2/(2*(expert_sigma**2))-0.5*torch.log(expert_sigma)
        likelihood = likelihood.sum(1)
        

        kl = KL_normal(s_mu, s_vari_mu, s_sigma, s_vari_sigma).sum(1)
        #kl = ((s_mu-s_vari_mu)**2).sum()
        if use_kl:
            reward = -likelihood - kl
        else:
            reward = -likelihood
        
        reward = reward*scale
        reward = torch.clamp(reward, -50, 50)
        
        count_active_reward+=1
        
    return reward.unsqueeze(1)

def optimize_rl_model(batch_size, discount_factor=1, final_mask=True, use_kl=True, expert_batch=True):
    if expert_batch:
        batch = memory.sample(int(batch_size/2)) + expert_memory.sample(int(batch_size/2))
    else:
        batch = memory.sample(batch_size)
    o_prev = [x[0] for x in batch]
    o_prev = torch.cat(o_prev).to(device).float()

    actions = [x[1] for x in batch]
    actions = torch.tensor(actions).to(device)

    o = [x[3] for x in batch]
    o = torch.cat(o).to(device).float()

    non_final_mask = [x[4] for x in batch]
    non_final_mask = torch.tensor(non_final_mask).to(device)

    reward = active_reward(o_prev, o, actions, use_kl=use_kl)

    with torch.no_grad():
        s_prev, _ = encoder(o_prev)
        s, _ = encoder(o)

    state_action_values = policy_net(s_prev).gather(1, actions.unsqueeze(1))
    target_values = target_net(s).max(1)[0].detach()
    if final_mask:
        target_values = (target_values*non_final_mask).unsqueeze(1)
    else:
        target_values = target_alues.unsqueeze(1)
    target_values = (discount_factor*target_values) + reward

    loss = F.mse_loss(state_action_values, target_values)

    optimizer_rl.zero_grad()
    loss.backward()
    optimizer_rl.step()
    return loss.item()

def optimize_ae_model(batch_size, expert_batch=True):
    if expert_batch:
        batch = memory.sample(int(batch_size/2)) + expert_memory.sample(int(batch_size/2))
    else:
        batch = memory.sample(batch_size)
    o_prev = [x[0] for x in batch]
    o_prev = torch.cat(o_prev).to(device).float()

    actions = [x[1] for x in batch]
    actions = torch.tensor(actions).to(device)

    o = [x[3] for x in batch]
    o = torch.cat(o).to(device).float()

    s_prev, _ = encoder(o_prev)

    s_mu, s_sigma = transition(s_prev)
    actions = actions.unsqueeze(1).unsqueeze(1)
    actions = actions.expand(*s_mu.size()[:2], 1)
    # prior distributino of s
    s_mu = s_mu.gather(2, actions).squeeze()
    s_sigma = s_sigma.gather(2, actions).squeeze()
    
    s = sample_normal(s_mu, s_sigma)

    o_pred, _ = decoder(s)

    recon_loss = (o_pred-o).pow(2).sum(1)
    kl_loss = KL_normal(s_mu,  torch.zeros_like(s_mu), s_sigma, torch.ones_like(s_sigma)).sum(1)
    
    loss = ((recon_loss) + (1e-4*kl_loss)).mean()
    #loss = recon_loss.mean()
    
    optimizer_ae.zero_grad()
    loss.backward()
    optimizer_ae.step()
    return loss.item()

if __name__== '__main__':
    args = parser_args()
    print(args)
    try:
        shutil.rmtree('runs/{}'.format(args.name))
    except FileNotFoundError:
        pass
    
    device = torch.device('cuda:{}'.format(args.cuda_device))
    env = gym.make(args.name)
    action_space = env.action_space
    obs_space = env.observation_space
    
    OBS_DIM = obs_space.shape[0]
    NUM_ACTION = action_space.n
    S_DIM = 32
    HIDDEN_DIM=128
    # Load Expert simulations
    with open('./rl_baselines_zoo/experts/{}_expert_demo.pkl'.format(args.name), 'rb') as f:
        obs_expert, actions_expert, rewards_expert, next_obs_expert, dones_expert = pkl.load(f)
    obs_expert = [torch.from_numpy(obs).float() for obs in obs_expert]
    next_obs_expert = [torch.from_numpy(obs).float() for obs in next_obs_expert]    
    
    assert args.scaler in ['minmax', 'standard', 'identity']
    if args.scaler=='minmax':
        scaler = MinMax_Normalizer()
    elif args.scaler=='standard':
        scaler = Standardization()    
    else:
        scaler = Identity()
    scaler.fit(torch.cat(obs_expert))
    
    obs_expert = scaler.transform(torch.cat(obs_expert)).unsqueeze(1)
    next_obs_expert = scaler.transform(torch.cat(next_obs_expert)).unsqueeze(1)
       
    expert_memory = ReplayMemory(capacity=obs_expert.size(0))
    expert_memory.memory = list(zip(obs_expert, actions_expert, rewards_expert, next_obs_expert, dones_expert))
    HISTORY = []
    for s in range(args.num_sample):
        # Leraning expert prior preference
        if args.expert_prior:
            prior_pref = optimize_expert_prior_preference(iterations=3000)
            batch = expert_memory.sample(1)
            obs = [x[0] for x in batch]
            obs = torch.cat(obs).to(device).float()

            next_obs = [x[3] for x in batch]
            next_obs = torch.cat(next_obs).to(device).float()

            with torch.no_grad():
                mu, sigma = prior_pref(obs)

            print("next_obs", scaler.transform_inv(next_obs).cpu().numpy())
            print("expert", scaler.transform_inv(mu).cpu().numpy())
        else:
            class global_prior():
                def __init__(self, mu, sigma):
                    self.mu = mu
                    self.sigma = sigma

                def __call__(self, obs):
                    batch, dim = obs.shape
                    mu = torch.stack([self.mu]*batch)
                    sigma = torch.stack([self.sigma]*batch)
                    return mu, sigma

            if 'MountainCar' in args.name:
                mu = torch.tensor([0.5, 0]).to(device)
                mu = scaler.transform(mu)
                sigma = torch.tensor([0.1, 0.1]).to(device)
            if 'CartPole' in args.name:
                mu = torch.tensor([0, 0, 0, 0]).to(device)
                mu = scaler.transform(mu)
                sigma = torch.tensor([0.1, 0.1, 0.1, 0.1]).to(device)            
            prior_pref = global_prior(mu, sigma)
        # Generate generative models and deep Q networks
        decoder, encoder, transition = generate_triple(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM, state_dim=S_DIM, action_num=NUM_ACTION, device=device)
        policy_net = q_network(S_DIM, HIDDEN_DIM, NUM_ACTION).to(device)
        target_net = q_network(S_DIM, HIDDEN_DIM, NUM_ACTION).to(device)

        parameters_ae = list(decoder.parameters()) + list(encoder.parameters()) + list(transition.parameters())
        optimizer_ae = torch.optim.RMSprop(parameters_ae, lr=0.001, momentum=0.95, alpha=0.95, eps=0.01)

        parameters_rl = list(policy_net.parameters())
        optimizer_rl = torch.optim.RMSprop(parameters_rl, lr=0.001, momentum=0.95, alpha=0.95, eps=0.01)

        target_net.load_state_dict(policy_net.state_dict())
        target_net.eval()
        # generate replay memory
        memory = generate_replay_memory(size=100000)
        
        epsilon_end = 0.05
        epsilons = np.linspace(1, epsilon_end, 1000)
        update_count = 1
        eps_count = 0
        best_episode = None
        TARGET_UPDATE = 100
        BEST_REWARD = -99999999
        BEST_EPISODE = None

        num_episodes = args.iter
        reward_record = []
        for i in range(1, num_episodes+1):
            done = False
            obs_list = []
            action_list = []
            obs = env.reset()
            obs = torch.from_numpy(obs).reshape(1, -1)
            obs = scaler.transform(obs).float()
            obs_list.append(obs)
            while not done:
                if eps_count<len(epsilons):
                    epsilon = epsilons[eps_count]
                    eps_count+=1
                else:
                    epsilon = epsilon_end

                with torch.no_grad():
                    s, _ = encoder(obs.to(device).float())

                action = select_action(s, eps=epsilon)
                next_obs, reward, done, _  = env.step(action)
                next_obs = torch.from_numpy(next_obs).reshape(1, -1)
                next_obs = scaler.transform(next_obs).float()
                obs_list.append(next_obs)

                memory.push([obs,
                             action,
                             reward,
                             next_obs, 
                             not done])

                # Move to the next state
                obs = next_obs

                # Perform one step of the optimization (on the target network)
                loss_rl = optimize_rl_model(128, discount_factor=args.gamma, final_mask=True, use_kl=args.use_kl, expert_batch=args.expert_batch)
                loss_ae = optimize_ae_model(128, expert_batch=args.expert_batch)
                update_count+=1
                if done:
                    break
                # Update the target network, copying all weights and biases in DQN
                if update_count % TARGET_UPDATE == 0:
                    update_target_net()
            if (i)%5==0 or i==1:
                mean = mean_reward(trials=5)
                reward_record.append(mean)
            if mean>BEST_REWARD:
                BEST_REWARD = max(mean, BEST_REWARD)
                save_model(args.name)

            print('{}th eposide Reward : {}, Best Reward : {}, loss_rl : {}, loss_ae : {}'.format(i, 
                                                                                                  mean,
                                                                                                  BEST_REWARD,
                                                                                                  loss_rl, 
                                                                                                  loss_ae), end='\r')
        HISTORY.append(reward_record)
        print('')
        
    with open('results/{}_{}.pkl'.format(args.name, args.suffix), 'wb') as f:
        pkl.dump(HISTORY, f)
    