import pickle
from torch import nn
import time
import datetime
import pytz
import torch

def save_model(model, path):
    torch.save(model.state_dict(), path)

def save_variable(v, filename):
    with open(filename, 'wb') as f:
        pickle.dump(v, f)

def load_variable(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

def init_all_params(net, hidden_size, gamma):
    for param in net.parameters():
        nn.init.normal_(param, mean=0, std=hidden_size ** (gamma))
    return net

def get_local_time_str():

    server_tz = pytz.timezone("UTC")

    
    local_tz = server_tz

    
    server_time = datetime.datetime.now(server_tz)

    
    local_time = server_time.astimezone(local_tz)

    
    time_str = local_time.strftime('%Y%m%d_%H%M%S')

    return time_str

def load_parameters_from_vector(model, param_vector):

    total_params = sum(p.numel() for p in model.parameters())



    device = next(model.parameters()).device
    param_vector = param_vector.to(device)


    pointer = 0
    

    with torch.no_grad():
        for param in model.parameters():
            num_param = param.numel()

            param_data = param_vector[pointer:pointer+num_param].view_as(param)
            param.copy_(param_data)
            pointer += num_param


    return model