import math
import torch
import numpy as np
from mpi4py import MPI
import random

def create_log_gaussian(mean, log_std, t):
    quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
    l = mean.shape
    log_z = log_std
    z = l[-1] * math.log(2 * math.pi)
    log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
    return log_p

def logsumexp(inputs, dim=None, keepdim=False):
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)
import os
def set_seed(env,args):
    
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    




    env.seed(args.seed)#+MPI.COMM_WORLD.Get_rank())
    env.action_space.seed(args.seed)
    #env.action_space.seed(args.seed)
    np.random.seed(args.seed)#+ MPI.COMM_WORLD.Get_rank())
    random.seed(args.seed )#+ MPI.COMM_WORLD.Get_rank())
    torch.manual_seed(args.seed)# + MPI.COMM_WORLD.Get_rank())
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if  args.cuda:
        print('cuda seed')
        torch.cuda.manual_seed(args.seed)# + MPI.COMM_WORLD.Get_rank())
    print('seed is set as: {}'.format(args.seed ))#)+ MPI.COMM_WORLD.Get_rank()))

  


def load_models(load_flag,agent,args):
    if load_flag:
        agent.load_checkpoint("checkpoints/{1}/sac_checkpoint_{0}_{1}".format(args.env_name, args.experiment_name))
        print("load successful")
    return