import numpy as np
import torch
import torch.nn as nn
from diff_nn import diff_CSDI
import math
import torch.nn.functional as F

def clip_by_norm(tensor, max_norm: float = 1.0, eps: float = 1e-8):
    orig_shape = tensor.shape
    B = orig_shape[0]

    flat = tensor.view(B, -1)
    norms = flat.norm(p=2, dim=1, keepdim=True)
    scale = torch.clamp(max_norm / (norms + eps), max=1.0) 

    flat = flat * scale
    return flat.reshape(orig_shape) 

class CSDI_base(nn.Module):
    def __init__(self, target_dim, config, device, logger=None):
        super().__init__()
        self.device = device
        self.target_dim = target_dim
        self.logger = logger

        self.emb_time_dim = config["model"]["timeemb"]
        self.emb_feature_dim = config["model"]["featureemb"]
        self.is_unconditional = config["model"]["is_unconditional"]
        self.target_strategy = config["model"]["target_strategy"]
        self.mitratio = config['model']['mitratio']
        self.is_ort = config['model']['is_ort']
        self.is_maskinfo = config['model']['is_maskinfo']
        self.target = config['model']['target']
        self.is_pr = config['pr']['is_pr']

        self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
        if self.is_maskinfo == True:
            self.emb_total_dim += 1
        self.embed_layer = nn.Embedding(
            num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim
        )

        config_diff = config["diffusion"]
        config_diff["side_dim"] = self.emb_total_dim
        config_diff["target"] = self.target

        if self.is_unconditional == True:
            input_dim = 1
        else:
            input_dim = 2

        self.diffmodel = diff_CSDI(config_diff, input_dim)
        if self.logger is not None:
            self.logger.info(f"Num of total trainable diffusion params is: {sum(p.numel() for p in self.diffmodel.parameters() if p.requires_grad)+ sum(p.numel() for p in self.embed_layer.parameters() if p.requires_grad)}")

        self.num_steps = config_diff["num_steps"]
        if config_diff["schedule"] == "quad":
            self.beta = np.linspace(
                config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps
            ) ** 2
        elif config_diff["schedule"] == "linear":
            self.beta = np.linspace(
                config_diff["beta_start"], config_diff["beta_end"], self.num_steps
            )
        elif config_diff["schedule"] == "cosine":
            self.beta = self._cosine_variance_schedule(self.num_steps)

        self.alpha_hat = 1 - self.beta
        self.alpha = np.cumprod(self.alpha_hat)
        self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
        alpha = self.alpha
        alpha_prev = np.append(1.0, alpha[:-1])
        self.sigma = np.sqrt((1 - alpha_prev) / (1 - alpha)) * np.sqrt(1 - (alpha / alpha_prev))

    def _cosine_variance_schedule(self,timesteps,epsilon= 0.008):
        steps=np.linspace(0,timesteps,timesteps+1)
        f_t=np.cos(((steps/timesteps+epsilon)/(1.0+epsilon))*math.pi*0.5)**2
        betas=np.clip(1.0-f_t[1:]/f_t[:timesteps],0.0,0.999)

        return betas

    def time_embedding(self, pos, d_model=128):
        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device)
        position = pos.unsqueeze(2)
        div_term = 1 / torch.pow(
            10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model
        )
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe

    def get_randmask(self, observed_mask):
        rand_for_mask = torch.rand_like(observed_mask) * observed_mask
        rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
        for i in range(len(observed_mask)):
            sample_ratio = np.random.rand()  # missing ratio
            num_observed = observed_mask[i].sum().item()
            num_masked = round(num_observed * sample_ratio)
            rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1
        cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
        return cond_mask
    
    def get_nearmask(self, mask, kernel_size=3):
        B, K, L = mask.shape
        inverse_mask = 1.0 - mask

        input_for_dilation = inverse_mask.reshape(B*K, 1, L)

        kernel = torch.ones(1,1,kernel_size, device=self.device)
        padding = kernel_size // 2

        dilated = F.conv1d(input_for_dilation, kernel, padding=padding)
        dilated = (dilated > 0).float()

        dilated = dilated.view(B, K, L)
        edge_mask = dilated * mask

        sample_ratios = torch.rand(B, 1, 1, device=self.device)
        random_mask = (torch.rand(B, K, L, device=self.device) < sample_ratios).float()

        new_missing = edge_mask * random_mask
        cond_mask = mask - new_missing

        return cond_mask

    def get_side_info(self, observed_tp, cond_mask):
        B, K, L = cond_mask.shape

        time_embed = self.time_embedding(observed_tp, self.emb_time_dim)  # (B,L,emb)
        time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1)
        feature_embed = self.embed_layer(
            torch.arange(self.target_dim).to(self.device)
        )  # (K,emb)
        feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)

        side_info = torch.cat([time_embed, feature_embed], dim=-1)  # (B,L,K,*)
        side_info = side_info.permute(0, 3, 2, 1)  # (B,*,K,L)

        if self.is_maskinfo == True:
            side_mask = cond_mask.unsqueeze(1)  # (B,1,K,L)
            side_info = torch.cat([side_info, side_mask], dim=1)

        return side_info

    def calc_loss_valid(
        self, observed_data, cond_mask, observed_mask, side_info, is_train
    ):
        loss_mit_sum = 0
        loss_ort_sum = 0
        for t in range(self.num_steps):
            if self.is_ort:
                loss_mit, loss_ort = self.calc_loss(
                    observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t
                )
                loss_mit_sum += loss_mit.detach()
                loss_ort_sum += loss_ort.detach()
            else:
                loss_mit = self.calc_loss(
                    observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t
                )
                loss_mit_sum += loss_mit.detach()
        if self.is_ort:
            return loss_mit_sum / self.num_steps, loss_ort_sum / self.num_steps
        else:
            return loss_mit_sum / self.num_steps

    def calc_loss(
        self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1
    ):
        B, K, L = observed_data.shape
        
        if is_train != 1:
            t = (torch.ones(B) * set_t).long().to(self.device)
        else:
            t = torch.randint(0, self.num_steps, [B]).to(self.device)
            
        current_alpha = self.alpha_torch[t]  # (B,1,1)
        noise = torch.randn_like(observed_data)
        noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise
        
        total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)

        predicted = self.diffmodel(total_input, side_info, t)  # (B,K,L)

        target_mask = observed_mask - cond_mask ##### target_mask : indicating mask, MIT

        if self.target == 'epsilon':
            residual_mit = (noise - predicted) * target_mask
            num_eval_mit = target_mask.sum()
            loss_mit = (residual_mit ** 2).sum() / (num_eval_mit if num_eval_mit > 0 else 1)
            if self.is_ort:
                residual_ort = (noise - predicted) * cond_mask ##### cond_mask : M_hat, missing mask
                num_eval_ort = cond_mask.sum()
                loss_ort = (residual_ort ** 2).sum() / (num_eval_ort if num_eval_ort > 0 else 1)
                return loss_mit, loss_ort
            else:
                return loss_mit
            
        elif self.target == 'x0':
            # coeff = 1.0 / self.alpha_torch[t]
            coeff = 1
            residual_mit = (observed_data - predicted) * target_mask
            num_eval_mit = target_mask.sum()
            loss_mit = (coeff*(residual_mit ** 2)).sum() / (num_eval_mit if num_eval_mit > 0 else 1)
            if self.is_ort:
                residual_ort = (observed_data - predicted) * cond_mask
                num_eval_ort = cond_mask.sum()
                loss_ort = (coeff*(residual_ort ** 2)).sum() / (num_eval_ort if num_eval_ort > 0 else 1)
                return loss_mit, loss_ort
            else:
                return loss_mit

    ##### return x0    
    def get_x0(
        self, predicted, current_sample = None, t = None
    ):
        if self.target == 'x0':
            return predicted
        elif self.target == 'epsilon':
            return (current_sample - (1-self.alpha[t]) ** 0.5 * predicted) / self.alpha[t]**0.5
        
    ##### get epsilon
    def get_eps(self, predicted, current_sample = None, t = None):
        if self.target == 'epsilon':
            return predicted
        elif self.target =='x0':
            return (current_sample - self.alpha[t] ** 0.5 * predicted) / (1-self.alpha[t])**0.5
        

    def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
        if self.is_unconditional == True:
            total_input = noisy_data.unsqueeze(1)  # (B,1,K,L)
        elif self.is_ort == False:
            cond_obs = cond_mask * observed_data
            cond_obs = cond_obs.unsqueeze(1)
            noisy_target = ((1-cond_mask)*noisy_data).unsqueeze(1)
            total_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)
        else:
            cond_obs = cond_mask * observed_data
            cond_obs = cond_obs.unsqueeze(1)
            noisy_target = noisy_data.unsqueeze(1)
            total_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)

        return total_input
    
    
    def p_sample(self, current_sample, observed_data, cond_mask, side_info, t):
        B, *_ = current_sample.shape
        if self.is_unconditional == True:
            noise = torch.randn_like(observed_data)
            noisy_obs = (self.alpha[t] ** 0.5) * observed_data + (1-self.alpha[t]) ** 0.5 * noise
            noisy_target = (cond_mask * noisy_obs + (1.0 - cond_mask) * current_sample).unsqueeze(1)
            diff_input = noisy_target  # (B,1,K,L)
        elif self.is_ort == False:
            cond_obs = cond_mask * observed_data
            cond_obs = cond_obs.unsqueeze(1)
            noisy_target = ((1-cond_mask) * current_sample).unsqueeze(1)
            diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)
        else:
            noise = torch.randn_like(observed_data)
            noisy_obs = (self.alpha[t] ** 0.5) * observed_data + (1-self.alpha[t]) ** 0.5 * noise
            cond_obs = cond_mask * observed_data
            cond_obs = cond_obs.unsqueeze(1)
            noisy_target = ((1-cond_mask) * current_sample + cond_mask * noisy_obs).unsqueeze(1)
            diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)

        predicted = self.diffmodel(diff_input, side_info, (torch.ones(B) * t).to(self.device).long())

        if self.target == 'epsilon':
            coeff1 = 1 / self.alpha_hat[t] ** 0.5
            coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
            current_sample = coeff1 * (current_sample - coeff2 * predicted)
            if t > 0:
                noise = torch.randn_like(current_sample)
                sigma = (
                    (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]
                ) ** 0.5
                current_sample += sigma * noise
            elif t == 0:
                current_sample = cond_mask * observed_data + (1-cond_mask) * current_sample

        elif self.target =='x0':
            if t > 0:
                eps_t = self.get_eps(predicted, current_sample=current_sample, t=t)
                noise = torch.randn_like(current_sample)
                #### DDIM
                current_sample = (self.alpha[t-1] ** 0.5) * predicted + (1-self.alpha[t-1]) ** 0.5 * eps_t
            elif t == 0:
                current_sample = cond_mask * observed_data + (1-cond_mask) * predicted

        return current_sample
    
    ##### TO-DO: make one-step denoising function with pr guidance
    def p_sample_pr(self, current_sample, observed_data, cond_mask, side_info, t, target_mask, model_pr = None, scale = 1):
        B, *_ = current_sample.shape
        if self.is_unconditional == True:
            noise = torch.randn_like(observed_data)
            noisy_obs = (self.alpha[t] ** 0.5) * observed_data + (1-self.alpha[t]) ** 0.5 * noise
            noisy_target = (cond_mask * noisy_obs + (1.0 - cond_mask) * current_sample).unsqueeze(1)
            diff_input = noisy_target  # (B,1,K,L)
        elif self.is_ort == False:
            cond_obs = cond_mask * observed_data
            cond_obs = cond_obs.unsqueeze(1)
            noisy_target = ((1-cond_mask) * current_sample).unsqueeze(1)
            if model_pr is not None:
                noisy_target = noisy_target.requires_grad_()
            diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)
        else:
            noise = torch.randn_like(observed_data)
            noisy_obs = (self.alpha[t] ** 0.5) * observed_data + (1-self.alpha[t]) ** 0.5 * noise
            cond_obs = cond_mask * observed_data
            cond_obs = cond_obs.unsqueeze(1)
            noisy_target = ((1-cond_mask) * current_sample + cond_mask * noisy_obs).unsqueeze(1)
            if model_pr is not None:
                noisy_target = noisy_target.requires_grad_()
            diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)

        predicted = self.diffmodel(diff_input, side_info, (torch.ones(B) * t).to(self.device).long())

        if self.target =='x0':

            imputed = ((1-cond_mask) * predicted + cond_mask * observed_data).unsqueeze(1)

            mask_hat = model_pr.diffmodel(imputed)

            fake_num = (target_mask).sum(dim = (1,2), keepdim = True)

            loglikelihood = (
                torch.sum(target_mask * torch.log(1. - mask_hat + 1e-10), dim=(1, 2), keepdim=True) / torch.clamp(fake_num, min=1.0)
            )
            
            pr_guidance = torch.autograd.grad(
                outputs = loglikelihood,
                inputs = noisy_target,
                grad_outputs=torch.ones_like(loglikelihood),
                retain_graph=True,
            )[0]

            pr_guidance = pr_guidance.squeeze()
            pr_guidance = clip_by_norm(pr_guidance, max_norm=1.0)

            if t > 0:
                eps_t = self.get_eps(predicted, current_sample=current_sample, t=t)
                #### DDIM sampling
                current_sample = (self.alpha[t-1] ** 0.5) * (predicted + scale * pr_guidance * (1-self.alpha[t]) / self.alpha[t]**0.5) + (1-self.alpha[t-1]) ** 0.5 * eps_t
            elif t == 0:
                current_sample = cond_mask * observed_data + (1-cond_mask) * predicted

        current_sample = current_sample.detach()

        return current_sample

    def impute(self, observed_data, cond_mask, side_info, n_samples):
        B, K, L = observed_data.shape

        imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device)

        for i in range(n_samples):
            current_sample = torch.randn_like(observed_data)
            for t in range(self.num_steps - 1, -1, -1):
                current_sample= self.p_sample(current_sample, observed_data, cond_mask, side_info, t)

            imputed_samples[:, i] = current_sample.detach()
        return imputed_samples
    
    def impute_pr(self, observed_data, cond_mask, side_info, n_samples, target_mask, model_pr, scale):
        B, K, L = observed_data.shape

        imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device)

        for i in range(n_samples):
            current_sample = torch.randn_like(observed_data)
            for t in range(self.num_steps - 1, -1, -1):
                current_sample= self.p_sample_pr(current_sample, observed_data, cond_mask, side_info, t, target_mask, model_pr, scale)

            imputed_samples[:, i] = current_sample.detach()
        return imputed_samples
    

    def forward(self, batch, is_train=1, is_dp=0):
        (
            original_data,
            observed_data,
            observed_mask,
            observed_tp,
            indicating_mask,
            gt_mask,
        ) = self.process_data(batch)
        if is_train == 0:
            cond_mask = gt_mask
        elif is_dp == 1:
            cond_mask = observed_mask
            observed_mask = torch.ones_like(observed_mask)
        elif self.target_strategy == "near":
            cond_mask = self.get_nearmask(observed_mask)
        else:
            cond_mask = self.get_randmask(observed_mask)
        
        side_info = self.get_side_info(observed_tp, cond_mask)

        loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid

        return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train)

    def evaluate(self, batch, n_samples, is_impute = False, model_pr = None, scale = 1):
        (
            original_data,
            observed_data,
            observed_mask,
            observed_tp,
            indicating_mask,
            gt_mask,
        ) = self.process_data(batch)

        if is_impute:
            cond_mask = observed_mask
            target_mask = 1. - observed_mask
        else:
            cond_mask = gt_mask
            target_mask = observed_mask - cond_mask

        side_info = self.get_side_info(observed_tp, cond_mask)

        if model_pr is not None:
            samples = self.impute_pr(observed_data, cond_mask, side_info, n_samples, target_mask, model_pr, scale)  
        else:  
            samples = self.impute(observed_data, cond_mask, side_info, n_samples)

        return samples, original_data, target_mask, observed_mask, observed_tp

    
class CSDI_ETT(CSDI_base):
    def __init__(self, config, device, target_dim=7, logger=None):
        super(CSDI_ETT, self).__init__(target_dim, config, device ,logger)

    def process_data(self, batch):
        original_data = batch["original_data"].to(self.device).float()
        observed_data = batch["observed_data"].to(self.device).float()
        observed_mask = batch["observed_mask"].to(self.device).float()
        indicating_mask = batch["indicating_mask"].to(self.device).float()
        observed_tp = batch["timepoints"].to(self.device).float()
        gt_mask = batch["gt_mask"].to(self.device).float()

        original_data = original_data.permute(0, 2, 1)
        observed_data = observed_data.permute(0, 2, 1)
        observed_mask = observed_mask.permute(0, 2, 1)
        indicating_mask = indicating_mask.permute(0, 2, 1)
        gt_mask = gt_mask.permute(0, 2, 1)

        return (
            original_data,
            observed_data,
            observed_mask,
            observed_tp,
            indicating_mask,
            gt_mask,
        )