#IRL for mdenv
import numpy as np
import os
import torch
from reward_fns import within_radius, neg_distance
from mdtoyenv import MultiDiscreteToyEnv
from gym.wrappers.time_limit import TimeLimit
from util import get_goals, flatten_buffer, get_hypersphere_locations
from PPO import PPO
from PPO import RolloutBuffer
import copy
import matplotlib.pyplot as plt
from torch import optim, nn
from mdenv_generate_trajectories import add_trajectories_to_replay, flatten_buffer
import argparse
import random
import matplotlib.colors as mcolors
from ilr_utils import combine_batches
from ilr_utils import get_rewards
import wandb


def expected_return(batch, w, importance_weights = None):
    _, _, _, _, _, done_inds = batch
    done_inds = done_inds.to("cpu")
    rewards = get_rewards(batch, env.goal_locs, env.goal_rads)
    sums = [calc_subsums_by_end_index(r, done_inds) for r in rewards]
    sums_cat = torch.cat(sums, dim=0) # 3 x N
    if importance_weights is None:
        importance_weights = torch.ones_like(sums_cat)
    iw_sum = importance_weights.squeeze(0) * sums_cat # 3 x N
    return (w @ iw_sum.sum(1)) / len(done_inds)

def calc_subsums_by_end_index(sum_tensor, end_inds):
    #from https://discuss.pytorch.org/t/sum-over-various-subsets-of-a-tensor/31881/4
    if len(sum_tensor.shape) == 1:
        sum_tensor.unsqueeze_(0)
    bb = sum_tensor.cumsum(1)
    cc = bb.gather(1, end_inds.view(1, -1).long()) #select relevant terms
    cc = torch.cat([torch.zeros(1, 1), cc], dim=-1) #start the sum with zeros
    res = cc.unsqueeze(2) - cc.unsqueeze(1)
    return torch.diagonal(res, offset=-1, dim1=1, dim2=2)

@torch.no_grad()
def simulate_trajectory(w, replay, n_trajs=1):
    flattened_ind = len(replay.states)
    add_trajectories_to_replay(n_trajs, w.detach().numpy(), env, ppo_agent, replay, verbose=False )
    flatten_buffer(replay, flattened_ind)

@torch.no_grad()
def iw_for(batch, weights):
    states, actions, logprobs, _, _, done_inds = batch
    n_rewards = len(weights)
    states[:, -n_rewards:] = weights
    #_, log_ps_w, _ = model.sample(states, greedy=True)
    log_ps_w, _, _ = ppo_agent.policy.evaluate(states, actions)
    
    done_inds = done_inds.to("cpu")
    log_ps_w = log_ps_w.to("cpu")
    logprobs = logprobs.to("cpu")
    log_ps_w_sum = calc_subsums_by_end_index(log_ps_w, done_inds)
    log_ps_wdot_sum = calc_subsums_by_end_index(logprobs, done_inds)

    importance_weights = (log_ps_w_sum - log_ps_wdot_sum).exp()
    #importance_weights.clamp_(0.01, 1.)

    return importance_weights

class WeightLearner(nn.Module):
    def __init__(self, n_rews):
        super(WeightLearner, self).__init__()
        self.n_weights = n_rews
        self.unc_weight = nn.Parameter(torch.randn(self.n_weights))

    def weight(self):
        return self.format_weight(self.unc_weight)

    def format_weight(self, weight: torch.Tensor):
        return weight.softmax(0)

    def sample_weight(self):
        weight = 2 * torch.rand(self.n_weights) - 1.0
        return self.format_weight(weight)

