import torch
import torch.nn as nn
import gpytorch
from gpytorch import kernels as gp_kernels
from gpytorch.means import Mean
from gpytorch.kernels import Kernel
from gpytorch.distributions import MultitaskMultivariateNormal
from truncnorm.TruncatedNormal import TruncatedNormal
from src.utils.env_dataset import EnvDataset


class BaseMeanModel(nn.Module):

    def __init__(self, input_dim, hidden_layer_width, output_dim):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_layer_width),
            nn.Mish(inplace=True),
            nn.Linear(hidden_layer_width, output_dim),
        )

    def forward(self, x):
        return self.seq(x)


class CustomKernel(Kernel):

    def __init__(self, ard_num_dims=None, **kwargs):
        super(CustomKernel, self).__init__(**kwargs)
        self.matern_loc1 = gp_kernels.MaternKernel(nu=2.5)
        self.matern_loc2 = gp_kernels.MaternKernel(nu=2.5)
        self.matern_loc3 = gp_kernels.MaternKernel(nu=2.5)
    
    def forward(self, x1, x2, diag=False, **params):
        x1_eq_x2 = torch.equal(x1, x2)
        if diag and x1_eq_x2:
            return torch.zeros(*x1.shape[:-2], x1.shape[-2], dtype=x1.dtype, device=x1.device)

        # Extract agent positions and Gaussian centers
        agent_pos1 = x1[..., :2]
        center1_1 = x1[..., 13:15]
        center1_2 = x1[..., 15:17]
        center1_3 = x1[..., 17:19]

        agent_pos2 = x2[..., :2]
        center2_1 = x2[..., 13:15]
        center2_2 = x2[..., 15:17]
        center2_3 = x2[..., 17:19]

        # Compute distances to locs
        dist1_1 = torch.norm(agent_pos1 - center1_1, dim=-1).unsqueeze(-1)
        dist1_2 = torch.norm(agent_pos1 - center1_2, dim=-1).unsqueeze(-1)
        dist1_3 = torch.norm(agent_pos1 - center1_3, dim=-1).unsqueeze(-1)

        dist2_1 = torch.norm(agent_pos2 - center2_1, dim=-1).unsqueeze(-1)
        dist2_2 = torch.norm(agent_pos2 - center2_2, dim=-1).unsqueeze(-1)
        dist2_3 = torch.norm(agent_pos2 - center2_3, dim=-1).unsqueeze(-1)

        return self.matern_loc1(dist1_1, dist2_1) + self.matern_loc2(dist1_2, dist2_2) + self.matern_loc3(dist1_3, dist2_3)
    

class BaseMultitaskGPModel(gpytorch.models.ApproximateGP):

    def __init__(self, input_dim, num_models, num_inducing_points, use_coregionalization=True, separate_reward_cov=False,):
        # Initialize independent inducing points for each task/model
        self.separate_reward_cov = separate_reward_cov
        inducing_points = torch.rand(1, num_models, num_inducing_points, input_dim)
        
        # Set the batch to learn a different variational distribution for each output dimension
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
           num_inducing_points, batch_shape=torch.Size([1, num_models])
        )
        
        # Wrap independent variational distributions together
        if use_coregionalization:
            variational_strategy = gpytorch.variational.LMCVariationalStrategy(
                gpytorch.variational.VariationalStrategy(
                    self, inducing_points, variational_distribution, learn_inducing_locations=True
                ),
                num_tasks=num_models,
                num_latents=num_models,
                latent_dim=-1,
            )
        else:
            variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
                gpytorch.variational.VariationalStrategy(
                    self, inducing_points, variational_distribution, learn_inducing_locations=True
                ),
                num_tasks=num_models
            )
        
        super(BaseMultitaskGPModel, self).__init__(variational_strategy)
        
        self.mean_module = gpytorch.means.ZeroMean(batch_shape=torch.Size([1, num_models]))
        if self.separate_reward_cov:
            kernels = [
                gp_kernels.ScaleKernel(gp_kernels.MaternKernel(nu=2.5, ard_num_dims=input_dim)), # pos x
                gp_kernels.ScaleKernel(gp_kernels.MaternKernel(nu=2.5, ard_num_dims=input_dim)), # pos y
                gp_kernels.ScaleKernel(gp_kernels.MaternKernel(nu=2.5, ard_num_dims=input_dim)), # vel x
                gp_kernels.ScaleKernel(gp_kernels.MaternKernel(nu=2.5, ard_num_dims=input_dim)), # vel y
                gp_kernels.ScaleKernel(gp_kernels.MaternKernel(nu=2.5, ard_num_dims=input_dim)) + gp_kernels.ScaleKernel(CustomKernel()), # sample value
            ]
            self.kernels = torch.nn.ModuleList(kernels)
        else:
            self.covar_module = gp_kernels.ScaleKernel( # learn the noise level in the target values
                gp_kernels.MaternKernel(nu=2.5, ard_num_dims=input_dim, batch_shape=torch.Size([1, num_models])),
                batch_shape=torch.Size([1, num_models])
            )

    def forward(self, state):
        # Called from variational_strategy with [inducing_points, x] full input
        mean = self.mean_module(state)

        if self.separate_reward_cov:
            covars = []
            for i, kernel in enumerate(self.kernels):
                output_covar = kernel(state[:, i, :]).unsqueeze(1)
                covars.append(output_covar)
            covar = gpytorch.lazy.CatLazyTensor(*covars, dim=1, output_device=mean.device)
        else:
            covar = self.covar_module(state)

        return gpytorch.distributions.MultivariateNormal(mean, covar)


