import json
import logging
import os
import pickle

from torch import autograd
from einops import rearrange, reduce, repeat
import torch
import torch.nn as nn
import torch.distributed as dist
import numpy as np

import psutil
import os
import time

# -------------------------------------------- utils -------------------------------------------
def log_memory_usage():
    process = psutil.Process(os.getpid())
    
    memory_usage = process.memory_info().rss / 1024 / 1024  # in MB
    
    # Get the current timestamp
    timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
    
    logger.info(f"Memory usage: {timestamp},{memory_usage}\n")

def compute_grad2(d_out, x_in, batch_size):
    grad_dout = []
    
    tmp = autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    assert (tmp.size() == x_in.size()), (tmp.size(), x_in.size())
    grad_dout.append(tmp)

    grad_dout = torch.cat(grad_dout, dim=-1)
    grad_dout2 = grad_dout.pow(2)
    reg = grad_dout2.transpose(0, 1).reshape(batch_size, -1).sum(1)
    return reg

def extend_trajectory(terminals, elements, extend_traj_length, fill_value = None):
    is_numpy = isinstance(terminals, np.ndarray)
    if is_numpy:
        terminals = torch.from_numpy(terminals)
        elements = [torch.from_numpy(e) for e in elements]
    length = len(elements[0])
    assert all([len(e) == length for e in elements]), ([len(e) for e in elements])

    if fill_value is None:
        fill_value = [0 for e in elements]

    if len(terminals) < length:
        terminals = torch.cat([terminals, torch.zeros(length - len(terminals)).to(terminals)], dim=0)
    
    ret = [[] for e in elements]

    last_t = 0
    for t in range(length):
        if terminals[t].all() or t == length - 1:
            l = t - last_t + 1
            for i, e in enumerate(elements):
                ret[i].append(e[last_t: t+1])
            if l < extend_traj_length:
                for i, e in enumerate(elements):
                    ret[i].append(fill_value[i] * torch.ones(extend_traj_length - l, *e.shape[1:]).to(e))
            last_t = t + 1
    ret = [torch.cat(v, dim=0) for v in ret]
    if is_numpy:
        ret = [r.cpu().numpy() for r in ret]
    return ret

def sample_from_rollouts(obs, actions, on_reset, batch_size, window_size, extend_traj=False, extend_traj_length=200):
    on_reset = on_reset.clone()
    
    T, B = on_reset.shape[:2]

    if extend_traj:
        _obs = []
        _actions = []
        _on_reset = []
        for b in range(B):
            obs_lst, actions_lst, on_reset_lst = extend_trajectory(on_reset[1:, b], [obs[:, b], actions[:, b], on_reset[:, b]], extend_traj_length)
            _obs.append(obs_lst)
            _actions.append(actions_lst)
            _on_reset.append(on_reset_lst)
        obs = torch.cat(_obs, dim=0).unsqueeze(1)
        actions = torch.cat(_actions, dim=0).unsqueeze(1)
        on_reset = torch.cat(_on_reset, dim=0).unsqueeze(1)
    else:
        obs = obs.transpose(0, 1).reshape(T * B, 1, *obs.shape[2:])
        actions = actions.transpose(0, 1).reshape(T * B, 1, *actions.shape[2:])
        on_reset = on_reset.transpose(0, 1).reshape(T * B, 1, *on_reset.shape[2:])
        T = T * B
        B = 1

    W = window_size
    assert T >= W
    invalid_mask = torch.zeros_like(on_reset).to(dtype=torch.float32)
    for i in range(window_size - 1):
        invalid_mask[1:] = (invalid_mask[1:] + invalid_mask[:-1] > 0).float()
        invalid_mask[on_reset == 1] = 1.
        invalid_mask[0] = 1.
    invalid_mask[:-1] = (invalid_mask[:-1] + (on_reset[1:] * on_reset[:-1] > 0).float() > 0).float()
    invalid_mask[1:] = (invalid_mask[1:] + (on_reset[1:] * on_reset[:-1] > 0).float() > 0).float()
    valid_mask = (1 - invalid_mask).reshape(-1)
    
    sampling_prob = valid_mask / valid_mask.sum()
    sampled_indices = torch.multinomial(sampling_prob, num_samples=batch_size, replacement=True)
    selected_indices = torch.stack([sampled_indices + (- window_size + i + 1) * B for i in range(window_size)], dim=-1).reshape(-1)
    
    trans_fn = lambda x: x.reshape(T * B, *x.shape[2:])[selected_indices].reshape(batch_size, window_size, *x.shape[2:]).transpose(0, 1)
    
    return trans_fn(obs), trans_fn(actions), trans_fn(on_reset)


