import os
import pickle
# import pickle5 as pickle
import random
import warnings
from distutils.util import strtobool
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from environments.parallel_envs import make_vec_envs

class DeviceManager:
    def __init__(self):
        self._device = None

    def set_device(self, gpu_id):
        if torch.cuda.is_available() and gpu_id >= 0:
            torch.cuda.set_device(gpu_id)
            self._device = torch.device(f"cuda:{gpu_id}")
        else:
            self._device = torch.device("cpu")
        print(f"Device set to: {self._device}")

    def get_device(self):
        if self._device is None:
            self.set_device(0)  
        return self._device

device_manager = DeviceManager()

# Use these functions in place of the global device variable
def set_device(gpu_id):
    device_manager.set_device(gpu_id)

def get_device():
    return device_manager.get_device()

# def save_models(args, logger, policy, vae, envs, iter_idx):
#     # TODO: save parameters, not entire model
#
#     save_path = os.path.join(logger.full_output_folder, 'models')
#     if not os.path.exists(save_path):
#         os.mkdir(save_path)
#     try:
#         torch.save(policy.actor_critic, os.path.join(save_path, "policy{0}.pt".format(iter_idx)))
#     except AttributeError:
#         torch.save(policy.policy, os.path.join(save_path, "policy{0}.pt".format(iter_idx)))
#     torch.save(vae.encoder, os.path.join(save_path, "encoder{0}.pt".format(iter_idx)))
#     if vae.state_decoder is not None:
#         torch.save(vae.state_decoder, os.path.join(save_path, "state_decoder{0}.pt".format(iter_idx)))
#     if vae.reward_decoder is not None:
#         torch.save(vae.reward_decoder,
#                    os.path.join(save_path, "reward_decoder{0}.pt".format(iter_idx)))
#     if vae.task_decoder is not None:
#         torch.save(vae.task_decoder, os.path.join(save_path, "task_decoder{0}.pt".format(iter_idx)))
#
#     # save normalisation params of envs
#     if args.norm_rew_for_policy:
#         rew_rms = envs.venv.ret_rms
#         save_obj(rew_rms, save_path, "env_rew_rms{0}.pkl".format(iter_idx))
#     if args.norm_obs_for_policy:
#         obs_rms = envs.venv.obs_rms
#         save_obj(obs_rms, save_path, "env_obs_rms{0}.pkl".format(iter_idx))


def reset_env(env, args, indices=None, state=None, pretrain_reward=None):
    """ env can be many environments or just one """
    # reset all environments
    if (indices is None) or (len(indices) == args.num_processes):
        state = env.reset().float().to(get_device())
        if pretrain_reward is not None:
            pretrain_reward.update_reward_fn()
    # reset only the ones given by indices
    else:
        assert state is not None
        for i in indices:
            state[i] = env.reset(index=i)
            if pretrain_reward is not None:
                pretrain_reward.update_reward_fn_index(i)

    belief = torch.from_numpy(env.get_belief()).float().to(get_device()) if args.pass_belief_to_policy else None
    task = torch.from_numpy(env.get_task()).float().to(get_device()) if args.pass_task_to_policy else None

    return state, belief, task


def squash_action(action, args):
    if args.norm_actions_post_sampling:
        return torch.tanh(action)
    else:
        return action


def env_step(env, action, args):
    act = squash_action(action.detach(), args)
    next_obs, reward, done, infos = env.step(act)

    if isinstance(next_obs, list):
        next_obs = [o.to(get_device()) for o in next_obs]
    else:
        next_obs = next_obs.to(get_device())
    if isinstance(reward, list):
        reward = [r.to(get_device()) for r in reward]
    else:
        reward = reward.to(get_device())

    belief = torch.from_numpy(env.get_belief()).float().to(get_device()) if args.pass_belief_to_policy else None
    task = torch.from_numpy(env.get_task()).float().to(get_device()) if (
                args.pass_task_to_policy or args.decode_task) else None

    return [next_obs, belief, task], reward, done, infos

