import numpy as np
import matplotlib.pyplot as plt
import torch

import IPython

class FeedforwardRepresentation(torch.nn.Module):
    def __init__(self, input_size, hidden_size, depth = 1, MLP=True):
        super(FeedforwardRepresentation, self).__init__()
        self.MLP = MLP
        self.input_size = input_size
        self.sigmoid = torch.nn.Sigmoid()

        if self.MLP:


            # self.input_size = input_size
            self.hidden_size = hidden_size

            self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
            self.relu = torch.nn.ReLU()
            self.fc2 = torch.nn.Linear(self.hidden_size, self.hidden_size)
            self.fc3 = torch.nn.Linear(self.hidden_size, 1)

            # self.input_size = input_size
            # self.hidden_size = hidden_size
            # self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
            # self.relu = torch.nn.ReLU()

            # self.fc2 = torch.nn.Linear(self.hidden_size, 1)
            # self.sigmoid = torch.nn.Sigmoid()

        else:
            self.fc1 = torch.nn.Linear(self.input_size, 1, bias=False)

    def forward(self, x, inverse_data_covariance=[], alpha=0):
        if self.MLP:
            hidden1 = self.fc1(x)
            hidden1 = self.relu(hidden1)
            hidden = self.fc2(hidden1)

            representation = self.relu(hidden)

            output = self.fc3(representation)


            # return output, relu
        else:
            representation = x
            output = self.fc1(x)

        if len(inverse_data_covariance) != 0:
            # IPython.embed()
            # raise ValueError("asdlfkm")
            output = torch.squeeze(output) + alpha * torch.sqrt(
                torch.matmul(
                    representation,
                    torch.matmul(inverse_data_covariance.float(), representation.t()),
                ).diag()
            )

        #output = self.sigmoid(output)

        return representation




class Feedforward(torch.nn.Module):
    def __init__(self, input_size, hidden_size, MLP=True):
        super(Feedforward, self).__init__()
        self.MLP = MLP
        self.input_size = input_size
        self.sigmoid = torch.nn.Sigmoid()

        if self.MLP:

            # self.input_size = input_size
            self.hidden_size = hidden_size

            self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
            self.relu = torch.nn.ReLU()
            self.fc2 = torch.nn.Linear(self.hidden_size, self.hidden_size)
            self.fc3 = torch.nn.Linear(self.hidden_size, 1)
            # self.sigmoid = torch.nn.Sigmoid()

        else:
            self.fc1 = torch.nn.Linear(self.input_size, 1, bias=False)


    def reset_weights(self):
        # torch.nn.init.xavier_uniform(self.fc1.weight.data)
        # torch.nn.init.xavier_uniform(self.fc2.weight.data)
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()        
        self.fc3.reset_parameters()

    def forward(self, x, inverse_data_covariance=[], alpha=0):
        if self.MLP:
            hidden1 = self.fc1(x)
            hidden1 = self.relu(hidden1)
            hidden = self.fc2(hidden1)
            representation = self.relu(hidden)
            output = self.fc3(representation)

            # return output, relu
        else:
            representation = x
            output = self.fc1(x)

        if len(inverse_data_covariance) != 0:
            # IPython.embed()
            # raise ValueError("asdlfkm")
            output = torch.squeeze(output) + alpha * torch.sqrt(
                torch.matmul(
                    representation,
                    torch.matmul(inverse_data_covariance.float(), representation.t()),
                ).diag()
            )

        #output = self.sigmoid(output)

        return output, representation