class MultitaskGPModel:

    def __init__(
            self,
            learn_smoothed_reward,
            learn_unsmoothed_reward,
            num_epochs,
            minibatch_size,
            input_dim,
            pos_vel_dim,
            max_nn_dataset_size=10000,
            max_gp_dataset_size=1000,
            num_inducing_points=200,
            hidden_layer_width=100,
            gp_learning_rate=0.0005,
            nn_learning_rate=0.0005,
            use_coregionalization=True,
            use_thompson_sampling=False,
            use_separate_reward_cov=False,
            device='cpu'
        ):
        self.learn_smoothed_reward = learn_smoothed_reward
        self.learn_unsmoothed_reward = learn_unsmoothed_reward
        self.use_separate_reward_cov = use_separate_reward_cov
        self.use_thompson_sampling = use_thompson_sampling
        self.num_epochs = num_epochs
        self.minibatch_size = minibatch_size
        self.max_nn_dataset_size = max_nn_dataset_size
        self.max_gp_dataset_size = max_gp_dataset_size
        self.input_dim = input_dim
        self.output_dim = pos_vel_dim + 1 if (learn_smoothed_reward or learn_unsmoothed_reward) else pos_vel_dim
        self.hidden_layer_width = hidden_layer_width
        self.gp_lr = gp_learning_rate
        self.nn_lr = nn_learning_rate
        self.eps = 1e-12
        self.device = device

        self.nn_priorities = torch.zeros((0)).to(self.device)
        self.gp_priorities = torch.zeros((0)).to(self.device)
        self.nn_train_dataset = EnvDataset(self.input_dim, self.output_dim, self.device)
        self.gp_train_dataset = EnvDataset(self.input_dim, self.output_dim, self.device)

        self.mean_function = BaseMeanModel(input_dim, hidden_layer_width, self.output_dim).to(self.device)
        self.gp_model = BaseMultitaskGPModel(
            self.input_dim,
            self.output_dim,
            num_inducing_points,
            use_coregionalization=use_coregionalization,
            separate_reward_cov=self.use_separate_reward_cov,
        ).to(self.device)
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
                num_tasks=self.output_dim,
                # rank=output_dim,
                # has_task_noise=False,
                # has_global_noise=False,
            ).to(self.device)

    def train(self):
        # Train the mean function
        minibatch_size = min(self.minibatch_size, len(self.nn_train_dataset))
        nn_train_loader = torch.utils.data.DataLoader(
            self.nn_train_dataset,
            batch_size=minibatch_size,
            pin_memory=False,
            drop_last=True
        )
        mean_final_loss = self.train_mean_function(nn_train_loader)
        
        with torch.no_grad():
            shifted_targets = self.gp_train_dataset.targets - self.mean_function(self.gp_train_dataset.inputs)
        shifted_gp_train_dataset = EnvDataset(self.input_dim, self.output_dim, self.device, inputs=self.gp_train_dataset.inputs, targets=shifted_targets)

        # Train the GP model and likelihood
        minibatch_size = min(self.minibatch_size, len(self.gp_train_dataset))
        gp_train_loader = torch.utils.data.DataLoader(
            shifted_gp_train_dataset,
            batch_size=minibatch_size,
            pin_memory=False,
            drop_last=True
        )
        gp_final_loss = self.train_gp(gp_train_loader)

        return self.num_epochs, mean_final_loss, gp_final_loss

    def train_mean_function(self, train_loader):
        self.mean_function.train()
        nn_optimizer = torch.optim.Adam([{'params': self.mean_function.parameters()}], lr=self.nn_lr)
        nn_loss_function = nn.MSELoss()
        losses = []
        for i in range(self.num_epochs):
            epoch_losses = []
            for input_batch, target_batch in train_loader:
                nn_optimizer.zero_grad()
                input_batch = input_batch.to(self.device)
                target_batch = target_batch.to(self.device)
                output = self.mean_function(input_batch)
                loss = nn_loss_function(output, target_batch)
                loss.backward()
                nn_optimizer.step()
                epoch_losses.append(loss.item())
            loss = sum(epoch_losses) / len(epoch_losses)
            losses.append(loss)
        self.mean_function.eval()
        return loss

    def train_gp(self, train_loader):
        self.gp_model.train() # enter training mode
        self.likelihood.train()
        gp_optimizer = torch.optim.Adam([{'params': self.gp_model.parameters()}, {'params': self.likelihood.parameters()}], lr=self.gp_lr)
        gp_loss_function = gpytorch.mlls.VariationalELBO(self.likelihood, self.gp_model, num_data=len(self.gp_train_dataset)).to(self.device)
        losses = []
        for i in range(self.num_epochs):
            epoch_losses = []
            for input_batch, target_batch in train_loader:
                gp_optimizer.zero_grad()
                input_batch = input_batch.unsqueeze(1).expand(-1, self.output_dim, -1).unsqueeze(2).to(self.device)
                target_batch = target_batch.unsqueeze(1).to(self.device)
                output = self.gp_model(input_batch)
                loss = -gp_loss_function(output, target_batch).sum()
                loss.backward()
                gp_optimizer.step()
                epoch_losses.append(loss.item())
            loss = sum(epoch_losses) / len(epoch_losses)
            losses.append(loss)
        return loss

    def posterior_prediction(self, input):
        self.mean_function.eval()
        self.gp_model.eval()
        self.likelihood.eval()

        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            # Expand shape to match with inducing points
            input_expanded = input.unsqueeze(1).expand(-1, self.output_dim, -1).unsqueeze(2)
            output = self.gp_model(input_expanded) # model posterior distribution for training data
            pred = self.likelihood(output) # posterior predicted distribution (probability distribution over the output value)
            mean = self.mean_function(input).unsqueeze(1)

        shifted_pred = MultitaskMultivariateNormal(pred.mean + mean, pred.covariance_matrix) # add mean back
        return shifted_pred
    
    def optimistic_posterior_prediction(self, input, lower_percentile, upper_percentile):
        # Get joint distribution over next state and reward
        posterior = self.posterior_prediction(input)

        # Sample next reward from the Truncated Normal distribution r' ~ p(r' | s, a, r' > k)
        reward_mean = posterior.mean[:, 0, 4]
        cov_rr = posterior.covariance_matrix.detach().clone()[:, 4, 4] # scalar for each env torch.Size([10, 1])
        lower_threshold = torch.distributions.Normal(reward_mean, cov_rr).icdf(torch.tensor(lower_percentile))
        upper_threshold = torch.distributions.Normal(reward_mean, cov_rr).icdf(torch.tensor(upper_percentile))
        optimistic_reward_dist = TruncatedNormal(loc=reward_mean, scale=cov_rr, a=lower_threshold, b=upper_threshold)
        reward_sample = optimistic_reward_dist.sample()

        # Sample next state from the Multivariate Normal distribution s' ~ p(s' | s, a, r'=r')
        state_mean = posterior.mean[:, 0, :4]
        cov_sr = posterior.covariance_matrix.detach().clone()[:, 4, :4]
        cov_rs = posterior.covariance_matrix.detach().clone()[:, :4, 4]
        cov_ss = posterior.covariance_matrix.detach().clone()[:, :4, :4]
        cov_sr_cov_rr_inv = cov_sr / (cov_rr.unsqueeze(-1))
        conditional_mean = state_mean + cov_sr_cov_rr_inv * (reward_sample - reward_mean).unsqueeze(-1)
        conditional_cov = cov_ss - cov_sr_cov_rr_inv.unsqueeze(-1) @ cov_rs.unsqueeze(1)
        conditional_dist = torch.distributions.MultivariateNormal(conditional_mean, conditional_cov)
        if self.use_thompson_sampling:
            state_sample = conditional_dist.sample()
        else:
            state_sample = conditional_dist.mean

        # Return the output
        return state_sample, reward_sample, posterior

    def register_new_data(self, inputs, targets, priorities):
        # Update NN train dataset
        nn_inputs = torch.cat((inputs.to(self.device), self.nn_train_dataset.inputs), dim=0)
        nn_targets = torch.cat((targets.to(self.device), self.nn_train_dataset.targets), dim=0)
        self.nn_priorities = torch.cat((priorities.to(self.device), self.nn_priorities))
        if len(self.nn_train_dataset.inputs) > self.max_nn_dataset_size: # keep high value samples
            _, indices = torch.sort(self.nn_priorities, descending=True)
            nn_inputs = nn_inputs[indices[:self.max_nn_dataset_size]]
            nn_targets = nn_targets[indices[:self.max_nn_dataset_size]]
            self.nn_priorities = self.nn_priorities[indices[:self.max_nn_dataset_size]]
        self.nn_train_dataset.inputs = nn_inputs
        self.nn_train_dataset.targets = nn_targets

        # Update GP train dataset
        gp_inputs = torch.cat((inputs.to(self.device), self.gp_train_dataset.inputs), dim=0)
        gp_targets = torch.cat((targets.to(self.device), self.gp_train_dataset.targets), dim=0)
        self.gp_priorities = torch.cat((priorities.to(self.device), self.gp_priorities))
        if len(gp_inputs) > self.max_gp_dataset_size: # keep high value samples
            _, indices = torch.sort(self.gp_priorities, descending=True)
            gp_inputs = gp_inputs[indices[:self.max_gp_dataset_size]]
            gp_targets = gp_targets[indices[:self.max_gp_dataset_size]]
            self.gp_priorities = self.gp_priorities[indices[:self.max_gp_dataset_size]]
        
        self.gp_train_dataset.inputs = gp_inputs
        self.gp_train_dataset.targets = gp_targets

    def save_model(self, path='model_state.pth'):
        torch.save(self.model.state_dict(), path)

    def load_model(self, path='model_state.pth'):
        state_dict = torch.load(path)
        self.model.load_state_dict(state_dict)

    def get_outputscales(self):
        if self.use_separate_reward_cov:
            outputscales = []
            for i in range(self.output_dim):
                if hasattr(self.gp_model.kernels[i], "outputscale"):
                    outputscales.append(self.gp_model.kernels[i].outputscale.detach().clone().item())
                else:
                    outputscale = self.gp_model.kernels[i].kernels[0].outputscale.detach().clone().item() + self.gp_model.kernels[i].kernels[1].outputscale.detach().clone().item()
                    outputscales.append(outputscale)
            return outputscales
        else:
            return self.gp_model.covar_module.outputscale[0].detach().clone()

    def get_lengthscales(self, model_num):
        if self.use_separate_reward_cov:
            return self.gp_model.kernels[model_num].base_kernel.lengthscale.detach().clone().squeeze()
        else:
            return self.gp_model.covar_module.base_kernel.lengthscale.detach().clone().squeeze(0,2)[model_num]