import torch
import gpytorch
import time
import numpy as np

from meta_learn.models import LearnedGPRegressionModel, NeuralNetwork, AffineTransformedDistribution
from meta_learn.util import _handle_input_dimensionality, DummyLRScheduler
from meta_learn.abstract import RegressionModelMetaLearned
from config import device

class GPRegressionEval(RegressionModelMetaLearned):

    def __init__(self, meta_train_data, learning_mode='both', lr_params=1e-3, weight_decay=0.0, feature_dim=2,
                 num_iter_fit=10000, covar_module='NN', mean_module='NN', mean_nn_layers=(32, 32), kernel_nn_layers=(32, 32),
                 task_batch_size=1, normalize_data=True, optimizer='Adam', lr_decay=0.0, random_seed=None,alpha = 0.02, beta = 0.02,
                 train_number = 5, val_number = 45, theorem = 5, get_markov = True, only_delta = True, only_psi = False):
        """
        Meta-Learning GP priors (i.e. mean and kernel function) via PACOH-MAP

        Args:
            meta_train_data: list of tuples of ndarrays[(train_x_1, train_t_1), ..., (train_x_n, train_t_n)]
            learning_mode: (str) specifying which of the GP prior parameters to optimize. Either one of
                    ['learned_mean', 'learned_kernel', 'both', 'vanilla']
            lr_params: (float) learning rate for GP prior parameters
            weight_decay: (float) weight decay multiplier for meta-level regularization
            feature_dim: (int) output dimensionality of NN feature map for kernel function
            num_iter_fit: (int) number of gradient steps for fitting the parameters
            covar_module: (gpytorch.mean.Kernel) optional kernel module, default: RBF kernel
            mean_module: (gpytorch.mean.Mean) optional mean module, default: ZeroMean
            mean_nn_layers: (tuple) hidden layer sizes of mean NN
            kernel_nn_layers: (tuple) hidden layer sizes of kernel NN
            learning_rate: (float) learning rate for AdamW optimizer
            task_batch_size: (int) batch size for meta training, i.e. number of tasks for computing gradients
            optimizer: (str) type of optimizer to use - must be either 'Adam' or 'SGD'
            lr_decay: (str) multiplicative learning rate decay applied every 1000 iterations
            random_seed: (int) seed for pytorch
        """
        super().__init__(normalize_data, random_seed)

        assert learning_mode in ['learn_mean', 'learn_kernel', 'both', 'vanilla']
        assert mean_module in ['NN', 'constant', 'zero'] or isinstance(mean_module, gpytorch.means.Mean)
        assert covar_module in ['NN', 'SE'] or isinstance(covar_module, gpytorch.kernels.Kernel)
        assert optimizer in ['Adam', 'SGD']
        self.alpha = alpha
        self.beta = beta
        self.train_number = train_number
        self.val_number = val_number
        self.theorem = theorem
        self.get_markov = get_markov
        self.only_delta = only_delta
        self.only_psi = only_psi    

        self.lr_params, self.weight_decay, self.feature_dim = lr_params, weight_decay, feature_dim
        self.num_iter_fit, self.task_batch_size, self.normalize_data = num_iter_fit, task_batch_size, normalize_data

        # Check that data all has the same size
        self._check_meta_data_shapes(meta_train_data)
        self._compute_normalization_stats(meta_train_data)

        # Setup components that are shared across tasks
        self._setup_gp_prior(mean_module, covar_module, learning_mode, feature_dim, mean_nn_layers, kernel_nn_layers)
        if self.theorem == 5:
            ## Theorem 5 uses a subset of data, S' \in S to train the base learner
            self.likelihood = gpytorch.likelihoods.GaussianLikelihood(alpha = 1/self.alpha,
                noise_constraint=gpytorch.likelihoods.noise_models.GreaterThan(1e-3)).to(device)
        else:
            ## Theorem 4 PACMAML uses S' = S
            # import pdb
            # pdb.set_trace()
            self.likelihood = gpytorch.likelihoods.GaussianLikelihood(alpha = self.train_number/self.beta,
                noise_constraint=gpytorch.likelihoods.noise_models.GreaterThan(1e-3)).to(device)
        self.shared_parameters.append({'params': self.likelihood.parameters(), 'lr': self.lr_params})
        # [self.shared_parameters.append(i) for i in self.likelihood.parameters()]

        # Setup components that are different across tasks
        self.task_dicts = []
        # :self.train_number = np.random.choice(range(self.train_number+self.val_number), self.train_number, replace = False)
        for train_x, train_y in meta_train_data:
            task_dict = {}

            # a) prepare data
            x_tensor, y_tensor = self._prepare_data_per_task(train_x, train_y)
            task_dict['train_x'], task_dict['train_y'] = x_tensor, y_tensor

            # import pdb
            # pdb.set_trace()
            # b) prepare model
            task_dict['model'] = LearnedGPRegressionModel(task_dict['train_x'][:self.train_number], task_dict['train_y'][:self.train_number], self.likelihood,
                                              learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                              covar_module=self.covar_module, mean_module=self.mean_module)
            # task_dict['refmodel'] = LearnedGPRegressionModel(task_dict['train_x'][:self.train_number], task_dict['train_y'][:self.train_number], self.likelihood,
            #                                   learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
            #                                   covar_module=self.covar_module, mean_module=self.mean_module)
            # task_dict['refmodel'].eval()
            task_dict['mll_fn'] = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, task_dict['model']).to(device)

            self.task_dicts.append(task_dict)

        # c) prepare inference
        self._setup_optimizer(optimizer, lr_params, lr_decay)

        self.fitted = False


    def meta_fit(self, valid_tuples=None, verbose=True, log_period=500, n_iter=None):
        """
        meta-learns the GP prior parameters

        Args:
            valid_tuples: list of valid tuples, i.e. [(test_context_x_1, test_context_t_1, test_x_1, test_t_1), ...]
            verbose: (boolean) whether to print training progress
            log_period: (int) number of steps after which to print stats
            n_iter: (int) number of gradient descent iterations
        """


        for task_dict in self.task_dicts: task_dict['model'].train()
        self.likelihood.train()

        assert (valid_tuples is None) or (all([len(valid_tuple) == 4 for valid_tuple in valid_tuples]))
        all_train = []
        all_val = []
        all_l2 = []
        if len(self.shared_parameters) > 0:
            t = time.time()
            cum_loss = 0.0

            if n_iter is None:
                n_iter = self.num_iter_fit
            

            """
            Below we compute l(h,D_i) actually, since here m is large.
            The first m_i can be used to compute l(h,S_i) and then resampling to compute the Markov term
            """
            l_pcal = []
            if self.only_psi:
                for task_dict in self.task_dicts:
                    gp_model = LearnedGPRegressionModel(None, None, self.likelihood,
                                                            learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                                            covar_module=self.covar_module, mean_module=self.mean_module)
                    gp_model.eval()
                    self.likelihood.eval()
                    pred_T = gp_model(task_dict['train_x'])
                    m, K = pred_T.mean, pred_T.covariance_matrix
                    l_pcal.append(torch.square(task_dict['train_y'] - m).detach().numpy())
                return l_pcal

            
            rstp_pcal = []
            # for task_dict in self.task_dicts:
            #     ## Build GP_model on m_i
            #     # task_dict['train_x'], indices = torch.sort(task_dict['train_x'],dim=0).values, torch.sort(task_dict['train_x'],dim=0).indices
            #     # task_dict['train_y'] = task_dict['train_y'][indices.reshape(-1)]

            #     n_groups = self.train_number
            #     group_length = int(len(task_dict['train_x'])/self.train_number)

            #     for s in range(10):
            #         starting_pos = np.random.choice(group_length)
            #         # import pdb
            #         # pdb.set_trace()
            #         # chosen_train = np.arange(starting_pos,len(task_dict['train_x']),group_length)
                    
            #         chosen_train = np.random.choice(range(len(task_dict['train_x'])),self.train_number,replace=False)
            #         chosen_test  = np.setdiff1d(range(len(task_dict['train_x'])), chosen_train)
            #         gp_model = LearnedGPRegressionModel(task_dict['train_x'][chosen_train], task_dict['train_y'][chosen_train], self.likelihood,
            #                                             learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
            #                                             covar_module=self.covar_module, mean_module=self.mean_module)
            #         gp_model.eval()
            #         self.likelihood.eval()
            #         # import pdb
            #         # pdb.set_trace()
            #         pred_T = gp_model(task_dict['train_x'][chosen_test])
            #         m, K = pred_T.mean, pred_T.covariance_matrix
            #         # return task_dict['train_x'][chosen_train], task_dict['train_y'][chosen_train], task_dict['train_x'][chosen_test], task_dict['train_y'][chosen_test], m
            #         rstp_pcal.append((torch.square(task_dict['train_y'][chosen_test] - m) + torch.trace(K)/(len(m))).detach().numpy())

            # import pdb
            # pdb.set_trace()

            rtp_pcal = []
            # for i in range(len(valid_tuples)):
            #     T_train_x, T_train_y = self._prepare_data_per_task(valid_tuples[i][0],valid_tuples[i][1])
            #     T_test_x, T_test_y = self._prepare_data_per_task(valid_tuples[i][2],valid_tuples[i][3])
            #     T_x = torch.cat((T_train_x, T_test_x))
            #     T_y = torch.cat((T_train_y, T_test_y))
            #     for s in range(30):
            #         chosen_train = np.random.choice(range(self.train_number),self.train_number,replace=False)
            #         chosen_test  = np.setdiff1d(range(len(T_x)), chosen_train)
            #         gp_model = LearnedGPRegressionModel(T_train_x, T_train_y, self.likelihood,
            #                                             learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
            #                                             covar_module=self.covar_module, mean_module=self.mean_module)
            #         gp_model.eval()
            #         self.likelihood.eval()
                    
            #         pred_T = gp_model(T_test_x)
            #         m, K = pred_T.mean, pred_T.covariance_matrix
            #         return T_test_x, T_test_y, m
            #         rtp_pcal.append((torch.square(T_test_y - m) + torch.trace(K)/(len(m))).detach().numpy())
            # for task_dict in self.task_dicts:
            #     ## Build GP_model on m
            #     gp_model = LearnedGPRegressionModel(task_dict['train_x'][:5], task_dict['train_y'][:5], self.likelihood,
            #                                         learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
            #                                         covar_module=self.covar_module, mean_module=self.mean_module)
            #     gp_model.eval()
            #     self.likelihood.eval()
            #     pred_T = gp_model(task_dict['train_x'][5:])
            #     m, K = pred_T.mean, pred_T.covariance_matrix
            #     rtp_pcal.append(torch.square(task_dict['train_y'][5:] - m) + torch.trace(K)/(len(m)))
            if self.only_delta:
                return rstp_pcal, rtp_pcal

            rstpmarkov = []
            if self.get_markov:
                for task_dict in self.task_dicts:
                    for s in range(20):
                        chosen_train = np.random.choice(range(len(task_dict['train_x'])),self.train_number,replace=False)
                        chosen_test  = np.setdiff1d(range(len(task_dict['train_x'])), chosen_train)
                        gp_model = LearnedGPRegressionModel(task_dict['train_x'][chosen_train], task_dict['train_y'][chosen_train], self.likelihood,
                                                            learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                                            covar_module=self.covar_module, mean_module=self.mean_module)
                        gp_model.eval()
                        self.likelihood.eval()
                        pred_T = gp_model(task_dict['train_x'][chosen_test])
                        m, K = pred_T.mean, pred_T.covariance_matrix
                        rstpmarkov.append(torch.square(task_dict['train_y'][chosen_test] - m) + torch.trace(K)/(len(m)))


            # psi = []
            # for i in range(len(valid_tuples)):
            #     valx, valy = self._prepare_data_per_task(valid_tuples[i][0],valid_tuples[i][1])
            #     testx, testy = self._prepare_data_per_task(valid_tuples[i][2],valid_tuples[i][3])
            #     if self.gettest == 'psi':
            #         gp_model = LearnedGPRegressionModel(None, None, self.likelihood,
            #                                             learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
            #                                             covar_module=self.covar_module, mean_module=self.mean_module)
            #     elif self.gettest == 'delta':
            #         gp_model = LearnedGPRegressionModel(valx, valy, self.likelihood,
            #                                             learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
            #                                             covar_module=self.covar_module, mean_module=self.mean_module)
            #     gp_model.eval()
            #     self.likelihood.eval()


            #     pred_T = gp_model(testx)
            #     m, K = pred_T.mean, pred_T.covariance_matrix
            #     psi.append(torch.square(testy - m) + torch.trace(K)/(len(m)))

            self.likelihood.train()

            var1 = torch.tensor([1,1,1,1,1,1])#torch.tensor([0.333,0.333,0.104,0.104,0.104,0.14])#torch.tensor([torch.var(i) for i in self.nn_mean_fn.parameters()])
            var2 = torch.tensor([1,1,1,1,1,1])#torch.tensor([0.333,0.333,0.104,0.104,0.104,0.14])#torch.tensor([torch.var(i) for i in self.nn_kernel_map.parameters()])
            #all_z = []
            for itr in range(1, n_iter + 1):
                loss = 0.0
                rmseloss = 0.0
                self.optimizer.zero_grad()
                for task_dict in self.rds_numpy.choice(self.task_dicts, size=self.task_batch_size):

                    output = task_dict['model'](task_dict['train_x'][:self.train_number])
                    mll = task_dict['mll_fn'](output, task_dict['train_y'][:self.train_number])
                    if self.theorem == 5:
                        
                        ## Obtain Q. Need .eval() to convert to Posterior

                        gp_model = LearnedGPRegressionModel(task_dict['train_x'][:self.train_number], task_dict['train_y'][:self.train_number], self.likelihood,
                                                    learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                                    covar_module=self.covar_module, mean_module=self.mean_module)
                        gp_model.eval()
                        self.likelihood.eval()
                        # task_dict['model'].eval()
                        pred_dist = self.likelihood(gp_model(task_dict['train_x']))
                        m, K = pred_dist.mean, pred_dist.covariance_matrix
                        
                        lV = torch.mean(torch.square(task_dict['train_y'] - m)) + torch.trace(K)/(len(m))
                        
                        if self.train_number > len(m):
                            lS = torch.mean(torch.square(task_dict['train_y'][:self.train_number] - m[:self.train_number])) + torch.trace(K[:self.train_number,:self.train_number])/len(m)
                        else:
                            lS = torch.mean(torch.square(task_dict['train_y'][:self.train_number] - m[:self.train_number])) + torch.mean(torch.diag(K)[:self.train_number])
                        self.likelihood.train()
                        # task_dict['model'].train()



                        loss += (-mll-0.5*torch.log(torch.tensor(3.141592653/self.alpha)))*self.train_number/self.beta \
                         + lV - self.alpha*self.train_number/self.beta*lS
                        # import pdb
                        # pdb.set_trace()
                    
                    elif self.theorem == 4:
                        loss += (-mll-0.5*torch.log(torch.tensor(3.141592653*self.train_number/self.beta)))*self.train_number/self.beta
                        # print(loss)
                        #rmseloss += torch.sqrt(torch.mean(torch.square(task_dict['train_y'] - output.mean)))
                # loss = loss/5
                #all_z.append(loss)
                #all_train.append(torch.mean(rmseloss))
                # loss /= self.task_batch_size
                loss.backward()
                self.optimizer.step()
                self.lr_scheduler.step()

                cum_loss += loss
                # print(torch.sum(torch.tensor([torch.norm(i)**2/var1[count] for count, i in enumerate(self.nn_mean_fn.parameters())])[:-1])+torch.sum(torch.tensor([torch.norm(i)**2/var2[count] for count, i in enumerate(self.nn_kernel_map.parameters())])[:-1]))

                # print training stats stats
                if itr == 1 or itr % log_period == 0:
                    # all_z.append(loss/self.task_batch_size)
                    all_l2.append(torch.sum(torch.tensor([torch.norm(i)**2/var1[count] for count, i in enumerate(self.nn_mean_fn.parameters())])[:-1])+torch.sum(torch.tensor([torch.norm(i)**2/var2[count] for count, i in enumerate(self.nn_kernel_map.parameters())])[:-1]))
                    duration = time.time() - t
                    avg_loss = cum_loss / (log_period if itr > 1 else 1.0)
                    cum_loss = 0.0
                    t = time.time()

                    message = 'Iter %d/%d - Loss: %.6f - Time %.2f sec' % (itr, self.num_iter_fit, avg_loss.item(), duration)

                    # if validation data is provided  -> compute the valid log-likelihood
                    if valid_tuples is not None:
                        self.likelihood.eval()
                        valid_ll, valid_rmse, calibr_err = self.eval_datasets(valid_tuples)
                        self.likelihood.train()
                        message += ' - Valid-LL: %.3f - Valid-RMSE: %.3f - Calib-Err %.3f' % (valid_ll, valid_rmse, calibr_err)

                    if verbose:
                        self.logger.info(message)
                    all_val.append(valid_rmse)

        else:
            self.logger.info('Vanilla mode - nothing to fit')

        self.fitted = True
        all_loss_D = []
        all_loss_S = []
        all_z = []
        all_lab = []
        for task_dict in self.task_dicts: 
            task_dict['model'].eval()
        
        for task_dict in self.task_dicts:
            with torch.no_grad():
                gp_model = LearnedGPRegressionModel(task_dict['train_x'][:self.train_number], task_dict['train_y'][:self.train_number], self.likelihood,
                                                    learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                                    covar_module=self.covar_module, mean_module=self.mean_module)

                gp_model.eval()
                self.likelihood.eval()
                pred_dist = gp_model(task_dict['train_x'])
                m, K = pred_dist.mean, pred_dist.covariance_matrix
                all_loss_D.append(torch.mean(torch.square(task_dict['train_y'] - m))+ torch.trace(K)/(len(m)))
            
            with torch.no_grad():
                gp_model = LearnedGPRegressionModel(task_dict['train_x'][:self.train_number], task_dict['train_y'][:self.train_number], self.likelihood,
                                                    learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                                    covar_module=self.covar_module, mean_module=self.mean_module)
                pred_dist = gp_model(task_dict['train_x'][:self.train_number])
                m, K = pred_dist.mean, pred_dist.covariance_matrix
                all_loss_S.append(torch.mean(torch.square(task_dict['train_y'][:self.train_number] - m[:self.train_number]))+ torch.trace(K[:self.train_number,:self.train_number])/(len(m[:self.train_number])))
                mll_model = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, gp_model)
                logz_ref = -mll_model(gp_model(task_dict['train_x'][:self.train_number]), task_dict['train_y'][:self.train_number])
                
                # import pdb
                # pdb.set_trace()
                if self.theorem == 5:
                    add_alpha = torch.diag(torch.ones(self.train_number)*1/self.alpha)
                else:
                    add_alpha = torch.diag(torch.ones(self.train_number)*self.train_number/self.beta)
                Sigma = torch.inverse(K[:self.train_number,:self.train_number]+add_alpha)
                
                med = -0.5*torch.matmul((task_dict['train_y'][:self.train_number] - m[:self.train_number]).T,Sigma)
                med = torch.matmul(med, (task_dict['train_y'][:self.train_number] - m[:self.train_number]))
                med -= 0.5*torch.logdet(K[:self.train_number,:self.train_number]+add_alpha)
                med -= 0.5*self.train_number*torch.log(torch.tensor(2*3.141592653))
                # import pdb
                # pdb.set_trace()
                if self.theorem == 5:
                    med += 0.5*self.train_number*torch.log(torch.tensor(3.141592653/self.alpha))
                else:
                    med += 0.5*self.train_number*torch.log(torch.tensor(3.141592653*self.train_number/self.beta))
                logz = -med/self.beta#self.train_number
            
                
                all_z.append(logz)
                if self.theorem == 5:
                    self.likelihood.eval()
                    gp_model.eval()
                    pred_dist = self.likelihood(gp_model(task_dict['train_x']))
                    m, K = pred_dist.mean, pred_dist.covariance_matrix

                    lV = torch.mean(torch.square(task_dict['train_y'][:] - m[:])) + torch.trace(K[:,:])/(len(m))
                    if self.train_number > len(m):
                        lS = torch.mean(torch.square(task_dict['train_y'][:self.train_number] - m[:self.train_number])) + torch.trace(K[:self.train_number,:self.train_number])/len(m)
                    else:
                        lS = torch.mean(torch.square(task_dict['train_y'][:self.train_number] - m[:self.train_number])) + torch.trace(K[:self.train_number,:self.train_number])/self.train_number

                    all_lab.append(lV - self.alpha*self.train_number/self.beta*lS)
                else:
                    all_lab.append(0)


        self.likelihood.eval()
        return all_z, all_lab, all_val, all_l2,  all_loss_D, all_loss_S, l_pcal, rstp_pcal, rtp_pcal, rstpmarkov

    def predict(self, context_x, context_y, test_x, return_density=False):
        """
        Performs posterior inference (target training) with (context_x, context_y) as training data and then
        computes the predictive distribution of the targets p(y|test_x, test_context_x, context_y) in the test points

        Args:
            context_x: (ndarray) context input data for which to compute the posterior
            context_y: (ndarray) context targets for which to compute the posterior
            test_x: (ndarray) query input data of shape (n_samples, ndim_x)
            return_density: (bool) whether to return result as mean and std ndarray or as MultivariateNormal pytorch object

        Returns:
            (pred_mean, pred_std) predicted mean and standard deviation corresponding to p(t|test_x, test_context_x, context_y)
        """

        context_x, context_y = _handle_input_dimensionality(context_x, context_y)
        test_x = _handle_input_dimensionality(test_x)
        assert test_x.shape[1] == context_x.shape[1]

        # normalize data and convert to tensor
        context_x, context_y = self._prepare_data_per_task(context_x, context_y)

        test_x = self._normalize_data(X=test_x, Y=None)
        test_x = torch.from_numpy(test_x).float().to(device)
        full_x = torch.cat((context_x,test_x))
        full_y = torch.cat((context_y,torch.zeros(5)))
        with torch.no_grad():
            # compute posterior given the context data
            gp_model = LearnedGPRegressionModel(context_x, context_y, self.likelihood,
                                                learned_kernel=self.nn_kernel_map, learned_mean=self.nn_mean_fn,
                                                covar_module=self.covar_module, mean_module=self.mean_module)
            gp_model.eval()
            self.likelihood.eval()
            
            pred_dist = self.likelihood(gp_model(test_x))
            pred_dist_transformed = AffineTransformedDistribution(pred_dist, normalization_mean=self.y_mean,
                                                                  normalization_std=self.y_std)

        if return_density:
            return pred_dist_transformed
        else:
            pred_mean = pred_dist_transformed.mean
            pred_std = pred_dist_transformed.stddev
            return pred_mean.cpu().numpy(), pred_std.cpu().numpy()

    def state_dict(self):
        state_dict = {
            'optimizer': self.optimizer.state_dict(),
            'model': self.task_dicts[0]['model'].state_dict()
        }
        for task_dict in self.task_dicts:
            for key, tensor in task_dict['model'].state_dict().items():
                assert torch.all(state_dict['model'][key] == tensor).item()
        return state_dict

    def load_state_dict(self, state_dict):
        for task_dict in self.task_dicts:
            task_dict['model'].load_state_dict(state_dict['model'])
        self.optimizer.load_state_dict(state_dict['optimizer'])

    def _setup_gp_prior(self, mean_module, covar_module, learning_mode, feature_dim, mean_nn_layers, kernel_nn_layers):

        self.shared_parameters = []

        # a) determine kernel map & module
        if covar_module == 'NN':
            assert learning_mode in ['learn_kernel', 'both'], 'neural network parameters must be learned'
            self.nn_kernel_map = NeuralNetwork(input_dim=self.input_dim, output_dim=feature_dim,
                                          layer_sizes=kernel_nn_layers).to(device)
            self.shared_parameters.append(
                {'params': self.nn_kernel_map.parameters(), 'lr': self.lr_params, 'weight_decay': self.weight_decay})
            # [self.shared_parameters.append(i) for i in self.nn_kernel_map.parameters()]
            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=feature_dim)).to(device)
        else:
            self.nn_kernel_map = None

        if covar_module == 'SE':
            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=self.input_dim)).to(device)
        elif isinstance(covar_module, gpytorch.kernels.Kernel):
            self.covar_module = covar_module.to(device)

        # b) determine mean map & module

        if mean_module == 'NN':
            assert learning_mode in ['learn_mean', 'both'], 'neural network parameters must be learned'
            self.nn_mean_fn = NeuralNetwork(input_dim=self.input_dim, output_dim=1, layer_sizes=mean_nn_layers).to(device)
            self.shared_parameters.append(
                {'params': self.nn_mean_fn.parameters(), 'lr': self.lr_params, 'weight_decay': self.weight_decay})
            # [self.shared_parameters.append(i) for i in self.nn_mean_fn.parameters()]
            self.mean_module = None
        else:
            self.nn_mean_fn = None

        if mean_module == 'constant':
            self.mean_module = gpytorch.means.ConstantMean().to(device)
        elif mean_module == 'zero':
            self.mean_module = gpytorch.means.ZeroMean().to(device)
        elif isinstance(mean_module, gpytorch.means.Mean):
            self.mean_module = mean_module.to(device)

        # c) add parameters of covar and mean module if desired

        if learning_mode in ["learn_kernel", "both"]:
            self.shared_parameters.append({'params': self.covar_module.hyperparameters(), 'lr': self.lr_params})
            # [self.shared_parameters.append(i) for i in self.covar_module.hyperparameters()]

        if learning_mode in ["learn_mean", "both"] and self.mean_module is not None:
            self.shared_parameters.append({'params': self.mean_module.hyperparameters(), 'lr': self.lr_params})

    def _setup_optimizer(self, optimizer, lr, lr_decay):
        if optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(self.shared_parameters, lr=lr, weight_decay=self.weight_decay)
        elif optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(self.shared_parameters, lr=lr)
        else:
            raise NotImplementedError('Optimizer must be Adam or SGD')

        if lr_decay < 1.0:
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1000, gamma=lr_decay)
        else:
            self.lr_scheduler = DummyLRScheduler()

    def _vectorize_pred_dist(self, pred_dist):
        return torch.distributions.Normal(pred_dist.mean, pred_dist.stddev)