class TorchRewardsModel:
    def __init__(
        self,
        random_init=False,
        fit_intercept=True,
        dim=None,
        MLP=True,
        representation_layer_size=100,
    ):
        self.fit_intercept = fit_intercept
        self.random_init = random_init
        self.MLP = MLP
        self.representation_layer_size = representation_layer_size
        self.criterion = torch.nn.MSELoss()
        self.criterion_l1 = torch.nn.L1Loss()


        if dim == None:
            raise ValueError("dimension is none")

        self.network = Feedforward(dim, representation_layer_size, MLP)


    def get_reward(self, batch_X):
        batch_X = torch.from_numpy(batch_X)
        rewards, _ = self.network(
            batch_X.float(),
        )  

        return torch.squeeze(rewards)


    def reset_weights(self):
        self.network.reset_weights()

    def get_loss(self, batch_X, batch_y, l2_regularizer = 0, range_regularizer = 0):
        batch_X = torch.from_numpy(batch_X)
        batch_y = torch.from_numpy(batch_y)
        predictions, _ = self.network(batch_X.float())  # .squeeze()


        l2_reg = torch.tensor(0.)

        for param in self.network.parameters():
            l2_reg += torch.norm(param,1)



        #prediction_range_clipping_loss = torch.mean(torch.log(torch.exp(predictions-1) + torch.exp(-predictions-1)))

        #print("prediction range clipping loss ", prediction_range_clipping_loss)
        #print("l1 norms ", l2_reg)
        return self.criterion(
            torch.squeeze(predictions), torch.squeeze(batch_y.float()) 
        ) + l2_regularizer*l2_reg #+ range_regularizer*prediction_range_clipping_loss




    def get_loss_l1(self, batch_X, batch_y, l2_regularizer = 0, range_regularizer = 0):
        batch_X = torch.from_numpy(batch_X)
        batch_y = torch.from_numpy(batch_y)
        predictions, _ = self.network(batch_X.float())  # .squeeze()


        l2_reg = torch.tensor(0.)

        for param in self.network.parameters():
            l2_reg += torch.norm(param,1)



        return self.criterion_l1(
            torch.squeeze(predictions), torch.squeeze(batch_y.float()) 
        ) + l2_regularizer*l2_reg 






class FeedforwardMultiLayerRepresentation(torch.nn.Module):
    def __init__(self, input_size, hidden_sizes, activation_type = "sigmoid", batch_norm = False, device = torch.device("cpu")):
        super(FeedforwardMultiLayerRepresentation, self).__init__()
        self.input_size = input_size
        self.batch_norm = batch_norm
        self.sigmoid = torch.nn.Sigmoid()
        self.hidden_sizes = hidden_sizes
        if activation_type == "sigmoid":
            self.activation = torch.nn.Sigmoid()
        elif activation_type == "relu":
            self.activation = torch.nn.ReLU()
        elif activation_type == "leaky_relu":
            self.activation = torch.nn.LeakyReLU()
        else:
            raise ValueError("Unrecognized activation type.")

        self.layers = torch.nn.ModuleList()
        

        self.layers = self.layers.append(torch.nn.Linear(self.input_size, self.hidden_sizes[0]))

        for i in range(len(self.hidden_sizes)-1):
            self.layers.append(torch.nn.Linear(self.hidden_sizes[i], self.hidden_sizes[i+1]))

        self.layers.to(device)

        if self.batch_norm:
            self.batch_norms = torch.nn.ModuleList()
            self.batch_norms.append(torch.nn.BatchNorm1d(self.hidden_sizes[0]))
            for i in range(len(self.hidden_sizes)-1):
                self.batch_norms.append(torch.nn.BatchNorm1d(self.hidden_sizes[i+1]))
            self.batch_norms.to(device)

    def forward(self, x):
        representation = x
        for i in range(len(self.layers)):
            representation = self.layers[i](representation)
            if self.batch_norm:
                representation = self.batch_norms[i](representation)
            representation = self.activation(representation)

        return representation





