from src.portfolio_optimization.util import REG, compute_covariance
from src.predictor import Predictor
import torch.nn as nn
import torch


def linear_block(in_channels, out_channels, activation='ReLU'):
    if activation == 'ReLU':
        return nn.Sequential(nn.Linear(in_channels, out_channels, dtype=torch.float32),
                             #nn.BatchNorm1d(out_channels),
                             nn.LeakyReLU()
                             )
    elif activation == 'Sigmoid':
        return nn.Sequential(nn.Linear(in_channels, out_channels, dtype=torch.float32),
                             #nn.BatchNorm1d(out_channels),
                             nn.Sigmoid()
                             )
    elif activation is None:
        return nn.Sequential(nn.Linear(in_channels, out_channels, dtype=torch.float32),
                             #nn.BatchNorm1d(out_channels),
                             )


class PortfolioModel(nn.Module):
    def __init__(self, input_size=20, output_size=1, scale_output=.1, shift_output=0.1):
        # input:  features
        # output: embedding
        super(PortfolioModel, self).__init__()
        self.scale_output = scale_output
        self.shift_output = shift_output
        self.input_size, self.output_size = input_size, output_size
        self.model = nn.Sequential(linear_block(input_size, 100),
                                   linear_block(100, 100),
                                   linear_block(100, output_size, activation='Sigmoid'),)

    def forward(self, x):
        y = self.model(x)
        return (y + self.shift_output) * self.scale_output



class CovarianceModel(nn.Module):
    def __init__(self, n, latent_dim):
        super(CovarianceModel, self).__init__()
        self.n = n
        self.latent_dim = latent_dim
        self.embedding = nn.Embedding(num_embeddings=self.n, embedding_dim=self.latent_dim,
                                      dtype=torch.float32)

    def forward(self, batch_size):
        security_embeddings = self.embedding(torch.LongTensor(range(self.n)))
        security_embeddings = torch.stack([security_embeddings] * batch_size)
        cov = compute_covariance(security_embeddings)
        return cov



class NetworkPredictor(Predictor):
    """ Predictor that learns to predict 'p' and learns the covariance matrix using an embedding model """

    def __init__(self, n_variables, observation_size, init_seed=None,
                 predict_x_u=False, train_W_sq=True, latent_dim=32,
                 x_u_shift=.1, scale_x_u=.1,
                 use_covariance=True, **kwargs):
        observation_size = observation_size * n_variables
        super().__init__(n_variables, observation_size, predict_x_u, init_seed, **kwargs)
        if init_seed is not None:
            torch.manual_seed(init_seed)
        self.latent_dim = latent_dim
        self.scale_x_u = scale_x_u
        self.x_u_shift = x_u_shift
        self.use_covariance = use_covariance
        self.train_W_sq = train_W_sq
        if use_covariance:
            self.model = PortfolioModel(input_size=observation_size, output_size=n_variables,
                                        scale_output=scale_x_u, shift_output=x_u_shift)
            self.covariance_model = CovarianceModel(n=n_variables, latent_dim=latent_dim)
            self.parameters = list(self.model.parameters()) + list(self.covariance_model.parameters())
        else:
            if train_W_sq:
                self.model = PortfolioModel(input_size=observation_size, output_size=n_variables + n_variables ** 2,
                                            scale_output=scale_x_u, shift_output=x_u_shift)
            else:
                self.model = PortfolioModel(input_size=observation_size, output_size=n_variables,
                                            scale_output=scale_x_u, shift_output=x_u_shift)
            self.parameters = list(self.model.parameters())


    def predict(self, obs_batch):
        obs_batch_reshaped = obs_batch.reshape((obs_batch.shape[0], obs_batch.shape[1] * obs_batch.shape[2]))
        batch_size = obs_batch.shape[0]
        #print('obs batch', obs_batch.shape)
        #print('obs_batch_reshaped', obs_batch_reshaped.shape)
        prediction = self.model(obs_batch_reshaped)
        #print('prediction', prediction.shape)
        if self.use_covariance:
            if self.train_W_sq:
                Q = self.covariance_model(batch_size) * (1 - REG) + torch.eye(self.n_variables) * REG
                prediction = prediction.reshape((batch_size, obs_batch.shape[1]))
                W_sq = torch.linalg.cholesky(Q).to(dtype=torch.float32)
                W_sq = W_sq.reshape((-1, self.n_variables, self.n_variables))
            else:
                W_sq = torch.stack([torch.eye(self.n_variables, dtype=torch.float32)] * batch_size)
                W_sq = W_sq.reshape((-1, self.n_variables, self.n_variables))
                Q = torch.stack([torch.eye(self.n_variables, dtype=torch.float32)] * batch_size)
            linear_term = prediction
        else:
            if self.train_W_sq:
                linear_term, W_sq = prediction[:, :self.n_variables], prediction[:, self.n_variables:]
                W_sq = W_sq.reshape(W_sq.shape[0], self.n_variables, self.n_variables)
                W_sq = (W_sq * (1 - REG) +
                        torch.stack([torch.eye(self.n_variables, dtype=torch.float32)] * batch_size) * REG)
                W_sq = W_sq.reshape((-1, self.n_variables, self.n_variables))
                Q = W_sq @ W_sq.swapdims(1, 2)
            else:
                W_sq = torch.stack([torch.eye(self.n_variables, dtype=torch.float32)] * batch_size)
                W_sq = W_sq.reshape((-1, self.n_variables, self.n_variables))
                Q = torch.stack([torch.eye(self.n_variables, dtype=torch.float32)] * batch_size)
                linear_term = prediction[:, :self.n_variables]

        if self.predict_x_u:
            x_u = linear_term
            w_lin = (x_u @ Q)[:, 0]
            #print('w_lin', w_lin.shape)
            prediction_dict = {'w_lin': w_lin, 'W_sq': W_sq, 'x_u': x_u}
        else:
            w_lin = linear_term
            prediction_dict = {'w_lin': w_lin, 'W_sq': W_sq}
        return prediction_dict

    def create_copy(self):
        predictor_copy = NetworkPredictor(self.n_variables, self.observation_size, self.init_seed,
                                          self.predict_x_u, self.train_W_sq, self.latent_dim, self.scale_x_u,
                                          self.use_covariance)
        predictor_copy.model.load_state_dict(self.model.state_dict())
        predictor_copy.covariance_model.load_state_dict(self.covariance_model.state_dict())
        return predictor_copy

    def save(self, path):
        torch.save(self.model.state_dict(), path + "/portfolio_model")
        if self.use_covariance:
            torch.save(self.covariance_model.state_dict(), path + "/covariance_model")

    def load(self, path):
        self.model.load_state_dict(torch.load(path + "/portfolio_model"))
        if self.use_covariance:
            self.covariance_model.load_state_dict(torch.load(path + "/covariance_model"))
