import time
import torch
import functools
import tqdm
from torch import nn
from torch_geometric.data import Data
from torch_scatter import scatter_mean, scatter_sum
from .obj_enc import PointNet, LabelEncoder
from .gnn import Score_GNN
from .model_tools import GaussianFourierProjection, marginal_prob_std, diffusion_coeff

class AssembleModel_Room(nn.Module):
    def __init__(self, model_args) -> None:
        super().__init__()
        self.model_args = model_args
        # We save the dict in the training procedure, and reload it back in the testing.
        self.label_len = model_args["label_len"]
        self.obj_feat_len = model_args["obj_feat_len"]
        self.text_input_len = model_args["text_input_len"]
        self.text_feat_len = model_args["text_feat_len"]
        self.time_feat_len = model_args["time_feat_len"]
        self.mid_lay_input_len = model_args["mid_lay_input_len"]
        self.target_len = model_args["target_len"]
        self.n_layers = model_args["n_layers"]
        self.sigma = model_args["sigma"]

        # object encoder
        self.obj_encoder = PointNet(self.obj_feat_len)
        if self.label_len > 0:
            self.label_encoder = LabelEncoder(self.label_len, self.obj_feat_len)
            print("Using label encoder")
        else:
            self.label_encoder = None
            print("Not using label encoder")
        
        # Score GNN
        self.score_gnn = Score_GNN(self.target_len, self.mid_lay_input_len, self.obj_feat_len, self.text_feat_len, self.time_feat_len, self.n_layers)

        # text emb processing
        # we need a small nextwork to further reduce the dimension of the text embedding
        self.text_emb_process = nn.Sequential(
            nn.Linear(self.text_input_len, self.text_feat_len),
            nn.ReLU(True),
            nn.Linear(self.text_feat_len, self.text_feat_len),
            nn.ReLU(True)
        )

        # score-based process
        self.margin_fn = functools.partial(marginal_prob_std, sigma=self.sigma)
        self.diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=self.sigma)
        self.t_embed = nn.Sequential(GaussianFourierProjection(embed_dim=self.time_feat_len),
                                   nn.Linear(self.time_feat_len, self.time_feat_len))
        self.act = lambda x: x * torch.sigmoid(x)

    def get_obj_emb(self, obj, label=None):
        if label is not None:
            # print("You are using label encoder")
            return self.obj_encoder(obj) + self.label_encoder(label)
        else:
            # print("You are not using label encoder")
            return self.obj_encoder(obj)

    def get_score(self, pos, obj_emb, text_emb, t, edge_index):
        """
        get the score of the model
        """
        t_embed = self.act(self.t_embed(t.squeeze(-1)))
        text_emb = self.text_emb_process(text_emb)
        obj_emb = self.score_gnn(pos, obj_emb, text_emb, t_embed, edge_index)
        obj_emb = obj_emb / (self.margin_fn(t) + 1e-7)
        return obj_emb


    def forward(self, data: Data):
        """
        forward function is only applied in training
        for inference, we need to use the get_score function
        this is because we need to use the multi-gpu training in the training process
        """
        eps = 1e-5
        local_batch_size = max(data.batch) + 1 # We call it local as for multi-gpu training, the global batch size is the sum of all local batch size
        random_t = torch.rand(local_batch_size, device=data.x.device) * (1. - eps) + eps
        random_t = random_t.unsqueeze(-1)
        # [bs, 1] -> [num_nodes, 1]
        random_t = random_t[data.batch]

        z = torch.randn_like(data.y)
        std = self.margin_fn(random_t)
        pert_pos = data.y + z * std

        # Get object embedding
        if self.label_len > 0:
            obj_emb = self.get_obj_emb(data.x, data.one_hot_class_labels)
        else:
            obj_emb = self.get_obj_emb(data.x)

        # Get score
        score = self.get_score(pert_pos, obj_emb, data.text_emb, random_t, data.edge_index) # Can be single GPU or multi-GPU
        # Get loss
        node_l2 = torch.mean((score * std + z) ** 2, dim=-1)
        local_batch_l2 = scatter_mean(node_l2, data.batch, dim=0)
        return local_batch_l2 # Global avg loss can only be obtained on one GPU
    
    def get_global_loss(self, local_batch_l2):
        # conducted on one GPU
        loss = torch.mean(local_batch_l2)
        return loss

    
        
        