# -------------------------------------- Info Decoder -----------------------------------------

class InformationDecoder:
    def __init__(self, z_latent_dim, obs_dim, window_size, use_orthogonal=True, gain=0.01, device=torch.device('cpu')):
        from onpolicy.algorithms.utils.distributions import Bernoulli, Categorical, DiagGaussian
        self._obs_dim = obs_dim
        self._window_size = window_size
        self.net = nn.Sequential(
            nn.Linear(window_size * obs_dim, 256), 
            nn.LayerNorm(256), 
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            DiagGaussian(256, z_latent_dim, use_orthogonal, gain)
        ).to(device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-4)
    
    def compute(self, x, z):
        x = rearrange(x, "w b d -> b (w d)", w=self._window_size, d=self._obs_dim)
        z_logit = self.net(x)
        z_log_prob = z_logit.log_probs(z)
        return z_log_prob

    def parameters(self):
        return self.net.parameters()

def build_info_decoder(**kwargs):
    return InformationDecoder(**kwargs)

# -------------------------------------- discriminator -----------------------------------------
class Discriminator:
    def __init__(self, obs_dim, window_size, device):
        self._obs_dim = obs_dim
        self._window_size = window_size
        self.net = nn.Sequential(nn.Linear(window_size * obs_dim, 256), 
                                 nn.LayerNorm(256), 
                                 nn.ReLU(),
                                 nn.Linear(256, 256),
                                 nn.LayerNorm(256),
                                 nn.ReLU(),
                                 nn.Linear(256, 256),
                                 nn.LayerNorm(256),
                                 nn.ReLU(),
                                 nn.Linear(256, 1),
                                 nn.Sigmoid()).to(device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=5e-5)

    def compute(self, x):
        x = rearrange(x, "w b d -> b (w d)", w=self._window_size, d=self._obs_dim)
        y = self.net(x)
        return y

    def parameters(self):
        return self.net.parameters()

def build_discriminator(**kwargs):
    return Discriminator(**kwargs)

# --------------------------------------- obs func --------------------------------------------

class GAIL_ObsFunction():
    def __init__(self, task, window_size, algorithm_name, scenario_name):
        self.task = task
        self.window_size = window_size
        self.algorithm_name = algorithm_name
        self.scenario_name = scenario_name
    
    def _process_window_data(self, x):
        return x

    def process_real_data(self, data, device=torch.device("cpu")):
        observations, actions = data

        assert observations.shape[0] == self.window_size
        assert actions.shape[0] == self.window_size

        x_real = torch.cat([observations, actions], dim=-1).to(device)

        processed_data = self._process_window_data(x_real)

        return processed_data
    
    def process_fake_data(self, obs, actions, on_reset):
        if self.algorithm_name in ["diff-gail", "diff-infogail"]:
            # actions in the buffer is reference + action + use_reference
            act_dim = actions.shape[-1] // 2
            reference_actions, actions, use_reference = actions[..., :act_dim], actions[..., act_dim:act_dim * 2], actions[..., act_dim * 2:]
            actions = use_reference * reference_actions + (1 - use_reference) * actions
        
        T, B = on_reset.shape[:2]
        W = self.window_size
        assert T >= W
        invalid_mask = torch.zeros_like(on_reset).to(dtype=torch.float32)
        for i in range(self.window_size - 1):
            invalid_mask[1:] = (invalid_mask[1:] + invalid_mask[:-1] > 0).float()
            invalid_mask[on_reset == 1] = 1.
            invalid_mask[0] = 1.
        valid_mask = 1 - invalid_mask

        obs = torch.stack([obs[i:i+T-W+1] for i in range(W)], dim=1).transpose(0, 1).reshape(W, (T - W + 1) * B, *obs.shape[2:])
        actions = torch.stack([actions[i:i+T-W+1] for i in range(W)], dim=1).transpose(0, 1).reshape(W, (T - W + 1) * B, *actions.shape[2:])

        x_fake = torch.cat([obs, actions], dim=-1)

        processed_data = self._process_window_data(x_fake)

        return processed_data, valid_mask

