#AIRL for mdenv
import numpy as np
import os
import torch
from reward_fns import within_radius, neg_distance
from mdtoyenv import MultiDiscreteToyEnv
from gym.wrappers.time_limit import TimeLimit
from util import flatten_buffer, get_hypersphere_locations
from PPO import PPO
from PPO import RolloutBuffer
import copy
from torch import optim, nn
from mdenv_generate_trajectories_airl import add_trajectories_to_replay, flatten_buffer
import argparse
import random
from ilr_utils import combine_batches, get_rewards
import wandb
import time


@torch.no_grad()
def simulate_trajectory(rew_fn, replay, n_trajs=1, cat_state_weights=True):
    flattened_ind = len(replay.states)    
    sim_trajs = add_trajectories_to_replay(n_trajs, rew_fn, env, ppo_agent, replay, verbose=False, cat_state_weights=cat_state_weights)
    flatten_buffer(replay, flattened_ind)
    return sim_trajs


class AIRL(nn.Module):
    def __init__(self, input_shape, linear_g, goal_locs=None, goal_rads=None, n_goals=None, n_env_dims=None):
        super().__init__()
        self.input_shape = input_shape
        self.gamma = 0.99
        self.linear_g = linear_g

        #g_theta and h_phi are as described in Fu et. al
        assert( (linear_g and (n_goals is not None)) or (not linear_g and (n_goals is None)) )
        if self.linear_g:
            self.goal_rads = goal_rads
            self.goal_locs = goal_locs
            self.n_env_dims = n_env_dims
            print("LINEAR G")
            self.weights = nn.Parameter(torch.randn((1,n_goals), requires_grad=True))
            self.rew_fn = neg_distance
            def get_linear_rewards(s):
                sub_reward = [self.rew_fn(s[:,:self.n_env_dims], goal.reshape((1,self.n_env_dims)), rad.reshape((1,1))) for ii, (goal, rad) in enumerate(zip(self.goal_locs, self.goal_rads))]
                sub_reward = torch.from_numpy(np.column_stack(sub_reward)).float()
                g_theta_out = (self.weights.softmax(-1) * sub_reward).sum(-1, keepdim=True)
                return g_theta_out
            self.g_theta = get_linear_rewards
        else:
            print("Deep G")
            self.g_theta = nn.Sequential(nn.Linear(self.input_shape, 32),
                                        nn.ReLU(),
                                        nn.Linear(32, 32),
                                        nn.ReLU(),
                                        nn.Linear(32,32),
                                        nn.ReLU(),
                                        nn.Linear(32,1))
            
        self.h_phi = nn.Sequential(nn.Linear(self.input_shape, 32),
                                     nn.ReLU(),
                                     nn.Linear(32, 32),
                                     nn.ReLU(),
                                     nn.Linear(32,32),
                                     nn.ReLU(),
                                     nn.Linear(32,1))

    def forward(self, s, sdot):
        f = self.g_theta(s) + self.gamma * self.h_phi(sdot) - self.h_phi(s)
        return f


def read_trajs(dir):
    buffer_names = []
    #file_name = sorted(os.listdir(dir))[3]
    for file_name in sorted(os.listdir(dir)):
        file_name = dir + file_name
        if os.path.isfile(file_name):
            buffer_names.append(file_name)

    return buffer_names

