import torch

import numpy as np

from torch import nn
from torch.nn import functional

import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)


class LinearMapModel(nn.Module):
    def __init__(self, solver, history_length, resolution=False):
        super(LinearMapModel, self).__init__()

        self.in_dim = history_length * solver.output_dimension ** 2  # Token history goes in
        if resolution:
            self.out_dim = solver.grid_dimension ** 2  # Original data comes out
            self.state_dim = 2 * solver.grid_dimension ** 2  # The dimension of the state space of the governing eq
        else:
            self.out_dim = solver.output_dimension ** 2  # Next token set comes out
            self.state_dim = solver.grid_dimension ** 2  # The dimension of the state space of the governing eq
        self.seed = solver.seed
        self.model = nn.Linear(self.in_dim, self.out_dim, bias=False, dtype=torch.float64)

    def forward(self, x):
        y = self.model(x)
        return y

    def save(self, epoch, optimizer):
        torch.save({
            'epoch': epoch,
            'model': self.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, f'output/{self.seed}/{self.seed}.pt')

    def load(self, seed):
        checkpoint = torch.load(f'output/{seed}/{seed}.pt')
        self.load_state_dict(checkpoint['model'])
        print("Loaded ", seed, " with epoch ", checkpoint['epoch'])

    @staticmethod
    def loss_function(x, y):
        return functional.mse_loss(x, y)


def get_error(arrow, target):
    # err = (arrow - target) ** 2
    err = np.abs(arrow - target)
    return np.log10(err.sum() / err.size)