class GAIL_XY_ObsFunction(GAIL_ObsFunction):
    # only for antmaze & maze2d
    def _process_window_data(self, x):
        assert (self.scenario_name.startswith('antmaze') or self.scenario_name.startswith('maze2d'))
        x = x.clone()
        x[..., 2:] = 0
        return x

class GAIL_XY_Pose_ObsFunction(GAIL_ObsFunction):
    # only for antmaze
    def _process_window_data(self, x):
        assert self.scenario_name.startswith('antmaze')
        x = x.clone()
        x[..., 15:] = 0
        return x

class GAIL_XY_Pose_Act_ObsFunction(GAIL_ObsFunction):
    # only for antmaze
    def _process_window_data(self, x):
        assert self.scenario_name.startswith('antmaze')
        x = x.clone()
        x[..., 15:29] = 0
        return x

class GAIL_NO_ACT_ObsFunction(GAIL_ObsFunction):
    def process_real_data(self, data, device=torch.device("cpu")):
        observations, actions = data
        return super().process_real_data([observations, torch.zeros_like(actions)], device=device)
    
    def process_fake_data(self, obs, actions, on_reset):
        return super().process_fake_data(obs, torch.zeros_like(actions), on_reset)

class GAIL_Pose_ObsFunction(GAIL_ObsFunction):
    # only pose, no velocity, no action
    def _process_window_data(self, x):
        if self.scenario_name.startswith("halfcheetah"):
            assert x.shape[-1] == 17 + 6
            pose_ix = [0,1,2,3,8,9,10,11,12]
            vel_ix = [4,5,6,7,13,14,15,16]
        elif self.scenario_name.startswith("hopper"):
            assert x.shape[-1] == 11 + 3
            x_ = torch.zeros_like(x)
            pose_ix = [0,1,2,3,4]
            vel_ix = [5,6,7,8,9,10]
        elif self.scenario_name.startswith("walker2d"):
            assert x.shape[-1] == 17 + 6
            pose_ix = [0,1,2,3,4,5,6,7]
            vel_ix = [8,9,10,11,12,13,14,15,16]
        elif self.scenario_name in ['antmaze-neargoal-large-v1',]:
            assert x.shape[-1] == 31 + 8
            pose_ix = [2,3,4,5,6,7,8,9,10,11,12,13,14]
            vel_ix = list(np.arange(15, 29+1))
        elif self.scenario_name.startswith("antmaze-"):
            assert x.shape[-1] == 29 + 8
            pose_ix = [2,3,4,5,6,7,8,9,10,11,12,13,14]
            vel_ix = list(np.arange(15, 29+1))
        else:
            raise RuntimeError(f"GAIL_Pose_ObsFunction doesn't support {self.scenario_name}")
        x_ = torch.zeros_like(x)
        x_[..., pose_ix] = x[..., pose_ix]
        return x_

