# Train diffusion model on D4RL transitions.
import argparse
import pathlib
from typing import Optional, List

# import d4rl
import gin
import gym
import math
import numpy as np
import torch
import torch.nn as nn
# import wandb

# from diffusion.trainer import Trainer
from diffusion.norm import MinMaxNormalizer
from diffusion.utils import make_inputs, split_diffusion_samples, split_diffusion_samples_no_sa, \
    construct_diffusion_model

from diffusion.elucidated_diffusion import default
from torch.nn.utils import vector_to_parameters, parameters_to_vector
from tqdm import tqdm

def gaussian_entropy(mu_array: np.ndarray, sigma_squared: float) -> np.ndarray:
    """
    Calculate the entropy of multivariate Gaussian distributions with covariance
    Diag(1/M * Σ(μₘ²) - μ̄²) + σ²I in batch mode.
    """
    if len(mu_array.shape) == 2:
        mu_array = mu_array[np.newaxis, ...]

    _, _, D = mu_array.shape

    diagonal_terms = np.mean(mu_array**2, axis=1) - np.mean(mu_array, axis=1) ** 2
    diagonal_terms = np.clip(diagonal_terms, 0.0, None)  # because with only M=6 samples there are some negative values
    eigenvalues = diagonal_terms + sigma_squared  # Shape: (N, D)
    log_det = np.sum(np.log(eigenvalues), axis=1)  # Shape: (N,)

    entropy = 0.5 * log_det + 0.5 * D * (np.log(2 * np.pi) + 1)

    if len(mu_array.shape) == 2:
        return entropy[0]
    return entropy

