import numpy as np
import torch

import gpytorch
from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import GaussianLikelihood

from due.dkl import DKL_GP, GP, initial_values_for_GP
from due.fc_resnet import FCResNet

import d4rl
import gym
from utils import check_or_make_folder
import sklearn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class DUEEstimator:
    def __init__(self,
                 env_name='hopper-medium-expert-v0',
                 target='rewards',  # 'next_observations'
                 batch_size=1024,
                 n_inducing_points=200,
                 features=256,
                 depth=6,  # 2, 4, 6, 8
                 kernel="RBF",  # RQ, Matern12, Matern32, Matern52
                 spectral_normalization=True,
                 coeff=2,  # [1, 2, 3]
                 n_power_iterations=1,
                 dropout_rate=0.0,
                 lr=1e-3,
                 train_epochs=500,
                 ):
        env = gym.make(env_name)
        dataset = d4rl.qlearning_dataset(env)
        X_train, y_train = np.concatenate((dataset['observations'], dataset['actions']), axis=1), dataset[target]
        input_dim = dataset['observations'].shape[1] + dataset['actions'].shape[1]
        num_outputs = 1 if target == 'rewards' else dataset['next_observations'].shape[1]

        self.ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                                       torch.from_numpy(y_train).float())
        self.dl_train = torch.utils.data.DataLoader(self.ds_train, batch_size=batch_size, shuffle=True,
                                                    drop_last=True)

        feature_extractor = FCResNet(
            input_dim=input_dim,
            features=features,
            depth=depth,
            spectral_normalization=spectral_normalization,
            coeff=coeff,
            n_power_iterations=n_power_iterations,
            dropout_rate=dropout_rate
        )
        initial_inducing_points, initial_lengthscale = initial_values_for_GP(
            self.ds_train, feature_extractor, n_inducing_points
        )
        gp = GP(
            num_outputs=num_outputs,
            initial_lengthscale=initial_lengthscale,
            initial_inducing_points=initial_inducing_points,
            kernel=kernel,
        )

        self.model = DKL_GP(feature_extractor, gp).to(device)
        self.likelihood = GaussianLikelihood().to(device)

        self.lr = lr
        # self.train(train_epochs)

    def train(self, train_epochs):
        loss_fn = VariationalELBO(self.likelihood, self.model.gp, num_data=len(self.ds_train))
        # self.lr = 1e-3
        parameters = [
            {"params": self.model.feature_extractor.parameters(), "lr": self.lr},
            {"params": self.model.gp.parameters(), "lr": self.lr},
            {"params": self.likelihood.parameters(), "lr": self.lr},
        ]

        # TODO: include scheduler
        optimizer = torch.optim.Adam(parameters, weight_decay=5e-4)

        self.model.train()
        self.likelihood.train()
        for i in range(train_epochs):
            avg_loss = []
            for x, y in self.dl_train:
                x, y = x.to(device), y.to(device)
                y_pred = self.model(x)

                optimizer.zero_grad()
                loss = -loss_fn(y_pred, y)
                loss.backward()
                optimizer.step()
                avg_loss.append(loss.item())

            if i % 1 == 0:
                print(f"Test Results - Epoch: {i}, Train Likelihood: {np.mean(avg_loss):.2f}")
                self.save(f"{i}_{np.mean(avg_loss):.2f}")

    def predict_var(self, obs):
        self.model.eval()
        self.likelihood.eval()

        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            observed_pred = self.likelihood(self.model(obs))

        return observed_pred.variance

    def save(self, suffix=None):
        """
        Method to save model after training is completed
        """
        print("Saving model checkpoint...")
        check_or_make_folder("./checkpoints")
        check_or_make_folder("./checkpoints/due_saved_weights_duezm")
        # Create a dictionary with pytorch objects we need to save, starting with models
        torch_state_dict = {
            'model': self.model.state_dict(),
            'likelihood': self.likelihood.state_dict(),
        }
        # Save Torch files
        fn = 'due_model_weights'
        if suffix is not None:
            fn = fn + '_' + suffix
        torch.save(torch_state_dict, f"./checkpoints/due_saved_weights_duezm/{fn}.pt")

    def load(self, model_file):
        """
        Method to load model from checkpoint folder
        """
        print("Loading model from checkpoint...")
        torch_state_dict = torch.load(model_file, map_location=device)
        self.model.load_state_dict(torch_state_dict['model'])
        self.likelihood.load_state_dict(torch_state_dict['likelihood'])