def read_trajs(dir):
    buffer_names = []
    #file_name = sorted(os.listdir(dir))[3]
    for file_name in sorted(os.listdir(dir)):
        file_name = dir + file_name
        if os.path.isfile(file_name):
            buffer_names.append(file_name)

    return buffer_names


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='MDENV Inverse Reinforcement Learning')
    parser.add_argument("--traj_dir", type=str)
    parser.add_argument("--policy_path", type=str)
    parser.add_argument("--random_seed", type=int, default=123)
    parser.add_argument("--w_sim", type=int, default=1)
    parser.add_argument("--w_bs", type=int, default=8)
    parser.add_argument("--w_last_n", type=int, default=1)
    parser.add_argument("--e_bs", type=int, default=32)
    parser.add_argument("--combine_batch", type=int, default=0)
    parser.add_argument("--id", type=int, default=999)

    args = parser.parse_args()
    #wandb.config.update(args)

    seed = args.random_seed
    print(f"Setting seed to {seed}")
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    #Read the trajectories
    print(f"READING BUFFERS FROM {args.traj_dir}")
    buffer_names = read_trajs(args.traj_dir)
    print(f"BUFFER LEN {len(buffer_names)}")
    N_weight_vecs = len(buffer_names)
    n_goals = len( sorted(os.listdir(args.traj_dir))[0].replace("weights", "").replace(".pkl", "").split("_") ) 

    #init env(or reward calculator)
    env_name = "mdenv"
    n_arenas = 1

    n_env_dims = int( (list(torch.load(args.policy_path).items())[0][1].shape[1] - n_goals) / (1 + n_goals) )
    goal_locs, goal_rads = get_hypersphere_locations(n_env_dims, n_goals)
    print(f"GOALS: {goal_locs}, RADS {goal_rads}")
    env = MultiDiscreteToyEnv(goal_locs, goal_rads)
    # state space dimension
    state_dim = env.observation_space.shape[0]

    # action space dimension
    action_space = "multidiscrete"
    action_dim = env.action_space.nvec

    # initialize a PPO agent
    alphas = np.ones(n_goals)
    ppo_agent = PPO(state_dim + n_goals, action_dim, 0.0, 0.0, 0.0, 0.0, 0.0, action_space, 0.0) 
    ppo_agent.load(args.policy_path)

    max_ep_len = 150
    env = TimeLimit(env, max_ep_len)

    #N_replicates = 1
    N_iter = 1250
    loss = 0
    losses = torch.zeros(N_weight_vecs, N_iter)
    weights_log = torch.zeros(N_weight_vecs, N_iter, n_goals)
    #log the real and recovered weights
    real_weights = torch.zeros(N_weight_vecs, n_goals)
    rec_weights = torch.zeros(N_weight_vecs, n_goals)
    #parameters
    e_bs = args.e_bs
    w_sim = args.w_sim
    w_bs = args.w_bs
    w_last_n = bool(args.w_last_n)
    combine_batch = bool(args.combine_batch)
    print("c", combine_batch, "last_n", w_last_n)

    for kk, expert_buffer_path in enumerate(buffer_names):
        expert_buffer = RolloutBuffer()
        expert_buffer.load(expert_buffer_path)
        print("REAL WEIGHTS: ", expert_buffer.states[0][0, -n_goals:])

        wl = WeightLearner(n_goals)
        opt = optim.Adam(wl.parameters(), lr=0.01)
        simulation_buffer = RolloutBuffer()
        print(f'init weight {wl.weight()}')
        for i in range(N_iter):
            weights = wl.weight()
            if i % 10 == 0:
                print(i, weights, loss)
            simulate_trajectory(weights, simulation_buffer, w_sim)
            expert_batch = expert_buffer.sample_episodes(e_bs, max_ep_len)

            simulation_batch = simulation_buffer.sample_episodes(w_bs, max_ep_len, last_n = w_last_n)
            if combine_batch:
                simulation_batch = combine_batches(simulation_batch, expert_buffer.sample_episodes(w_bs, max_ep_len))
            
            #importance_weights = iw_for(simulation_batch, weights)
            importance_weights = None
            batch_losses = []
            for _ in range(2):
                opt.zero_grad()
                weights = wl.weight()
                current_w_expectation = expected_return(simulation_batch, weights, importance_weights)
                expert_expectation = expected_return(expert_batch, weights)
                
                loss = - expert_expectation + current_w_expectation
                loss = loss.clamp(0)#try to clamp to prevent overfitting?
                loss.backward()
                #wandb.log({"loss": loss.detach().item(), "expert_expectation": expert_expectation.detach().item(), "current_w_expectation": current_w_expectation.detach().item(), "weights1": weights[0].item(), "weights2": weights[1].item(), "weights3": weights[2].item()}, step=i)
                torch.nn.utils.clip_grad.clip_grad_norm_(wl.parameters(), 1.)
                batch_losses.append(loss.detach().item())    
                opt.step()
                
            losses[kk, i] = np.mean(batch_losses)
            weights_log[kk, i] = wl.weight().detach()
            # print(losses[kk, i])
            # print(weights_log[kk, i])

            #print(f'{i} loss {loss.item():3f}, weight {wl.weight()}')

        print("Expert weights:")
        print(expert_buffer.states[0][0, -n_goals:])
        real_weights[kk] = expert_buffer.states[0][0, -n_goals:]
        print("Recovered weights:")
        print(wl.weight().detach())
        rec_weights[kk] = wl.weight().detach()
       
        #plot losses
        plt.plot(losses[kk])
        plt.savefig(f"plots/{n_goals}/{kk}_loss.pdf")
        plt.close()
        
        #plot weights
        #cols = list(mcolors.TABLEAU_COLORS.keys())
        cols = ['#000000', '#00FF00', '#0000FF', '#FF0000', '#01FFFE', '#FFA6FE', '#FFDB66', '#006401', '#010067', '#95003A', '#007DB5', '#FF00F6', '#FFEEE8', '#774D00', '#90FB92', '#0076FF', '#D5FF00', '#FF937E', '#6A826C', '#FF029D']
        labels = np.arange(n_goals)
        for ii in range(n_goals):
            plt.plot(weights_log[kk,:,ii].T, label=labels[ii], c=cols[ii], alpha=0.5)
            plt.hlines(expert_buffer.states[0][0, -n_goals:][ii], xmin=0, xmax=N_iter, color=cols[ii])
        plt.legend()
        print(kk)
        plt.savefig(f"plots/{n_goals}/{kk}.pdf")
        plt.close()

    #save the recovered and
    save_tensor = torch.stack((real_weights, rec_weights))
    torch.save(save_tensor, f"results/scalability/{n_goals}.pth")
    torch.save(weights_log, f"results/scalability/{n_goals}_weightslog.pth")
    torch.save(real_weights, f"results/scalability/{n_goals}_realweights.pth")

    print("FINISHED")

    

