import numpy as np
import os
import pickle
from PPO import RolloutBuffer
import torch
from reward_fns import RewardCalculator
from gym_derk.envs import DerkEnv
from gym.wrappers.time_limit import TimeLimit
from util import SingleAgentWrapper, DiscreteActionEnv, BotPolicy
from util import flatten_buffer
from PPO import PPO
from PPO import RolloutBuffer
import copy
import matplotlib.pyplot as plt
from torch import optim, nn
from derk_generate_trajectories import add_trajectories_to_replay, flatten_buffer
import argparse

def get_rewards(batch):
    #a function to respect earlier implenetation, refactor later?
    #These could be saved during the simulation already?
    states, actions, _, _, next_states, _ = batch
    states = states.to("cpu")
    actions = actions.to("cpu")
    next_states = next_states.to("cpu")
    rc = RewardCalculator(1)
    rewards = np.zeros((states.shape[0], 2))
    for ii in range(states.shape[0]):
        state, action, next_state = states[[ii]].numpy(), [tuple(actions[ii].tolist())], next_states[[ii]].numpy()
        rc.update_focus(state, action, next_state) 
        reward_dmg = rc.damage_enemy(state, action, next_state) / 10.0 #normalize to be more equivalent with healing
        reward_heal = rc.heal_teammate(state, action, next_state)
        reward = np.column_stack([reward_dmg, reward_heal])
        rewards[ii] = reward
    rewards = torch.from_numpy(rewards).float()
    return list(rewards.T)

def expected_return(batch, w, importance_weights = None):
    _, _, _, _, _, done_inds = batch
    done_inds = done_inds.to("cpu")
    rewards = get_rewards(batch)
    sums = [calc_subsums_by_end_index(r, done_inds) for r in rewards]
    sums_cat = torch.cat(sums, dim=0) # 3 x N
    if importance_weights is None:
        importance_weights = torch.ones_like(sums_cat)
    iw_sum = importance_weights.squeeze(0) * sums_cat # 3 x N
    return (w @ iw_sum.sum(1)) / len(done_inds)

def calc_subsums_by_end_index(sum_tensor, end_inds):
    #from https://discuss.pytorch.org/t/sum-over-various-subsets-of-a-tensor/31881/4
    if len(sum_tensor.shape) == 1:
        sum_tensor.unsqueeze_(0)
    bb = sum_tensor.cumsum(1)
    cc = bb.gather(1, end_inds.view(1, -1).long()) #select relevant terms
    cc = torch.cat([torch.zeros(1, 1), cc], dim=-1) #start the sum with zeros
    res = cc.unsqueeze(2) - cc.unsqueeze(1)
    return torch.diagonal(res, offset=-1, dim1=1, dim2=2)

@torch.no_grad()
def simulate_trajectory(w, replay):
    flattened_ind = len(replay.states)
    add_trajectories_to_replay(1, w.detach().numpy(), env, ppo_agent, replay, verbose=False )
    flatten_buffer(replay, flattened_ind)

@torch.no_grad()
def iw_for(batch, weights):
    states, actions, logprobs, _, _, done_inds = batch
    n_rewards = len(weights)
    states[:, -n_rewards:] = weights
    log_ps_w, _, _ = ppo_agent.policy.evaluate(states, actions)#TODO this is now not greedy
    
    done_inds = done_inds.to("cpu")
    log_ps_w = log_ps_w.to("cpu")
    logprobs = logprobs.to("cpu")
    log_ps_w_sum = calc_subsums_by_end_index(log_ps_w, done_inds)
    log_ps_wdot_sum = calc_subsums_by_end_index(logprobs, done_inds)

    importance_weights = (log_ps_w_sum - log_ps_wdot_sum).exp()
    importance_weights.clamp_(0.01, 1.)

    return importance_weights

class WeightLearner(nn.Module):
    def __init__(self, n_rews):
        super(WeightLearner, self).__init__()
        self.n_weights = n_rews-1

        self.unc_weight = nn.Parameter(torch.randn(1))

    def weight(self):
        return self.format_weight(self.unc_weight)

    def format_weight(self, weight):
        weights_torch_sigmoid = weight.sigmoid()
        return torch.cat((weights_torch_sigmoid, 1.0 - weights_torch_sigmoid))

    def sample_weight(self):
        weight = 2 * (torch.rand(self.n_weights) - 1 )
        return self.format_weight(weight)

