import torch.nn.functional as F
import torch.nn as nn

import numpy as np
import torch
import random
import os


def _select_seed_randomly(min_seed_value=0, max_seed_value=255):
    return random.randint(min_seed_value, max_seed_value)


def seed_everything(seed):
    max_seed_value = np.iinfo(np.uint32).max
    min_seed_value = np.iinfo(np.uint32).min

    try:
        if seed is None:
            seed = os.environ.get("PL_GLOBAL_SEED")
        seed = int(seed)
    except (TypeError, ValueError):
        seed = _select_seed_randomly(min_seed_value, max_seed_value)
        print(f"No correct seed found, seed set to {seed}")

    if not (min_seed_value <= seed <= max_seed_value):
        print(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
        seed = _select_seed_randomly(min_seed_value, max_seed_value)

    os.environ["PL_GLOBAL_SEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed


def make_parameter(shape, device, debug=False):
    new_param = nn.Parameter(torch.Tensor(*shape))
    new_param.data = new_param.to(device)
    if debug:
        new_param.data = new_param.double()
    return new_param


def kd_prox(precond_outputs, fsp_outputs, temp=1.):
    prox_fsp = nn.KLDivLoss("batchmean")(F.log_softmax(precond_outputs / temp, dim=1),
                                         F.softmax(fsp_outputs / temp, dim=1)) * (temp * temp)
    return prox_fsp


def euc_prox(precond_outputs, fsp_outputs):
    batch_size = precond_outputs.shape[0]
    prox_fsp = torch.sum((precond_outputs - fsp_outputs) ** 2.) / batch_size
    return prox_fsp


def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])


def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)


def make_functional(mod):
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    i = 0
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append((name, orig_params[i]))
        del p
        i += 1
    return names
