import torch
import torch.nn as nn
from MLP import MLP
import yaml


class Sin(nn.Module):
  def __init__(self):
    super(Sin, self).__init__()
    pass
  def forward(self, x):
    return torch.sin(x)
  

class ZeroCenteredSigmoid(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return nn.functional.sigmoid(x) - 0.5


def eval_model_classification(model, dl, device):
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for x, y in dl:
            x = x.to(device)
            y = y.to(device)
            pred = model(x)
            preds = torch.argmax(pred, dim=-1)
            correct += (preds == y).sum().cpu().numpy()
            total += x.size(0)
    model.train()
    return correct / total


def create_generator_from_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
        arch = config.get('arch', [750, 512])
        gen_in = config.get('gen_in')
        gen_out = config.get('gen_out')
        residual = config.get('residual', False)
        initialization_type = config.get('initialization_type', 'uniform')
        initialization_alpha = config.get('initialization_alpha', 1.0)
        first_layer_initialization_alpha = config.get('first_layer_initialization_alpha', 1.0)
        use_tanh = config.get('use_tanh', True)
        use_bias = config.get('use_bias', False)
        activation_fn_config_value = config.get('activation_fn', 'sin')

    if activation_fn_config_value == 'sin':
        activation = Sin
    elif activation_fn_config_value == 'elu':
        activation = nn.ELU
    elif activation_fn_config_value == 'sigmoid':
        activation = ZeroCenteredSigmoid
    elif activation_fn_config_value == 'relu':
        activation = nn.ReLU
    elif activation_fn_config_value == 'leaky_relu':
        activation = nn.LeakyReLU
    else:
        raise 'invalid activation function'
    mlp_in = gen_in
    final_activation = nn.Tanh() if use_tanh else nn.Identity()

    model = MLP(architecture=[mlp_in] + arch + [gen_out], activation=activation(), final_activation=final_activation, 
                residual=residual, bias=use_bias)

    if len(arch) > 0:
        first_layer = model.basis[0]
    else:
        first_layer = model.regressor[0]

    if initialization_type == 'gaussian':
        with torch.no_grad():
            for m in model.modules():
                if type(m) is nn.Linear:
                    if m is first_layer:
                        init_range = torch.sqrt(torch.as_tensor(first_layer_initialization_alpha / m.in_features))
                    else:
                        init_range = torch.sqrt(torch.as_tensor(initialization_alpha / m.in_features))
                    nn.init.normal_(m.weight.data, 0, init_range)
    elif initialization_type == 'uniform':
        with torch.no_grad():
            for m in model.modules():
                if type(m) is nn.Linear:
                    if m is first_layer:
                        init_range = torch.sqrt(torch.as_tensor(first_layer_initialization_alpha / m.in_features))
                    else:
                        init_range = torch.sqrt(torch.as_tensor(initialization_alpha / m.in_features))

                    nn.init.uniform_(m.weight.data, -init_range, init_range)

    return model