@gin.configurable
class SimpleDiffusionGenerator:
    def __init__(
            self,
            env: gym.Env,
            ema_model,
            model_la=None,
            rew_model=None,
            num_sample_steps: int = 128,
            sample_batch_size: int = 10000,
    ):
        self.env = env
        self.diffusion = ema_model
        self.model_la = model_la
        self.diffusion.eval()
        self.rew_model = rew_model
        # Clamp samples if normalizer is MinMaxNormalizer
        self.clamp_samples = isinstance(self.diffusion.normalizer, MinMaxNormalizer)
        self.num_sample_steps = num_sample_steps
        self.sample_batch_size = sample_batch_size
        print(f'Sampling using: {self.num_sample_steps} steps, {self.sample_batch_size} batch size.')

    def sample(
            self,
            clip,
            num_samples: int,
            state_energy,
            transition_energy,
            policy_energy
    ) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray):
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'
        num_batches = num_samples // self.sample_batch_size
        observations = []
        actions = []
        rewards = []
        next_observations = []
        terminals = []
        for i in range(num_batches):
            # print(f'Generating split {i + 1} of {num_batches}')
            sampled_outputs = self.diffusion.sample(
                clip=clip,
                env=self.env,
                batch_size=self.sample_batch_size,
                num_sample_steps=self.num_sample_steps,
                clamp=self.clamp_samples,
                state_energy=state_energy,
                transition_energy=transition_energy,
                policy_energy=policy_energy
            )

            device = sampled_outputs.device
            sampled_outputs = sampled_outputs.cpu().numpy()

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples(sampled_outputs, self.env)
            if len(transitions) == 4:
                obs, act, rew, next_obs = transitions
                if self.rew_model is not None:
                    obs_tensor, acts_tensor = \
                        torch.from_numpy(obs), torch.from_numpy(act)
                    data = torch.cat([obs_tensor, acts_tensor], dim=1).to(device)
                    new_rew = self.rew_model(data)
                    rew = new_rew.squeeze(-1).detach().cpu().numpy()
                terminal = np.zeros_like(next_obs[:, 0])
            else:
                obs, act, rew, next_obs, terminal = transitions
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        observations = np.concatenate(observations, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)

        return observations, actions, rewards, next_observations, terminals

    def sample_laplace(
            self,
            clip,
            num_samples: int,
            state_energy,
            transition_energy,
            policy_energy
    ):
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'
        num_batches = num_samples // self.sample_batch_size
        observations = []
        actions = []
        rewards = []
        next_observations = []
        terminals = []
        list_sampled_outputs = []
        list_features = []
        list_weights = self.model_la.sample(4)
        
        # add MAP model
        list_weights = torch.concat([self.model_la.mean[None, :], list_weights], dim=0)

        S, D = list_weights.shape
        init_weight = parameters_to_vector(self.diffusion.parameters())

        #HERE
        for i in range(int(num_batches * 1.1)):
            # print(f'Generating split {i + 1} of {num_batches}')
            num_sample_steps = default(self.num_sample_steps, self.diffusion.num_sample_steps)
            shape = (self.sample_batch_size, *self.diffusion.event_shape)

            # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
            sigmas = self.diffusion.sample_schedule(num_sample_steps)
            gammas = torch.where(
                (sigmas >= self.diffusion.S_tmin) & (sigmas <= self.diffusion.S_tmax),
                min(self.diffusion.S_churn / num_sample_steps, math.sqrt(2) - 1),
                0.
            )

            sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))

            # inputs are noise at the beginning
            init_sigma = sigmas[0]
            inputs = init_sigma * torch.randn(shape, device=self.diffusion.device)
            list_eps = []
            for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc='sampling time step', mininterval=1,
                                             disable=False):
                eps = self.diffusion.S_noise * torch.randn(shape, device=self.diffusion.device)  # stochastic sampling
                list_eps.append(eps)

            list_laplace = []
            for s in range(S):
                model_params = parameters_to_vector(self.diffusion.parameters())
                model_params[-D:] = list_weights[s]
                vector_to_parameters(model_params, self.diffusion.parameters())
                sampled_outputs = self.diffusion.sample_with_inputs(
                    inputs,
                    sigmas_and_gammas,
                    shape,
                    list_eps,
                    clip=clip,
                    env=self.env,
                    batch_size=self.sample_batch_size,
                    num_sample_steps=self.num_sample_steps,
                    clamp=self.clamp_samples,
                    state_energy=state_energy,
                    transition_energy=transition_energy,
                    policy_energy=policy_energy
                )
                list_laplace.append(sampled_outputs)

            vector_to_parameters(init_weight, self.diffusion.parameters())
            sampled_outputs = self.diffusion.sample_with_inputs(
                inputs,
                sigmas_and_gammas,
                shape,
                list_eps,
                clip=clip,
                env=self.env,
                batch_size=self.sample_batch_size,
                num_sample_steps=self.num_sample_steps,
                clamp=self.clamp_samples,
                state_energy=state_energy,
                transition_energy=transition_energy,
                policy_energy=policy_energy
            )
            list_laplace.append(sampled_outputs)

            features = torch.stack(list_laplace, dim=0)
            features = np.transpose(features.cpu().numpy(), (1, 0, 2))
            list_features.append(features)
            device = sampled_outputs.device
            sampled_outputs = sampled_outputs.cpu().numpy()
            list_sampled_outputs.append(sampled_outputs)

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples(sampled_outputs, self.env)
            if len(transitions) == 4:
                obs, act, rew, next_obs = transitions
                if self.rew_model is not None:
                    obs_tensor, acts_tensor = \
                        torch.from_numpy(obs), torch.from_numpy(act)
                    data = torch.cat([obs_tensor, acts_tensor], dim=1).to(device)
                    new_rew = self.rew_model(data)
                    rew = new_rew.squeeze(-1).detach().cpu().numpy()
                terminal = np.zeros_like(next_obs[:, 0])
            else:
                obs, act, rew, next_obs, terminal = transitions
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        list_sampled_outputs = np.concatenate(list_sampled_outputs, axis=0)
        list_features = np.concatenate(list_features, axis=0)
        observations = np.concatenate(observations, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)
        entropies = gaussian_entropy(list_features, sigma_squared=1e-3)
        sorted_indices = np.argsort(entropies)
        list_sampled_outputs = list_sampled_outputs[sorted_indices[:num_samples]]
        observations = observations[sorted_indices[:num_samples]]
        actions = actions[sorted_indices[:num_samples]]
        rewards = rewards[sorted_indices[:num_samples]]
        next_observations = next_observations[sorted_indices[:num_samples]]
        terminals = terminals[sorted_indices[:num_samples]]

        # return (observations, actions, rewards, next_observations, terminals), list_sampled_outputs
        return observations, actions, rewards, next_observations, terminals

    # @torch.no_grad
    def sample_wo_guidance_cond(self, num_samples: int, cond: torch.Tensor) -> np.ndarray:
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'
        num_batches = num_samples // self.sample_batch_size
        rewards = []
        next_observations = []
        terminals = []
        for i in range(num_batches):
            # print(f'Generating split {i + 1} of {num_batches}')
            sampled_outputs = self.diffusion.sample_wo_guidance(
                batch_size=self.sample_batch_size,
                num_sample_steps=self.num_sample_steps,
                clamp=self.clamp_samples,
                cond=cond
            )
            sampled_outputs = sampled_outputs.detach().cpu().numpy()

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples_no_sa(sampled_outputs, self.env)
            # transitions = split_diffusion_samples_no_sa(sampled_outputs, self.env)
            if len(transitions) == 4:
                obs, act, rew, next_obs = transitions
                terminal = np.zeros_like(next_obs[:, 0])
            elif len(transitions) == 3:
                rew, next_obs, terminal = transitions
            elif len(transitions) == 2:
                rew, next_obs = transitions
                terminal = np.zeros_like(rew)
            else:
                raise NotImplementedError
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)

        return rewards, next_observations, terminals

    def sample_wo_guidance(
            self,
            num_samples: int,
            denoise_step: int,
    ) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray):
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'
        num_batches = num_samples // self.sample_batch_size
        observations = []
        actions = []
        rewards = []
        next_observations = []
        terminals = []
        for i in range(num_batches):
            # print(f'Generating split {i + 1} of {num_batches}')
            sampled_outputs = self.diffusion.sample_wo_guidance(
                batch_size=self.sample_batch_size,
                # num_sample_steps=denoise_step,
                denoise_step = denoise_step,
                clamp=self.clamp_samples,
            )
            sampled_outputs = sampled_outputs.cpu().numpy()

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples(sampled_outputs, self.env)
            if len(transitions) == 4:
                obs, act, rew, next_obs = transitions
                terminal = np.zeros_like(next_obs[:, 0])
            else:
                obs, act, rew, next_obs, terminal = transitions
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        observations = np.concatenate(observations, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)

        return observations, actions, rewards, next_observations, terminals