class GAIL_Pose_Act_ObsFunction(GAIL_ObsFunction):
    # no velocity
    def _process_window_data(self, x):
        if self.scenario_name.startswith("halfcheetah"):
            assert x.shape[-1] == 17 + 6
            pose_ix = [0,1,2,3,8,9,10,11,12]
            vel_ix = [4,5,6,7,13,14,15,16]
        elif self.scenario_name.startswith("hopper"):
            assert x.shape[-1] == 11 + 3
            x_ = torch.zeros_like(x)
            pose_ix = [0,1,2,3,4]
            vel_ix = [5,6,7,8,9,10]
        elif self.scenario_name.startswith("walker2d"):
            assert x.shape[-1] == 17 + 6
            pose_ix = [0,1,2,3,4,5,6,7]
            vel_ix = [8,9,10,11,12,13,14,15,16]
        else:
            raise RuntimeError(f"GAIL_Pose_ObsFunction doesn't support {self.scenario_name}")
        x_ = x.clone()
        x_[..., vel_ix].fill_(0.)
        return x_

class Football_ObsFunction:
    def __init__(self, task, window_size, algorithm_name, scenario_name):
        self.task = task
        self.window_size = window_size
        self.algorithm_name = algorithm_name
        self.scenario_name = scenario_name
    
    def _process_window_data(self, x):
        return x

    def process_real_data(self, data, device=torch.device("cpu")):
        observations, actions = data

        assert observations.shape[0] == self.window_size
        assert actions.shape[0] == self.window_size

        x_real = observations.to(device)

        processed_data = self._process_window_data(x_real)

        return processed_data
    
    def process_fake_data(self, obs, actions, on_reset):
        T, B = on_reset.shape[:2]
        W = self.window_size
        assert T >= W
        invalid_mask = torch.zeros_like(on_reset).to(dtype=torch.float32)
        for i in range(self.window_size - 1):
            invalid_mask[1:] = (invalid_mask[1:] + invalid_mask[:-1] > 0).float()
            invalid_mask[on_reset == 1] = 1.
            invalid_mask[0] = 1.
        valid_mask = 1 - invalid_mask

        obs = torch.stack([obs[i:i+T-W+1] for i in range(W)], dim=1).transpose(0, 1).reshape(W, (T - W + 1) * B, *obs.shape[2:])

        x_fake = obs

        processed_data = self._process_window_data(x_fake)

        return processed_data, valid_mask

_GAIL_ObsFunctionClass = {"default": GAIL_ObsFunction,
                          "d4rl": GAIL_ObsFunction,
                          "d4rl:xy": GAIL_XY_ObsFunction,
                          "d4rl:xy_pose": GAIL_XY_Pose_ObsFunction,
                          "d4rl:xy_pose_act": GAIL_XY_Pose_Act_ObsFunction,
                          "d4rl:no_act": GAIL_NO_ACT_ObsFunction,
                          "d4rl:pose": GAIL_Pose_ObsFunction,
                          "d4rl:pose_act": GAIL_Pose_Act_ObsFunction,
                          "football:default": Football_ObsFunction,}

def build_obs_func(task, window_size, algorithm_name, scenario_name):
    return _GAIL_ObsFunctionClass[task](task=task, 
                                        window_size=window_size,
                                        algorithm_name=algorithm_name, 
                                        scenario_name=scenario_name)

# ---------------------------------------- dataset ---------------------------------------------


FOOTBALL_DATASETS = {
    "academy_counterattack_easy": ["/workspace/sipo/datasets/fb_ca_sipo-rbf_seed6_N5_500k_1.pkl", "/workspace/sipo/datasets/fb_ca_sipo-rbf_seed6_N5_500k_2.pkl"],
    "academy_3_vs_1_with_keeper": ["/workspace/sipo/datasets/fb_3v1_sipo-wd_seed4_500k_1.pkl", "/workspace/sipo/datasets/fb_3v1_sipo-wd_seed4_500k_2.pkl"],
    "academy_corner": ["/workspace/sipo/datasets/fb_corner_sipo-rbf_seed5_N4_250k_1.pkl", "/workspace/sipo/datasets/fb_corner_sipo-rbf_seed5_N4_250k_2.pkl", "/workspace/sipo/datasets/fb_corner_sipo-rbf_seed5_N4_250k_3.pkl", "/workspace/sipo/datasets/fb_corner_sipo-rbf_seed5_N4_250k_4.pkl"]
}

