import torch
from torch import nn
from torch.optim import Adam

from generalization_study.utils import Tracker


class ReadOutMLP(nn.Sequential):
    def __init__(self, in_features: int, out_features: int,
                 number_latent=40):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, number_latent),
            nn.ReLU(),
            nn.Linear(number_latent, number_latent),
            nn.ReLU(),
            nn.Linear(number_latent, number_latent),
            nn.ReLU(),
            nn.Linear(number_latent, out_features),
        )


def train_mlp_on_readout(model, dataloader, number_latents: int,
                         number_factors: int, device='cuda',
                         number_epochs=8,
                         writer=None,
                         current_iteration: int = 0):
    model.eval()
    readout_model = ReadOutMLP(in_features=number_latents,
                               out_features=number_factors).to(device)
    optimizer = Adam(readout_model.parameters())

    tracker = Tracker(writer)
    for epoch in range(number_epochs):
        for batch, targets in dataloader:
            batch, targets = batch.to(device), targets.to(device)
            with torch.no_grad():
                model_latents = model(batch)
            predicted_factor = readout_model(model_latents)
            squared_diff = (targets - predicted_factor).pow(2)
            loss = squared_diff.sum(dim=1).mean()  # mse

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tracker.track({'train_loss': loss.item()})
        tracker.write(epoch, f'readout_training/iteration_{current_iteration}_')
    supervised_model = ConcatModels(model, readout_model).eval()
    return supervised_model


class ConcatModels(nn.Module):
    def __init__(self, first_model, second_model):
        super().__init__()
        self.first_model = first_model
        self.second_model = second_model

    def forward(self, x: torch.tensor):
        return self.second_model(self.first_model(x))