@gin.configurable
class EnsemblesDiffusionGenerator:
    def __init__(
            self,
            env: gym.Env,
            ema_models,
            model_la=None,
            rew_model=None,
            num_sample_steps: int = 128,
            sample_batch_size: int = 10000,
    ):
        self.env = env
        self.list_diffusions = ema_models
        self.model_la = model_la
        for diffusion in self.list_diffusions:
            diffusion.eval()
        self.rew_model = rew_model
        # Clamp samples if normalizer is MinMaxNormalizer
        self.list_clamp_samples = []
        for diffusion in self.list_diffusions:
            clamp_samples = isinstance(diffusion.normalizer, MinMaxNormalizer)
            self.list_clamp_samples.append(clamp_samples)
        self.num_sample_steps = num_sample_steps
        self.sample_batch_size = sample_batch_size
        print(f'Sampling using: {self.num_sample_steps} steps, {self.sample_batch_size} batch size.')

    def sample_ensemble(
            self,
            clip,
            num_samples: int,
            state_energy,
            transition_energy,
            policy_energy
    ):
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'
        num_batches = num_samples // self.sample_batch_size
        observations = []
        actions = []
        rewards = []
        next_observations = []
        terminals = []
        list_sampled_outputs = []
        list_features = []

        #HERE
        for i in range(int(num_batches * 1.1)):
            # print(f'Generating split {i + 1} of {num_batches}')
            num_sample_steps = default(self.num_sample_steps, self.list_diffusions[0].num_sample_steps)
            shape = (self.sample_batch_size, *self.list_diffusions[0].event_shape)

            # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
            sigmas = self.list_diffusions[0].sample_schedule(num_sample_steps)
            gammas = torch.where(
                (sigmas >= self.list_diffusions[0].S_tmin) & (sigmas <= self.list_diffusions[0].S_tmax),
                min(self.list_diffusions[0].S_churn / num_sample_steps, math.sqrt(2) - 1),
                0.
            )

            sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))

            # inputs are noise at the beginning
            init_sigma = sigmas[0]
            inputs = init_sigma * torch.randn(shape, device=self.list_diffusions[0].device)
            list_eps = []
            for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc='sampling time step', mininterval=1,
                                             disable=False):
                eps = self.list_diffusions[0].S_noise * torch.randn(shape, device=self.list_diffusions[0].device)  # stochastic sampling
                list_eps.append(eps)

            list_laplace = []
            for i in range(len(self.list_diffusions)):
                diffusion = self.list_diffusions[i]
                clamp_samples = self.list_clamp_samples[i]
                sampled_outputs = diffusion.sample_with_inputs(
                    inputs,
                    sigmas_and_gammas,
                    shape,
                    list_eps,
                    clip=clip,
                    env=self.env,
                    batch_size=self.sample_batch_size,
                    num_sample_steps=self.num_sample_steps,
                    clamp=clamp_samples,
                    state_energy=state_energy,
                    transition_energy=transition_energy,
                    policy_energy=policy_energy
                )
                list_laplace.append(sampled_outputs)

            features = torch.stack(list_laplace, dim=0)
            sampled_outputs = torch.mean(features, dim=0)
            # sampled_outputs = self.list_diffusions[0].normalizer.unnormalize(sampled_outputs)
            features = np.transpose(features.cpu().numpy(), (1, 0, 2))
            list_features.append(features)
            device = sampled_outputs.device
            sampled_outputs = sampled_outputs.cpu().numpy()
            list_sampled_outputs.append(sampled_outputs)

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples(sampled_outputs, self.env)
            if len(transitions) == 4:
                obs, act, rew, next_obs = transitions
                if self.rew_model is not None:
                    obs_tensor, acts_tensor = \
                        torch.from_numpy(obs), torch.from_numpy(act)
                    data = torch.cat([obs_tensor, acts_tensor], dim=1).to(device)
                    new_rew = self.rew_model(data)
                    rew = new_rew.squeeze(-1).detach().cpu().numpy()
                terminal = np.zeros_like(next_obs[:, 0])
            else:
                obs, act, rew, next_obs, terminal = transitions
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        list_sampled_outputs = np.concatenate(list_sampled_outputs, axis=0)
        list_features = np.concatenate(list_features, axis=0)
        observations = np.concatenate(observations, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)
        entropies = gaussian_entropy(list_features, sigma_squared=1e-3)
        sorted_indices = np.argsort(entropies)
        list_sampled_outputs = list_sampled_outputs[sorted_indices[:num_samples]]
        observations = observations[sorted_indices[:num_samples]]
        actions = actions[sorted_indices[:num_samples]]
        rewards = rewards[sorted_indices[:num_samples]]
        next_observations = next_observations[sorted_indices[:num_samples]]
        terminals = terminals[sorted_indices[:num_samples]]

        # return (observations, actions, rewards, next_observations, terminals), list_sampled_outputs
        return observations, actions, rewards, next_observations, terminals