class GradientFieldSampler():
    def __init__(self, sampler_param, gf_model, device) -> None:
        self.change_sampler_param(sampler_param)
        self.gf_model = gf_model
        self.device = device

    def change_sampler_param(self, sampler_param):
        self.sampler_param = sampler_param
        self.sampler = sampler_param["sampler"]
        self.num_steps = sampler_param["num_steps"]
        self.t0 = sampler_param["t0"]
        self.snr = sampler_param["snr"]
        self.eps = 1e-3

    def sample_one_batch(self, data: Data, pos_init=None):
        if self.gf_model.label_len > 0:
            obj_emb = self.gf_model.get_obj_emb(data.x, data.one_hot_class_labels)
        else:
            obj_emb = self.gf_model.get_obj_emb(data.x)
        edge_index = data.edge_index
        batch = data.batch
        if pos_init is None:
            pos_init = torch.randn(obj_emb.size(0), self.gf_model.target_len, device=self.device)
        if self.sampler == "EM":
            print('Use EM Sampler')
            pos_states, samp_time = self.EM_samplers(pos_init, obj_emb, data.text_emb, batch, edge_index)
            return pos_states, samp_time
        elif self.sampler == "PC":
            pos_states, samp_time = self.PC_samplers(pos_init, obj_emb, data.text_emb, batch, edge_index)
            return pos_states, samp_time
        else:
            print("Not a valid sampler")
            raise ValueError
    
    # def EM_samplers(self, args):
    #     raise NotImplementedError

    def EM_samplers(self, pos, obj_emb, text_emb, batch, edge_index):
        pos_states = []
        t = torch.ones(obj_emb.size(0), device=self.device) * self.t0
        pos *= self.gf_model.margin_fn(t)[:, None]
        time_steps = torch.linspace(self.t0, self.eps, self.num_steps, device=self.device)
        step_size = time_steps[0] - time_steps[1]
        with torch.no_grad():
            iter = 0
            t0_time = time.time()
            t0_perf_counter = time.perf_counter()
            t0_process_time = time.process_time()
            for time_step in tqdm.tqdm(time_steps):
                batch_time_step = torch.ones(obj_emb.size(0), device=self.device) * time_step
                batch_time_step = batch_time_step.unsqueeze(-1)

                # Corrector step (Langevin MCMC)
                # grad = self.gf_model.get_score(pos, obj_emb, text_emb, batch_time_step, edge_index)
                # grad_norm = torch.square(grad).sum(dim=-1)
                # grad_norm = torch.sqrt(scatter_sum(grad_norm, batch, dim=0))[batch].unsqueeze(-1)
                # noise_norm = torch.sqrt(
                #     scatter_sum(torch.ones(pos.size(0), device=self.device) * pos.size(1), batch, dim=0))
                # noise_norm = noise_norm.unsqueeze(-1)
                # noise_norm = noise_norm[batch]
                # langevin_step_size = 2 * (self.snr * noise_norm / grad_norm) ** 2
                # pos = pos + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(pos)

                # Predictor step (Euler-Maruyama)
                g = self.gf_model.diffusion_coeff_fn(batch_time_step)
                pos_mean = pos + (g ** 2) * self.gf_model.get_score(pos, obj_emb, text_emb, batch_time_step, edge_index) * step_size
                pos = pos_mean + torch.sqrt(g ** 2 * step_size) * torch.randn_like(pos)
                pos_states.append(pos_mean)
                iter += 1
            samp_time = {"time": time.time() - t0_time, "perf_counter": time.perf_counter() - t0_perf_counter, "process_time": time.process_time() - t0_process_time}
            return pos_states, samp_time
    
    def PC_samplers(self, pos, obj_emb, text_emb, batch, edge_index):
        pos_states = []
        t = torch.ones(obj_emb.size(0), device=self.device) * self.t0
        pos *= self.gf_model.margin_fn(t)[:, None]
        time_steps = torch.linspace(self.t0, self.eps, self.num_steps, device=self.device)
        step_size = time_steps[0] - time_steps[1]
        with torch.no_grad():
            iter = 0
            t0_time = time.time()
            t0_perf_counter = time.perf_counter()
            t0_process_time = time.process_time()
            for time_step in tqdm.tqdm(time_steps):
                batch_time_step = torch.ones(obj_emb.size(0), device=self.device) * time_step
                batch_time_step = batch_time_step.unsqueeze(-1)

                # Corrector step (Langevin MCMC)
                grad = self.gf_model.get_score(pos, obj_emb, text_emb, batch_time_step, edge_index)
                grad_norm = torch.square(grad).sum(dim=-1)
                grad_norm = torch.sqrt(scatter_sum(grad_norm, batch, dim=0))[batch].unsqueeze(-1)
                noise_norm = torch.sqrt(
                    scatter_sum(torch.ones(pos.size(0), device=self.device) * pos.size(1), batch, dim=0))
                noise_norm = noise_norm.unsqueeze(-1)
                noise_norm = noise_norm[batch]
                langevin_step_size = 2 * (self.snr * noise_norm / grad_norm) ** 2
                pos = pos + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(pos)

                # Predictor step (Euler-Maruyama)
                g = self.gf_model.diffusion_coeff_fn(batch_time_step)
                pos_mean = pos + (g ** 2) * self.gf_model.get_score(pos, obj_emb, text_emb, batch_time_step, edge_index) * step_size
                pos = pos_mean + torch.sqrt(g ** 2 * step_size) * torch.randn_like(pos)
                pos_states.append(pos_mean)
                iter += 1
            samp_time = {"time": time.time() - t0_time, "perf_counter": time.perf_counter() - t0_perf_counter, "process_time": time.process_time() - t0_process_time}
            return pos_states, samp_time