import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

import difficulty


def target_function(domain_batch):
    return torch.sin(2 * np.pi * domain_batch ** 2)


def train(model, optimizer, domain_distributions, batch_size_per_distribution=1):
    losses = torch.zeros(len(domain_distributions))
    domain_samples = torch.zeros(len(domain_distributions) * batch_size_per_distribution)
    for step, distribution in enumerate(domain_distributions):
        domain_sample = distribution(batch_size_per_distribution)
        target_values = target_function(domain_sample)

        optimizer.zero_grad()
        predicted_values = model(domain_sample)
        loss = F.mse_loss(predicted_values, target_values, reduction='mean')

        loss.backward()
        optimizer.step()

        losses[step] = loss.item()
        domain_samples[step * batch_size_per_distribution : (step + 1) * batch_size_per_distribution] = domain_sample.flatten()

    return losses, domain_samples


def test(model, test_sample):
    with torch.no_grad():
        target_values = target_function(test_sample)
        predicted_values = model(test_sample)

        loss = F.mse_loss(predicted_values, target_values, reduction='mean')

    return loss


def train_measure_segments(model, optimizer, distributions, measurement_interval, test_size, test_grid=True, batch_size_per_distribution=1):
    steps = []
    mean_train_losses = []
    mean_test_losses = []
    measured_segments = difficulty.segmenter(distributions, measurement_interval)
    for segment, segment_distributions in enumerate(tqdm(measured_segments)):
        losses, _samples = train(model, optimizer, segment_distributions, batch_size_per_distribution)

        if test_grid:
            test_sample = difficulty.grid_covering_stationary_support(test_size, 1)
        else:
            test_sample = difficulty.sample_overall_stationary(test_size, 1)
        test_loss = test(model, test_sample)

        mean_train_losses.append(torch.mean(losses).item())
        mean_test_losses.append(test_loss.item())
        steps.append(segment * len(segment_distributions))
    return torch.tensor([steps, mean_test_losses, mean_train_losses])


def integration_test(model_init, learning_rate):
    model = model_init()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    distributions = difficulty.get_distributions(20, 10, 0.8, 1)

    steps, segment_test_losses, segment_train_losses = train_measure_segments(
        model,
        optimizer,
        distributions,
        measurement_interval=25,
        test_size=100,
        batch_size_per_distribution=50,
    )


if __name__ == '__main__':
    from networks import FixedWidthPureLTA, FixedWidthPureRelu

    integration_test(
        model_init=lambda: FixedWidthPureLTA([100], 100, 1 / 100, -1, 1),
        learning_rate=1e-4,
    )

    integration_test(
        model_init=lambda: FixedWidthPureRelu([100, 100, 100, 100]),
        learning_rate=1e-2,
    )