def select_action(args,
                  policy,
                  deterministic,
                  state=None,
                  belief=None,
                  task=None,
                  prob=None,
                  latent_pol=None,
                  latent_sample=None, latent_mean=None, latent_logvar=None, w=None):
    """ Select action using the policy. """
    latent = get_latent_for_policy(args=args, latent_sample=latent_sample, latent_mean=latent_mean,
                                   latent_logvar=latent_logvar)
    if latent_pol is not None and latent_pol.shape[0] == 1:
        latent_pol = latent_pol[0]
    action = policy.act(state=state, latent=latent, belief=belief, task=task, prob=prob, latent_pol= latent_pol, w=w, deterministic=deterministic)
    if isinstance(action, list) or isinstance(action, tuple):
        value, action = action
    else:
        value = None
    action = action.to(get_device())
    return value, action


def get_latent_for_policy(args, latent_sample=None, latent_mean=None, latent_logvar=None):
    if (latent_sample is None) and (latent_mean is None) and (latent_logvar is None):
        return None

    if args.add_nonlinearity_to_latent:
        latent_sample = F.relu(latent_sample)
        latent_mean = F.relu(latent_mean)
        latent_logvar = F.relu(latent_logvar)

    if args.sample_embeddings:
        latent = latent_sample
    else:
        latent = torch.cat((latent_mean, latent_logvar), dim=-1)

    if latent.shape[0] == 1:
        latent = latent.squeeze(0)

    return latent


def update_encoding(encoder, next_obs, action, reward, done, hidden_state, vae_mixture_num=1, y_intercept=None):
    # print("helpers.py, update_encoding called")
    # reset hidden state of the recurrent net when we reset the task
    if done is not None:
        hidden_state = encoder.reset_hidden(hidden_state, done)

    if vae_mixture_num > 1:
        with torch.no_grad():
            encoder_output = encoder(actions=action.float(),
                                   states=next_obs,
                                   rewards=reward,
                                   hidden_state=hidden_state,
                                   return_prior=False,
                                   y_intercept=y_intercept)
            
            # Handle both GMVAE (7 outputs) and regular mixture VAE (10 outputs)
            if len(encoder_output) == 7:  # GMVAE
                latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w = encoder_output
                return latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w
            else:  # Regular mixture VAE
                latent_sample, latent_mean, latent_logvar, hidden_state, \
                y, z, mu, var, logits, prob = encoder_output
                return latent_sample, latent_mean, latent_logvar, hidden_state, y, z, mu, var, logits, prob

        # TODO: move the sampling out of the encoder!

    else:
        with torch.no_grad():
            latent_sample, latent_mean, latent_logvar, hidden_state = encoder(actions=action.float(),
                                                                              states=next_obs,
                                                                              rewards=reward,
                                                                              hidden_state=hidden_state,
                                                                              return_prior=False)

        # TODO: move the sampling out of the encoder!

        return latent_sample, latent_mean, latent_logvar, hidden_state


def update_encoding_pol(encoder, next_obs, action, reward, done, hidden_state):
    # print("helpers.py, update_encoding called")
    # reset hidden state of the recurrent net when we reset the task
    if done is not None:
        hidden_state = encoder.reset_hidden(hidden_state, done)
    with torch.no_grad():
        latent_mean, hidden_state = encoder(actions=action.float(),
                                                                          states=next_obs,
                                                                          rewards=reward,
                                                                          hidden_state=hidden_state,
                                                                          return_prior=False,
                                                                          sample=False)

    # TODO: move the sampling out of the encoder!

        return latent_mean, hidden_state

def seed(seed, deterministic_execution=False):
    print('Seeding random, torch, numpy.')
    random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    np.random.seed(seed)

    if deterministic_execution:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        print('Note that due to parallel processing results will be similar but not identical. '
              'Use only one process and set --deterministic_execution to True if you want identical results '
              '(only recommended for debugging).')


def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data, gain)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)

