import numpy as np

import torch

from opelab.core.baseline import Baseline
from opelab.core.data import DataType
from opelab.core.policy import Policy
from opelab.core.baselines.diffusion.diffusionInpainter import GaussianDiffusionInpainter
from opelab.core.baselines.diffusion.temporal import TemporalUnet


class Inpainter(Baseline):

    def __init__(self, T:int, D:int, num_samples:int, state_dim:int, action_dim:int, device, unnormalizer, normalizer, reward_fn, model_path, target_model, scale, behavior_model, env,start,every) -> None:
        self.num_samples = num_samples
        self.T = T
        self.D = D
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        self.unnormalizer = unnormalizer
        self.reward_fn = reward_fn
        self.scale = scale
        self.env = env
        self.env_min = env.action_space.low
        self.env_max = env.action_space.high
        self.start = start
        self.every = every


        temporal_model = TemporalUnet(
            horizon=T,
            transition_dim=state_dim + action_dim,
        ).to(device)

        diffusion_model = GaussianDiffusionInpainter(
            model=temporal_model,
            horizon=T,
            observation_dim=state_dim,
            action_dim=action_dim,
            n_timesteps=D,
            policy=target_model,
            behavior_policy = behavior_model,
            normalizer=normalizer,
            unnormalizer=unnormalizer,
            start=start,
            every=every
        )

        diffusion_model.load_state_dict(torch.load(model_path))
        diffusion_model.to(device)
        self.diffusion_model = diffusion_model
        self.diffusion_model.eval()

    def unsquash_action(self, action):
        return (action + 1) * (self.env_max - self.env_min) / 2 + self.env_min

    def squash_action(self, action):
        return 2 * (action - self.env_min) / (self.env_max - self.env_min) - 1

    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0, reward_estimator=None) -> float:
        batch_size = 100
        num_batches = int(np.ceil(self.num_samples / batch_size))
        all_samples = np.empty((self.num_samples, self.T, self.state_dim + self.action_dim))
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, self.num_samples)
            batch_samples = self.diffusion_model.p_sample_loop((end_idx - start_idx, self.T, self.state_dim + self.action_dim), scale=self.scale)
            batch_samples = batch_samples.trajectories.detach().cpu().numpy()
            if self.unnormalizer is not None:
                batch_samples = self.unnormalizer(batch_samples)
            batch_samples[:, :, self.state_dim:] = 2 * np.tanh(batch_samples[:, :, self.state_dim:])
            all_samples[start_idx:end_idx] = batch_samples
        samples =  all_samples
        #squashed_samples = np.tanh(samples[:, :, self.state_dim:])
        #samples[:, :, self.state_dim:] = self.unsquash_action(squashed_samples)
       
        #samples[:, :, self.state_dim:] = np.clip(samples[:, :, self.state_dim:], a_min=-1, a_max=1)
        
        all_rewards = []
        for i in range(samples.shape[0]):
            sum_reward = 0
            normalizer = 0
            gamma_t = 1
            #tao_rewards = self.reward_fn(samples[i]) #Experimental
            for t in range(samples.shape[1]): 
                state = samples[i, t,:self.state_dim]
                action = samples[i, t, self.state_dim:]
                if self.reward_fn is not None:
                    rewards = self.reward_fn(state, action)
                else:
                    rewards = reward_estimator.predict(np.concatenate([state, action]).reshape(1, -1)).mean()
                sum_reward += rewards * gamma_t
                normalizer += gamma_t
                gamma_t *= gamma
            #sum_reward = tao_rewards #Experimental
            all_rewards.append(sum_reward / normalizer)

        print(all_rewards)
        # remove later 
        mean = np.mean(all_rewards)
        std = np.std(all_rewards)
        print(f"Mean reward: {mean}, Std: {std}")
        # only keep |x - mean| < std
        
        all_rewards_ = [x for x in all_rewards if x > -50]
        print(f'kept {len(all_rewards_)}')
        print(f"Mean reward: {np.mean(all_rewards_)}, Std: {np.std(all_rewards_)}")

        
        all_rewards_ = [x for x in all_rewards if x > -16]
        print(f'kept {len(all_rewards_)}')
        print(f"Mean reward: {np.mean(all_rewards_)}, Std: {np.std(all_rewards_)}")

        return np.mean(all_rewards_)