class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self,
                 num_tasks=1,
                 num_inducing_points=200,
                 num_inputs=1,
                 base_kernel_type=gpytorch.kernels.RQKernel,
                 inducing_points=None,
                 ):
        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        if inducing_points is None:
            print(num_inducing_points)
            inducing_points = torch.rand(num_tasks, num_inducing_points, num_inputs)

        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_tasks])
        )

        # We have to wrap the VariationalStrategy in a MultitaskVariationalStrategy
        # so that the output will be a MultitaskMultivariateNormal rather than a batch output
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ), num_tasks=num_tasks
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch so we learn a different set of hyperparameters
        # self.mean_module = gpytorch.means.ConstantMean(
        #     batch_shape=torch.Size([num_tasks])
        # )
        self.mean_module = gpytorch.means.ZeroMean(
            batch_shape=torch.Size([num_tasks])
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel_type(batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks]),
            ard_num_dims=num_inputs,
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class SparseGPEstimator:
    def __init__(self,
                 env_name='hopper-medium-expert-v0',
                 target='rewards',  # 'next_observations'
                 batch_size=1024,
                 n_inducing_points=200,
                 lr=1e-3,
                 train_epochs=500,
                 ):
        env = gym.make(env_name)
        dataset = d4rl.qlearning_dataset(env)
        X_train, y_train = np.concatenate((dataset['observations'], dataset['actions']), axis=1), dataset[target]
        input_dim = dataset['observations'].shape[1] + dataset['actions'].shape[1]
        num_outputs = 1 if target == 'rewards' else dataset['next_observations'].shape[1]

        self.ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                                       torch.from_numpy(y_train).float())
        self.dl_train = torch.utils.data.DataLoader(self.ds_train, batch_size=batch_size, shuffle=True,
                                                    drop_last=True)

        print(np.mean(y_train))
        print(np.min(y_train))

        # kmeans = sklearn.cluster.MiniBatchKMeans(
        #     n_clusters=n_inducing_points, batch_size=n_inducing_points * 10
        # )
        # kmeans.fit(X_train)
        # initial_inducing_points = torch.from_numpy(kmeans.cluster_centers_)

        self.model = MultitaskGPModel(num_inducing_points=n_inducing_points,
                                      num_tasks=num_outputs,
                                      num_inputs=input_dim,
                                      # inducing_points=initial_inducing_points,
                                      ).to(device)
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_outputs).to(device)

        self.lr = lr

    def train(self, train_epochs):
        loss_fn = VariationalELBO(self.likelihood, self.model, num_data=len(self.ds_train))

        optimizer = torch.optim.Adam([
            {"params": self.model.parameters()},
            {"params": self.likelihood.parameters()},
        ], lr=self.lr)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3)

        self.model.train()
        self.likelihood.train()

        for i in range(train_epochs):
            avg_loss = []
            for x, y in self.dl_train:
                x, y = x.to(device), y.to(device)
                y_pred = self.model(x)

                optimizer.zero_grad()
                loss = -loss_fn(y_pred, y)
                loss.backward()
                optimizer.step()
                avg_loss.append(loss.item())

            # if i > 0 and i % 50 == 0:
            #     lr_scheduler.step()

            if i % 10 == 0:
                print(f"Test Results - Epoch: {i}, Train Likelihood: {np.mean(avg_loss):.2f}")
                self.save(f"{i}_{np.mean(avg_loss):.2f}")

    def predict_var(self, obs):
        self.model.eval()
        self.likelihood.eval()

        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            observed_pred = self.likelihood(self.model(obs))

        return observed_pred.variance

    def save(self, suffix=None):
        """
        Method to save model after training is completed
        """
        print("Saving model checkpoint...")
        check_or_make_folder("./checkpoints")
        check_or_make_folder("./checkpoints/sparse_saved_weights_duecm")
        # Create a dictionary with pytorch objects we need to save, starting with models
        torch_state_dict = {
            'model': self.model.state_dict(),
            'likelihood': self.likelihood.state_dict(),
        }
        # Save Torch files
        fn = 'sparse_model_weights'
        if suffix is not None:
            fn = fn + '_' + suffix
        torch.save(torch_state_dict, f"./checkpoints/sparse_saved_weights_duecm/{fn}.pt")

    def load(self, model_file):
        """
        Method to load model from checkpoint folder
        """
        print("Loading model from checkpoint...")
        torch_state_dict = torch.load(model_file, map_location=device)
        self.model.load_state_dict(torch_state_dict['model'])
        self.likelihood.load_state_dict(torch_state_dict['likelihood'])


if __name__ == '__main__':
    #next_observations
    #rewards
    # d = SparseGPEstimator(target='rewards')
    d = DUEEstimator(target='rewards')
    d.train(2000)