def recompute_embeddings(
        policy_storage,
        encoder,
        sample,
        update_idx,
        detach_every,
        mixture=False,
        policy_separate_gru = False
):
    if policy_separate_gru:
        # get the prior
        assert not mixture
        latent_pol = [policy_storage.latent_pol[0].detach().clone()]
        latent_pol[0].requires_grad = True

        # loop through experience and update hidden state
        # (we need to loop because we sometimes need to reset the hidden state)
        h = policy_storage.hidden_states_pol[0].detach()

        for i in range(policy_storage.actions.shape[0]):
            # reset hidden state of the GRU when we reset the task
            h = encoder.reset_hidden(h, policy_storage.done[i + 1])

            tm, h = encoder(policy_storage.actions.float()[i:i + 1],
                                    policy_storage.next_state[i:i + 1],
                                    policy_storage.rewards_raw[i:i + 1],
                                    h,
                                    sample=sample,
                                    return_prior=False,
                                    detach_every=detach_every
                                    )
            latent_pol.append(tm)

        if update_idx == 0:
            try:
                assert (torch.cat(policy_storage.latent_pol) - torch.cat(latent_pol)).sum() == 0
            except AssertionError:
                warnings.warn('You are not recomputing the embeddings correctly!')
                import pdb
                pdb.set_trace()

        policy_storage.latent_pol = latent_pol  # use latent mean as the latent for policy gru encoder

    else:
        # get the prior
        latent_sample = [policy_storage.latent_samples[0].detach().clone()]
        latent_mean = [policy_storage.latent_mean[0].detach().clone()]
        latent_logvar = [policy_storage.latent_logvar[0].detach().clone()]
        if mixture:
            prob = [policy_storage.prob[0].detach().clone()]
            prob[0].requires_grad = True

        latent_sample[0].requires_grad = True
        latent_mean[0].requires_grad = True
        latent_logvar[0].requires_grad = True

        # loop through experience and update hidden state
        # (we need to loop because we sometimes need to reset the hidden state)
        h = policy_storage.hidden_states[0].detach()

        if mixture:
            for i in range(policy_storage.actions.shape[0]):
                # reset hidden state of the GRU when we reset the task
                h = encoder.reset_hidden(h, policy_storage.done[i + 1])
                #z, mu, torch.log(var + 1e-20), output, y, z, mu, var, logits, prob
                ts, tm, tl, h, ty, tz, tmu, tvar, tlogits, tprob = encoder(policy_storage.actions.float()[i:i + 1],
                                                                           policy_storage.next_state[i:i + 1],
                                                                           policy_storage.rewards_raw[i:i + 1],
                                                                           h,
                                                                           sample=sample,
                                                                           return_prior=False,
                                                                           detach_every=detach_every
                                                                           )
                #print("recompute embeddings, step: {}, prob: {}".format(i, tprob))
                latent_sample.append(ts)
                latent_mean.append(tm)
                latent_logvar.append(tl)
                prob.append(tprob)
        else:
            for i in range(policy_storage.actions.shape[0]):
                # reset hidden state of the GRU when we reset the task
                h = encoder.reset_hidden(h, policy_storage.done[i + 1])

                ts, tm, tl, h = encoder(policy_storage.actions.float()[i:i + 1],
                                        policy_storage.next_state[i:i + 1],
                                        policy_storage.rewards_raw[i:i + 1],
                                        h,
                                        sample=sample,
                                        return_prior=False,
                                        detach_every=detach_every
                                        )

                # print(i, reset_task.sum())
                # print(i, (policy_storage.latent_mean[i + 1] - tm).sum())
                # print(i, (policy_storage.latent_logvar[i + 1] - tl).sum())
                # print(i, (policy_storage.hidden_states[i + 1] - h).sum())

                latent_sample.append(ts)
                latent_mean.append(tm)
                latent_logvar.append(tl)

        if update_idx == 0:
            if mixture:
                try:
                    assert ((torch.cat(policy_storage.prob) - torch.cat(prob))**2).sum() == 0
                except AssertionError:
                    warnings.warn('You are not recomputing the embeddings correctly!')
                    import pdb
                    pdb.set_trace()
            else:
                try:
                    assert (torch.cat(policy_storage.latent_mean) - torch.cat(latent_mean)).sum() == 0
                    assert (torch.cat(policy_storage.latent_logvar) - torch.cat(latent_logvar)).sum() == 0
                except AssertionError:
                    warnings.warn('You are not recomputing the embeddings correctly!')
                    import pdb
                    pdb.set_trace()

        # print('3: ', policy_storage.latent_mean[1])
        # print('4: ', latent_mean[1])
        policy_storage.latent_samples = latent_sample
        policy_storage.latent_mean = latent_mean
        policy_storage.latent_logvar = latent_logvar

        if mixture:
            policy_storage.prob = prob