class AutoEncoder:
    def __init__(
        self,
        dim=None,
        encoder_representation_layer_sizes=[10],
        activation_type = 'sigmoid',
        batch_norm = False,
        device = torch.device("cpu")

    ):
        self.dim = dim
        self.device = device
        self.encoder_representation_layer_sizes = encoder_representation_layer_sizes
        self.encoder_output_dimension = encoder_representation_layer_sizes[-1]
        self.decoder_representation_layer_sizes = list(reversed(encoder_representation_layer_sizes[:-1])) + [self.dim]

        self.criterion = torch.nn.MSELoss()

        if dim == None:
            raise ValueError("dimension was set to None")

        self.encoder = FeedforwardMultiLayerRepresentation(self.dim, self.encoder_representation_layer_sizes, 
            activation_type = activation_type, batch_norm = batch_norm, device = device)
        self.decoder = FeedforwardMultiLayerRepresentation(self.encoder_output_dimension, self.decoder_representation_layer_sizes, 
            activation_type = activation_type, batch_norm = batch_norm, device = device)

        self.encoder_decoder = torch.nn.Sequential(self.encoder, self.decoder)

    def get_loss(self, batch_X):
        batch_X = torch.from_numpy(batch_X).to(self.device)
        # encoded_batch = self.encoder(batch_X.float())
        # decoded_batch = self.decoder(encoded_batch)
        
        decoded_batch = self.encoder_decoder(batch_X.float())


        return self.criterion( 
            torch.squeeze(decoded_batch), torch.squeeze(batch_X.float()) 

            ) #+ l2_lambda*l2_reg


    def reconstruction(self, batch_X):
        batch_X = torch.from_numpy(batch_X).to(self.device)     
        decoded_batch = self.encoder_decoder(batch_X.float())
        return decoded_batch

    def encode(self, batch_X):
        batch_X = torch.from_numpy(batch_X).to(self.device)         
        encoded_batch = self.encoder(batch_X.float())
        return  encoded_batch



class TorchRewardsModelMultilayer:
    def __init__(
        self,
        dim=None,
        representation_layer_sizes=[10],
        activation_type = 'sigmoid',
        batch_norm = False,
        device = torch.device("cpu"),
        logistic = False
    ):

        self.device = device
        self.dim = dim
        self.criterion = torch.nn.MSELoss()
        self.criterion_l1 = torch.nn.L1Loss()
        self.criterion_logistic = torch.nn.BCELoss()
        self.representation_layer_sizes  = representation_layer_sizes
        self.logistic = logistic

        if dim == None:
            raise ValueError("dimension is none")


        self.network = FeedforwardMultiLayerRepresentation(self.dim, self.representation_layer_sizes + [1], 
            activation_type = activation_type, batch_norm = batch_norm, device = device)


        if self.logistic:
            self.network = torch.nn.Sequential(self.network, torch.nn.Sigmoid())

        self.network.to(self.device)




    def get_reward(self, batch_X):
        batch_X = torch.from_numpy(batch_X).to(self.device)
        rewards = self.network(
            batch_X.float(),
        )  

        return torch.squeeze(rewards)




    def reset_weights(self):
        self.network.reset_weights()




    def get_logistic_loss(self, batch_X, batch_y):
        batch_X = torch.from_numpy(batch_X).to(self.device)
        batch_y = torch.from_numpy(batch_y).to(self.device)
        predictions = self.network(batch_X.float())  
        #IPython.embed()
        return self.criterion_logistic(
            torch.squeeze(predictions), torch.squeeze(batch_y.float()) 
        ) 






    def get_loss(self, batch_X, batch_y):
        batch_X = torch.from_numpy(batch_X).to(self.device)
        batch_y = torch.from_numpy(batch_y).to(self.device)
        predictions = self.network(batch_X.float())  
        return self.criterion(
            torch.squeeze(predictions), torch.squeeze(batch_y.float()) 
        ) 




    def get_loss_l1(self, batch_X, batch_y):
        batch_X = torch.from_numpy(batch_X).to(self.device)
        batch_y = torch.from_numpy(batch_y).to(self.device)
        predictions = self.network(batch_X.float())  # .squeeze()


        return self.criterion_l1(
            torch.squeeze(predictions), torch.squeeze(batch_y.float()) 
        ) 