class GAIL_D4RL_Dataset:
    def __init__(self, task, scenario_name, window_size, dataset_path, **kwargs):
        import gym
        import d4rl

        self._env = gym.make(scenario_name)
        self._dataset = self._env.get_dataset(h5path=dataset_path)
        self._size = self._dataset["observations"].shape[0]
        self._window_size = window_size
    
    def sample(self, batch_size):
        indices = np.random.randint(0, self._dataset["observations"].shape[0] - self._window_size, (batch_size,))
        observations = []
        actions = []
        masks = np.ones((batch_size,))
        for w in range(self._window_size):
            cur_obs = self._dataset["observations"][indices + w] * masks[:, np.newaxis]
            cur_actions = self._dataset["actions"][indices + w] * masks[:, np.newaxis]
            observations.append(cur_obs)
            actions.append(cur_actions)
            timeouts = (1 - self._dataset["timeouts"][indices + w]) if "timeouts" in self._dataset else 1
            masks = masks * timeouts * (1 - self._dataset["terminals"][indices + w])
        observations = torch.from_numpy(np.stack(observations, axis=0)).to(dtype=torch.float32) # w, bs, d
        actions = torch.from_numpy(np.stack(actions, axis=0)).to(dtype=torch.float32) # w, bs, d
        return [observations, actions]

class GAIL_Football_Dataset(GAIL_D4RL_Dataset):
    def __init__(self, task, scenario_name, window_size, dataset_path, **kwargs):
        self.extend_traj = kwargs.get("extend_traj", False)
        self.extend_traj_length = kwargs.get("extend_traj_length", 200)

        env_name = scenario_name
        loadpaths = FOOTBALL_DATASETS[env_name]
        dataset = {}
        for loadpath in loadpaths:
            d = pickle.load(open(loadpath, 'rb'))
            for k, v in d.items():
                if k not in dataset:
                    dataset[k] = []
                dataset[k].append(v)
        dataset = {k: np.concatenate(v) for k, v in dataset.items()}
        observations = np.concatenate([dataset["observations/cent_state"], dataset["observations/ball_owned_team"], dataset["observations/ball_owned_player"]], axis=-1)
        actions = dataset["actions"]
        onehot_actions = np.zeros((actions.shape[0], actions.shape[1], 19))
        for i in range(19):
            onehot_actions[:, :, i] = (actions[:, :, 0] == i).astype(np.float)
        actions = onehot_actions.reshape(actions.shape[0], actions.shape[1] * 19)
        next_observations = np.concatenate([observations[1:], np.zeros_like(observations[:1])], axis=0)
        terminals = dataset["terminals"]
        timeouts = np.zeros_like(terminals)
        next_observations[terminals > 0] = 0
        rewards = dataset["rewards"]
        
        if self.extend_traj:
            observations, actions, next_observations, rewards, terminals, timeouts = extend_trajectory(terminals, [observations, actions, next_observations, rewards, terminals, timeouts], self.extend_traj_length)

        dataset = {
            "observations": observations,
            "actions": actions,
            "next_observations": next_observations,
            "rewards": rewards,
            "terminals": terminals,
            "timeouts": timeouts,
        }
 
        self._dataset = dataset
        self._window_size = window_size
    


_GAIL_TaskDataset = {
    "d4rl": GAIL_D4RL_Dataset,
    "default": GAIL_D4RL_Dataset,
    "football:default": GAIL_Football_Dataset,
}

def build_dataset(task, **kwargs):
    return _GAIL_TaskDataset.get(task, _GAIL_TaskDataset["default"])(task=task, **kwargs)