class FeatureExtractor(nn.Module):
    """ Used for extrating features for states/actions/rewards """

    def __init__(self, input_size, output_size, activation_function):
        super(FeatureExtractor, self).__init__()
        self.output_size = output_size
        self.activation_function = activation_function
        if self.output_size != 0:
            self.fc = nn.Linear(input_size, output_size)
        else:
            self.fc = None

    def forward(self, inputs):
        if self.output_size != 0:
            return self.activation_function(self.fc(inputs))
        else:
            return torch.zeros(0, ).to(get_device())


def sample_gaussian(mu, logvar, num=None):
    std = torch.exp(0.5 * logvar)
    if num is not None:
        std = std.repeat(num, 1)
        mu = mu.repeat(num, 1)
    eps = torch.randn_like(std)
    return mu + std * eps


def save_obj(obj, folder, name):
    filename = os.path.join(folder, name + '.pkl')
    with open(filename, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(folder, name):
    filename = os.path.join(folder, name + '.pkl')
    with open(filename, 'rb') as f:
        return pickle.load(f)


class RunningMeanStd(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    # PyTorch version.
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = torch.zeros(shape).float().to(get_device())
        self.var = torch.ones(shape).float().to(get_device())
        self.count = epsilon

    def update(self, x):
        x = x.view((-1, x.shape[-1]))
        batch_mean = x.mean(dim=0)
        batch_var = x.var(dim=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)


def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count


def boolean_argument(value):
    """Convert a string value to boolean."""
    if isinstance(value, bool):
        return value
    return bool(strtobool(value))


def get_task_dim(args):
    env = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes,
                        gamma=args.policy_gamma, device=get_device(),
                        episodes_per_task=args.max_rollouts_per_task,
                        normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                        tasks=None
                        )
    return env.task_dim


def get_num_tasks(args):
    env = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes,
                        gamma=args.policy_gamma, device=get_device(),
                        episodes_per_task=args.max_rollouts_per_task,
                        normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                        tasks=None
                        )
    try:
        num_tasks = env.num_tasks
    except AttributeError:
        num_tasks = None
    return num_tasks


def clip(value, low, high):
    """Imitates `{np,tf}.clip`.

    `torch.clamp` doesn't support tensor valued low/high so this provides the
    clip functionality.

    TODO(hartikainen): The broadcasting hasn't been extensively tested yet,
        but works for the regular cases where
        `value.shape == low.shape == high.shape` or when `{low,high}.shape == ()`.
    """
    low, high = torch.tensor(low), torch.tensor(high)

    assert torch.all(low <= high), (low, high)

    clipped_value = torch.max(torch.min(value, high), low)
    return clipped_value


def update_encoding_dme(encoder, next_obs, action, reward, done, hidden_state, w_intercept=None, y_intercept=None):
    """
    Update encoding for DME with optional intercept parameters for virtual training.
    
    DME encoders have a recurrent stem (GRU) just like regular encoders and need proper hidden state management.
    """
    # reset hidden state of the recurrent net when we reset the task
    if done is not None:
        hidden_state = encoder.reset_hidden(hidden_state, done)
        
    with torch.no_grad():
        encoder_output = encoder(actions=action.float(),
                               states=next_obs,
                               rewards=reward,
                               hidden_state=hidden_state,
                               return_prior=False,
                               return_w_params=False,  # Don't need w_mu, w_logvar here
                               y_intercept=y_intercept)
        
        # GMVAE encoder returns: latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w
        latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w = encoder_output
            
        # If w_intercept is provided, use it instead of the computed w
        if w_intercept is not None:
            w = w_intercept
            
        return latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w
