from src.predictor import Predictor
from torch import nn
import torch


class NetworkPredictor(Predictor):

    def __init__(self, n_variables, observation_size, n_hidden=64, init_seed=None,
                 predict_x_u=False, train_W_sq=True, x_u_shift=7, **kwargs):
        super().__init__(n_variables, observation_size, predict_x_u, init_seed, **kwargs)
        self.train_W_sq = train_W_sq
        self.n_hidden = n_hidden
        self.parameters = []
        n_inputs = observation_size
        n_outputs = n_variables + n_variables ** 2
        if init_seed is not None:
            torch.manual_seed(init_seed)
        self.model = nn.Sequential(nn.Linear(n_inputs, n_hidden,  dtype=torch.float64), nn.LeakyReLU(),
                                   nn.Linear(n_hidden, n_hidden, dtype=torch.float64), nn.LeakyReLU(),
                                   nn.Linear(n_hidden, n_outputs,  dtype=torch.float64))
        self.parameters = list(self.model.parameters())
        self.x_u_shift = x_u_shift

    def predict(self, obs_batch):
        prediction = self.model(obs_batch)
        if self.predict_x_u:
            x_u, v_W_sq = (prediction[:, :self.n_variables],
                           prediction[:, self.n_variables:].reshape((-1, self.n_variables, self.n_variables)))
            x_u = x_u + self.x_u_shift
            if self.train_W_sq:
                v_W_sq = v_W_sq
            else:
                v_W_sq = torch.stack([torch.eye(self.n_variables, dtype=torch.float64)] * len(obs_batch))
            Q = v_W_sq @ v_W_sq.swapdims(1, 2)
            w_lin = (2 * x_u @ Q)[:, 0]
            prediction_dict = {'w_lin': w_lin,
                               'x_u': x_u,
                               'W_sq': v_W_sq}
        else:
            v_w_lin, v_W_sq = (prediction[:, :self.n_variables],
                               prediction[:, self.n_variables:].reshape((-1, self.n_variables, self.n_variables)))
            v_w_lin = v_w_lin + self.x_u_shift * 2
            if self.train_W_sq:
                v_W_sq = v_W_sq
            else:
                v_W_sq = torch.stack([torch.eye(self.n_variables, dtype=torch.float64)] * len(obs_batch))
            prediction_dict = {'w_lin': v_w_lin,
                               'W_sq': v_W_sq}

        return prediction_dict

    def create_copy(self):
        predictor_copy = NetworkPredictor(self.n_variables, self.observation_size, self.n_hidden,
                                          self.init_seed, self.predict_x_u)
        predictor_copy.model.load_state_dict(self.model.state_dict())
        return predictor_copy

    def save(self, path):
        torch.save(self.model.state_dict(), path + "/model")

    def load(self, path):
        self.model.load_state_dict(torch.load(path + "/model"))