def get_reward_f(disc, policy):
    def rew(s, a, sdot):
        f = disc(s, sdot).squeeze(-1)
        log_probs, _, _ = policy.evaluate(s, a)
        # disc_out = f.exp() / (f.exp() + log_probs.exp())
        # reward1 = disc_out.log() - (1 - disc_out).log() #original formulation
        reward = f - log_probs #more stable, but equal to the above
        return reward

    return rew

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='MDENV Inverse Reinforcement Learning')
    parser.add_argument("--id", type=int, default=999)
    parser.add_argument("--traj_dir", type=str)
    parser.add_argument("--warmstart", type=int, default=0)
    parser.add_argument("--random_seed", type=int, default=123)
    parser.add_argument("--k_epochs", default=3, type=int) #R?
    parser.add_argument("--eps_clip", default=0.2, type=float)
    parser.add_argument("--lr_actor", default=0.0003, type=float)
    parser.add_argument("--lr_critic", default=0.001, type=float)
    parser.add_argument("--lr_disc", default=0.001, type=float)
    parser.add_argument("--gamma", default=0.99, type=float)
    parser.add_argument("--lamb", default=0.95, type=float)
    parser.add_argument("--bs_policy", default=150, type=int)
    parser.add_argument("--entropy_coef", default=0.01, type=float)
    parser.add_argument("--linear_g", default=1, type=int)
    parser.add_argument("--expert_data_ind", default=0, type=int)
    parser.add_argument("--bs_sim", default=8*5, type=int)
    parser.add_argument("--bs_expert", default=32, type=int)
    parser.add_argument("--n_sim", default=8, type=int)
    parser.add_argument("--disc_weight_decay", default=0.01, type=float)
    parser.add_argument("--wandb_name", default="airl-default", type=str)


    args = parser.parse_args()
    seed = args.random_seed
    print(f"Setting seed to {seed}")
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    linear_g = bool(args.linear_g)
    warmstart = bool(args.warmstart) #bool if we load a pretrained policy
    airl_modeldir = f"airl_models2/id{args.id}"
    if linear_g:
        airl_modeldir = f"airl_models_linear2/id{args.id}"
    if warmstart:
        airl_modeldir = f"airl_models_warmstart2/id{args.id}"

    #Read the trajectories
    buffer_names = read_trajs(args.traj_dir)
    kk = args.expert_data_ind
    expert_buffer_path = buffer_names[kk]
    print(f"READING BUFFERS FROM {expert_buffer_path}")
    print(f"BUFFER LEN {len(buffer_names)}")
    print("SAVING MODELS TO", airl_modeldir)
    n_goals = len( sorted(os.listdir(args.traj_dir))[0].replace("weights", "").replace(".pkl", "").split("_") ) 
    #load and save the dimension of the env same as in demonstrations
    expert_buffer = RolloutBuffer()
    expert_buffer.load(expert_buffer_path)
    real_weights = expert_buffer.states[0][0, -n_goals:]
    n_env_dims = int( (expert_buffer.states[0].shape[1] - n_goals) / (1 + n_goals) )
    env_name = "mdenv"
    n_arenas = 1

    goal_locs, goal_rads = get_hypersphere_locations(n_env_dims, n_goals)
    print(f"GOALS: {goal_locs}, RADS {goal_rads}")
    env = MultiDiscreteToyEnv(goal_locs, goal_rads)
    # state space dimension
    state_dim = env.observation_space.shape[0]
    # action space dimension
    action_space = "multidiscrete"
    action_dim = env.action_space.nvec

    max_ep_len = 150
    env = TimeLimit(env, max_ep_len)

    N_iter = 10000
    weights_log = torch.zeros(N_iter, n_goals)

    wandb.init(project=f"{args.wandb_name}")
    wandb.config.update(args)
    wandb.config.update({"expert_data_ind": kk})
    
    print("REAL WEIGHTS: ", real_weights)
    wandb.config.update({"real_weights": real_weights.tolist()})

    simulation_buffer = RolloutBuffer()
    loss = torch.zeros(1)
    #discriminator
    if linear_g:
        disc_model = AIRL(state_dim, linear_g = True, n_goals = n_goals, goal_locs=goal_locs, goal_rads=goal_rads, n_env_dims=n_env_dims)
    else:
        disc_model = AIRL(state_dim, linear_g = False)
    
    disc_loss = nn.BCEWithLogitsLoss()
    opt = optim.AdamW(disc_model.parameters(), lr=args.lr_disc, weight_decay=args.disc_weight_decay, betas=[0.5, 0.999])
    # initialize a PPO agent
    ppo_agent = PPO(state_dim, action_dim, args.lr_actor, args.lr_critic, args.gamma, args.k_epochs, args.eps_clip, action_space, 0.0, args.lamb, args.bs_policy, args.entropy_coef) 

    if warmstart:
        airl_ind = 250
        ppo_path = f"airl_models_linear2/id10/9_ppo_agent{airl_ind}.pth"
        disc_path = f"airl_models_linear2/id10/9_disc{airl_ind}.pth"
        print(f"loading airl policy (ppo) from {ppo_path} \n .. and disc from {disc_path}")
        # load disc and ppo_agent state_dicts
        print("ppo before load", sum(pp.norm().detach() for pp in ppo_agent.policy.parameters()))
        ppo_agent.load(ppo_path)
        print("ppo after load", sum(pp.norm().detach() for pp in ppo_agent.policy.parameters()))
        print("disc before load", sum(pp.norm().detach() for pp in disc_model.parameters()))
        disc_model.load_state_dict(torch.load(disc_path))
        print("disc after load", sum(pp.norm().detach() for pp in disc_model.parameters()))

    for i in range(N_iter):
        if i % 10 == 0:
            print(i)
        loop_start = time.time()
        
        reward_function = get_reward_f(disc_model, ppo_agent.policy)
        sim_start = time.time()
        policy_states, policy_actions, policy_log_probs, policy_rewards, policy_is_terminals, policy_next_states = simulate_trajectory(reward_function, simulation_buffer, args.n_sim, cat_state_weights = False)
        # print("sim_time", time.time() - sim_start)

        #copy the last simulation to the ppo_agent buffer
        buffer_start = time.time()
        policy_buffer = RolloutBuffer()
        policy_buffer.states = copy.deepcopy(policy_states)
        policy_buffer.actions = copy.deepcopy(policy_actions)
        policy_buffer.logprobs = copy.deepcopy(policy_log_probs)
        policy_buffer.rewards = copy.deepcopy(policy_rewards)
        policy_buffer.is_terminals = copy.deepcopy(policy_is_terminals)
        policy_buffer.next_states = copy.deepcopy(policy_next_states)
        ppo_agent.buffer = policy_buffer
        # print("buffer_time", time.time() - buffer_start)

        sample_time = time.time()
        simulation_batch = simulation_buffer.sample_episodes(args.bs_sim, max_ep_len, last_n = True)
        #simulation_batch = combine_batches(simulation_batch, expert_buffer.sample_episodes(n_sim, max_ep_len)) 
        n_simulation = simulation_batch[0].shape[0]
        expert_batch = expert_buffer.sample_episodes(min(int(n_simulation/max_ep_len), args.bs_expert), max_ep_len)
        n_expert = expert_batch[0].shape[0]

        if len(simulation_buffer.states) > 3000 * max_ep_len: #avoid simulation buffer to accumulate too much memory
            simulation_buffer.remove_n_last(1500 * max_ep_len)

        #update discriminator expert first
        disc_expert_states = expert_batch[0][:,:-n_goals]
        disc_expert_actions = expert_batch[1]
        disc_expert_statesdot = expert_batch[4][:, :-n_goals]
        disc_expert_y = torch.ones(n_expert)
        opt.zero_grad()
        f_X = disc_model(disc_expert_states, disc_expert_statesdot).squeeze(1)
        log_ps, _, _ = ppo_agent.policy.evaluate(disc_expert_states, disc_expert_actions)
        log_ps = log_ps.detach()
        disc_logits = f_X - log_ps
        loss = disc_loss(disc_logits, disc_expert_y)
        loss.backward()
        opt.step()
        disc_expert_acc = (disc_logits.sigmoid() > 0.5).float().mean()

        #update discriminator simulation 
        disc_sim_states = simulation_batch[0]
        disc_sim_actions = simulation_batch[1]
        disc_sim_statesdot = simulation_batch[4]
        disc_sim_y = torch.zeros(n_simulation)
        opt.zero_grad()
        f_X = disc_model(disc_sim_states, disc_sim_statesdot).squeeze(1)
        log_ps, _, _ = ppo_agent.policy.evaluate(disc_sim_states, disc_sim_actions)
        log_ps = log_ps.detach()
        disc_logits = f_X - log_ps
        loss = disc_loss(disc_logits, disc_sim_y)
        loss.backward()
        opt.step()
        disc_sim_acc = (disc_logits.sigmoid() < 0.5).float().mean()
        

        if linear_g:
            weights_softmax = list(disc_model.named_parameters())[0][1].softmax(-1)
            weights_log[i] = weights_softmax
            wandb.log({"weights1": weights_softmax[0,0], "weights2": weights_softmax[0,1], "weights3": weights_softmax[0,2]}, step=i)

        if (i % 250 == 0) and (not warmstart):
            ppo_agent.save(f"{airl_modeldir}/{kk}_ppo_agent{i}.pth")
            torch.save(disc_model.state_dict(), f"{airl_modeldir}/{kk}_disc{i}.pth")

        #update the policy
        policy_start = time.time()
        ppo_stats = ppo_agent.update(empty_buffer=True)
        # print("policy_time", time.time() - policy_start)
        wandb.log({"dl": loss.detach().item(), "dacc_expert": disc_expert_acc, "dacc_sim": disc_sim_acc, "ppo_target": ppo_stats[0], "ppo_clip_frac":ppo_stats[5]}, step=i)
        # print("total", time.time() - loop_start)


    #save the recovered and
    rec_weights = list(disc_model.named_parameters())[0][1].softmax(-1)
    save_tensor = torch.stack((real_weights, rec_weights.squeeze(0)))
    resdir = "airl_results_warmstart" if warmstart else "airl_results"
    torch.save(save_tensor, f"results/{resdir}/{kk}_{n_goals}.pth")
    torch.save(weights_log, f"results/{resdir}/{kk}_{n_goals}_weightslog.pth")
    torch.save(real_weights, f"results/{resdir}/{kk}_{n_goals}_realweights.pth")

    print("FINISHED")
    print("real_weights", real_weights)
    print("rec_weights", rec_weights)

    