@gin.configurable
class MLPGenerator:
    def __init__(
            self,
            env: gym.Env,
            model: nn.Module,
            rew_model: Optional[nn.Module] = None,
            num_sample_steps: int = 128,
            sample_batch_size: int = 1000,
    ):
        self.env = env
        self.model = model
        self.rew_model = rew_model
        # Clamp samples if normalizer is MinMaxNormalizer
        self.clamp_samples = False
        self.num_sample_steps = num_sample_steps
        self.sample_batch_size = sample_batch_size

    def sample_cond(self, cond: torch.Tensor) -> np.ndarray:
        self.model.eval()
        if self.rew_model is not None:
            self.rew_model.eval()
        num_samples = cond.shape[0]
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'
        num_batches = num_samples // self.sample_batch_size
        rewards = []
        next_observations = []
        terminals = []
        conds = torch.split(cond, self.sample_batch_size, dim=0)
        for i, cond in enumerate(conds):
            # print(f'Generating split {i + 1} of {num_batches}')
            sampled_outputs = self.model.sample(
                cond,
                clamp=self.clamp_samples
            )
            sampled_outputs = sampled_outputs.detach().cpu().numpy()

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples_no_sa(sampled_outputs, self.env)
            if len(transitions) == 4:
                rew, next_obs = transitions
                terminal = np.zeros_like(next_obs[:, 0])
            elif len(transitions) == 3:
                rew, next_obs, terminal = transitions
            else:
                rew, next_obs = transitions
                terminal = np.zeros_like(rew)
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)

        return rewards, next_observations, terminals

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='halfcheetah-medium-replay-v2')
    parser.add_argument('--gin_config_files', nargs='*', type=str, default=['config/resmlp_denoiser.gin'])
    parser.add_argument('--gin_params', nargs='*', type=str, default=[])
    # wandb config
    parser.add_argument('--wandb-project', type=str, default="offline-rl-diffusion")
    parser.add_argument('--wandb-entity', type=str, default="")
    parser.add_argument('--wandb-group', type=str, default="diffusion_training")
    #
    parser.add_argument('--results_folder', type=str, default='./results')
    parser.add_argument('--use_gpu', action='store_true', default=True)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--save_samples', action='store_true', default=True)
    parser.add_argument('--save_num_samples', type=int, default=int(5e6))
    parser.add_argument('--save_file_name', type=str, default='5m_samples.npz')
    parser.add_argument('--load_checkpoint', action='store_true')
    args = parser.parse_args()

    gin.parse_config_files_and_bindings(args.gin_config_files, args.gin_params)

    # Set seed.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.use_gpu:
        torch.cuda.manual_seed(args.seed)

    # Create the environment and dataset.
    env = gym.make(args.dataset)
    inputs = make_inputs(env)
    inputs = torch.from_numpy(inputs).float()
    dataset = torch.utils.data.TensorDataset(inputs)

    results_folder = pathlib.Path(args.results_folder)
    results_folder.mkdir(parents=True, exist_ok=True)
    with open(results_folder / 'config.gin', 'w') as f:
        f.write(gin.config_str())

    # Create the diffusion model and trainer.
    diffusion = construct_diffusion_model(inputs=inputs)
    trainer = Trainer(
        diffusion,
        dataset,
        results_folder=args.results_folder,
    )

    if not args.load_checkpoint:
        # Initialize logging.
        # wandb.init(
        #     project=args.wandb_project,
        #     entity=args.wandb_entity,
        #     config=args,
        #     group=args.wandb_group,
        #     name=args.results_folder.split('/')[-1],
        # )
        # Train model.
        trainer.train()
    else:
        trainer.ema.to(trainer.accelerator.device)
        # Load the last checkpoint.
        trainer.load(milestone=trainer.train_num_steps)

    # Generate samples and save them.
    if args.save_samples:
        generator = SimpleDiffusionGenerator(
            env=env,
            ema_model=trainer.ema.ema_model,
        )
        observations, actions, rewards, next_observations, terminals = generator.sample(
            num_samples=args.save_num_samples,
        )
        np.savez_compressed(
            results_folder / args.save_file_name,
            observations=observations,
            actions=actions,
            rewards=rewards,
            next_observations=next_observations,
            terminals=terminals,
        )