if __name__ == "__main__":
    from experiments.data_sim import GPFunctionsDataset, SinusoidDataset

    data_sim = SinusoidDataset(random_state=np.random.RandomState(29))
    meta_train_data = data_sim.generate_meta_train_data(n_tasks=20, n_samples=10)
    meta_test_data = data_sim.generate_meta_test_data(n_tasks=50, n_samples_context=10, n_samples_test=160)

    NN_LAYERS = (32, 32, 32, 32)

    plot = False
    from matplotlib import pyplot as plt

    if plot:
        for x_train, y_train in meta_train_data:
            plt.scatter(x_train, y_train)
        plt.title('sample from the GP prior')
        plt.show()

    """ 2) Classical mean learning based on mll """

    print('\n ---- GPR mll meta-learning ---- ')

    torch.set_num_threads(2)

    for weight_decay in [0.8, 0.5, 0.4, 0.3, 0.2, 0.1]:
        gp_model = GPRegressionEval(meta_train_data, num_iter_fit=20000, weight_decay=weight_decay, task_batch_size=2,
                                             covar_module='NN', mean_module='NN', mean_nn_layers=NN_LAYERS,
                                             kernel_nn_layers=NN_LAYERS)
        itrs = 0
        print("---- weight-decay =  %.4f ----"%weight_decay)
        for i in range(1):
            gp_model.meta_fit(valid_tuples=meta_test_data, log_period=1000, n_iter=20000)
            itrs += 20000

            x_plot = np.linspace(-5, 5, num=150)
            x_context, t_context, x_test, y_test = meta_test_data[0]
            pred_mean, pred_std = gp_model.predict(x_context, t_context, x_plot)
            ucb, lcb = gp_model.confidence_intervals(x_context, t_context, x_plot, confidence=0.9)

            plt.scatter(x_test, y_test)
            plt.scatter(x_context, t_context)

            plt.plot(x_plot, pred_mean)
            plt.fill_between(x_plot, lcb, ucb, alpha=0.2)
            plt.title('GPR meta mll (weight-decay =  %.4f) itrs = %i' % (weight_decay, itrs))
            plt.show()
 
