import torch
import torch.nn as nn
import torch.nn.functional as F

from nn.cola_nn import dense_init

# Compute f(X) by combining the input dimensions and sine waves
def function_with_freq(X, X_max, X_min, num_waves, frequencies, phase_shifts, amplitudes):
    num_points, num_input_dims = X.shape
    f_X = torch.zeros(num_points, device=X.device, dtype=X.dtype)

    for i in range(num_waves):
        sines = torch.sin(2 * torch.pi * frequencies[i] * (X / (X_max - X_min)) + phase_shifts[i])
        wave = torch.prod(sines[..., :5], dim=-1)
        f_X += amplitudes[i] * wave
    return f_X

def get_random_f(freq_min, freq_max, num_waves, num_input_dims, X_max, X_min, y_max, device='cpu', rng=torch.Generator(), dtype=torch.float32):
    # Generate random frequencies, amplitudes, and phase shifts for each sine wave
    frequencies = torch.randint(freq_min, freq_max, (num_waves, num_input_dims), generator=rng, device=device, dtype=dtype)
    # frequencies = freq_min + torch.rand(num_waves, num_input_dims, device=device, dtype=dtype, generator=rng) * (freq_max - freq_min)
    amplitudes = torch.ones(num_waves, device=device, dtype=dtype)
    phase_shifts = torch.rand(num_waves, num_input_dims, device=device, dtype=dtype, generator=rng) * 2 * torch.pi

    f = lambda X: function_with_freq(X, X_max, X_min, num_waves, frequencies, phase_shifts, amplitudes)
    return f

class SineWaves(nn.Module):
    def __init__(self, feature_num, num_waves):
        super(SineWaves, self).__init__()
        self.template = nn.Linear(feature_num, num_waves)
        self.output_layer = nn.Linear(num_waves, 1)

        self.feature_num = feature_num

    def forward(self, x):
        # Only use the first few dimension
        x = x[..., :self.feature_num]
        x = torch.sin(self.template(x))
        return self.output_layer(x)
    
    def init(self, rng, freq_std):
        self.template.weight.data.normal_(mean=0, std=freq_std*self.feature_num**-0.5, generator=rng)
        self.template.bias.data.uniform_(0, 2*torch.pi, generator=rng)
        self.output_layer.weight.data.uniform_(1, 10, generator=rng)
        self.output_layer.bias.data.zero_()


def get_sine_f(feature_num, num_waves, freq_std, rng, device):
    sine_waves = SineWaves(feature_num, num_waves).to(device)
    sine_waves.init(rng, freq_std)

    # Normalie the output
    Z = torch.rand(1024, feature_num, device=device, generator=rng) - 0.5
    out = sine_waves(Z)
    mean = out.mean()
    std = out.std() + 1e-8
    sine_waves.output_layer.weight.data.div_(std)
    sine_waves.output_layer.bias.data.add_(-mean/std)

    Z = torch.rand(1024, feature_num, device=device, generator=rng) - 0.5
    Z = sine_waves.template(Z)
    print(f'Template µ: {Z.mean().item():.1f}, σ: {Z.std().item():.1f}')
    Z = torch.sin(Z)
    print(f'Sine µ: {Z.mean().item():.1f}, σ: {Z.std().item():.1f}')
    Z = sine_waves.output_layer(Z)
    print(f'output µ: {Z.mean().item():.1f}, σ: {Z.std().item():.1f}')

    frequencies = sine_waves.template.weight.norm(dim=1)
    print(f'frequencies µ: {frequencies.mean().item():.1f}, σ: {frequencies.std().item():.1f}')

    return sine_waves

class MLP(nn.Module):
    def __init__(self, feature_num, hidden_sizes, output_size):
        super(MLP, self).__init__()
        layers = []
        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(feature_num if i == 0 else hidden_sizes[i-1], hidden_size))
            layers.append(nn.ReLU())
        self.network = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_sizes[-1], output_size)

        self.feature_num = feature_num

    def forward(self, x):
        # Only use the first few dimension
        x = x[..., :self.feature_num]
        x = self.network(x)
        x = self.output_layer(x)

        return x
    
    def init(self, rng):
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                d_in = layer.in_features
                std = d_in**-0.5
                layer.weight.data.normal_(mean=0, std=std, generator=rng)
                if layer.bias is not None:
                    layer.bias.data.zero_()
        
        d_in = self.output_layer.in_features
        std = d_in**-0.5
        self.output_layer.weight.data.normal_(mean=0, std=std, generator=rng)
        self.output_layer.bias.data.zero_()


def get_teacher_f(feature_num, hidden_sizes, rng, device):
    teacher_model = MLP(feature_num, hidden_sizes, 1).to(device)
    teacher_model.init(rng)

    # Normalie the output
    Z = torch.rand(1024, feature_num, device=device, generator=rng) - 0.5
    out = teacher_model(Z)
    mean = out.mean()
    std = out.std() + 1e-8
    print(f"mean {mean:.2f} std {std:.2f}")
    teacher_model.output_layer.weight.data.div_(std)
    teacher_model.output_layer.bias.data.add_(-mean/std)

    Z = torch.rand(1024, feature_num, device=device, generator=rng) - 0.5
    Z = teacher_model(Z)
    print(f'output µ: {Z.mean().item():.1f}, σ: {Z.std().item():.1f}')


    return lambda x: teacher_model(x)