if __name__ == "__main__":

    #IRL for derk
    parser = argparse.ArgumentParser(description='MDENV Inverse Reinforcement Learning')
    parser.add_argument("--traj_dir", default = "trajectories/derk", type=str)
    parser.add_argument("--policy_path", type=str)
    parser.add_argument("--random_seed", type=int, default=123)

    args = parser.parse_args()


    #Read the trajectories
    dir = args.traj_dir
    buffer_paths = [dir + ff for ff in sorted(os.listdir(dir))]
    
    #init env(or reward calculator)
    env_name = "derk"
    #init reward functions
    base_reward_fn = {
        "killEnemyStatue": 0,
        "killEnemyUnit": 0,
        "damageEnemyUnit": 0,
        "healTeammate1": 0,
        "healTeammate2": 0,
        "friendlyFire": 0,
        "statueDamageTaken": 0,
        "fallDamageTaken": 0,
        "healEnemy": 0,
        "timeScaling": 0
    }
    # random seed set to derk_appinstance.py create_session() -function "pWkn91perNetBJQ3ymNc9"
    n_arenas = 32
    env = DerkEnv(n_arenas = n_arenas, turbo_mode=True, reward_function=base_reward_fn, session_args=dict(interleaved=False))
    discretization = [[env.action_space[0].low.item(), env.action_space[0].high.item()], [env.action_space[1].low.item(), env.action_space[1].high.item()], [env.action_space[2].low.item(), env.action_space[2].high.item()], None, None]
    d_steps = 5
    for di, interval in enumerate(discretization):
        if interval is not None:
            disc = torch.linspace(interval[0], interval[1], d_steps+1)
            disc = disc.repeat_interleave(2)
            disc = disc[1:-1].reshape(-1, 2)
            discretization[di] = disc

    env = DiscreteActionEnv(env, discretization)

    # state space dimension
    state_dim = env.observation_space.shape[0]

    # action space dimension
    action_space = "multidiscrete"
    action_dim = env.action_space.nvec
    rc = RewardCalculator(1)
    weights_torch = torch.randn(2)
    weights_torch = weights_torch.softmax(dim=0)

    # initialize a PPO agent
    n_weights = 2
    alphas = np.ones(n_weights)
    bot_weights = np.random.dirichlet(alphas, 1)[0]
    #weights = np.array([0.0, 1.0], dtype="float32")
    ppo_agent = PPO(state_dim + n_weights, action_dim, 0.0, 0.0, 0.0, 0.0, 0.0, action_space, 0.0) 
    ppo_agent.load(args.policy_path)
    #copy if one wants static bot policy
    #bot_net = copy.deepcopy(ppo_agent.policy.actor)
    #if we want to load existing model
    #bot_net = PPO(state_dim + n_weights, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, action_space, action_std)
    #bot_net.load("PPO_preTrained/derk/PPO_derk_0_0.pth")
    #bot_net = bot_net.policy.actor
    #if we wish to have exactly the same policy
    #bot_net = ppo_agent.policy.actor
    bot_net = None
    bp = BotPolicy(env=env, bot_net=bot_net, weights=bot_weights)
    env = SingleAgentWrapper(env, bp)
    max_ep_len = 150
    env = TimeLimit(env, max_ep_len)

    N_replicates = len(buffer_paths)
    N_iter = 80
    losses = torch.zeros(N_replicates, N_iter)
    weights_log = torch.zeros(N_replicates, N_iter, n_weights)
    #log the real and recovered weights
    real_weights = torch.zeros(N_replicates, n_weights)
    rec_weights = torch.zeros(N_replicates, n_weights)
    
    for kk, expert_path in enumerate(buffer_paths):
        expert_buffer = RolloutBuffer()
        expert_buffer.load(expert_path)
        print("REAL WEIGHTS: ", expert_buffer.states[0][0, -2:])
        wl = WeightLearner(n_weights)
        opt = optim.Adam(wl.parameters(), lr=0.1)
        bs = 32
        simulation_buffer = RolloutBuffer()
        print(f'{kk} iter, init weight {wl.weight()}')
        import time
        startt = time.time()        
        for i in range(N_iter):

            def cc():
                opt.zero_grad()

                weights = wl.weight()
                simulate_trajectory(weights, simulation_buffer)
                expert_batch = expert_buffer.sample_episodes(bs, max_ep_len)
                simulation_batch = simulation_buffer.sample_episodes(n_arenas, max_ep_len, last_n = True)
                
                importance_weights = iw_for(simulation_batch, weights)
                #importance_weights = None

                current_w_expectation = expected_return(simulation_batch, weights, importance_weights)
                expert_expectation = expected_return(expert_batch, weights)
                
                loss = - expert_expectation + current_w_expectation
                loss.backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(wl.parameters(), 1.)
                return loss
            
            loss = opt.step(cc)
            losses[kk, i] = loss.detach().item()
            weights_log[kk, i] = wl.weight().detach()
            

        print("Time ", time.time() - startt)
        print("Expert weights:")
        print(expert_buffer.states[0][0, -n_weights:])
        real_weights[kk] = expert_buffer.states[0][0, -n_weights:]
        print("Recovered weights:")
        print(wl.weight().detach())
        rec_weights[kk] = wl.weight().detach()
        expert_buffer.clear()
    
        torch.save(weights_log, "results/derk/weights_log.pth")
        torch.save(real_weights, "results/derk/real_weights.pth")


        # plt.plot(losses.T)
        # plt.show()

        # cols = ("brown", "green")
        # labels = ("dmg", "heal")
        # for ii in range(n_weights):
        #     plt.plot(weights_log[kk,:,ii].T, label=labels[ii], c=cols[ii], alpha=0.5)
        #     plt.hlines(expert_buffer.states[0][0, -n_weights:][ii], xmin=0, xmax=N_iter, color=cols[ii])

        # plt.legend()
        # plt.show()

        
    '''
    #plt.plot(losses[450:])
    print(init_weights)
    print(sm(weights_torch))
    '''

