import torch
import torch.nn.functional as F
from torch.distributions import Categorical
import time
from expression_tree import ExpressionTree
from enums import *
from diffusion_utils import *
from position_encodings import *


class BTSTransformerModel(nn.Module):
    def __init__(self, two_children_funcs, one_children_funcs, variables, **kwargs):
        super(BTSTransformerModel, self).__init__()
        default_model_parameters = {
            "max_depth": 32,
            "num_heads": 1,
            "dim_feedforward": 2048,
            "encode_layers": 0,
            "decoder_layers": 1,
            "oversampling": 3,
            "opt_const": True,
            "use_dct": True,
            "embedding_dim": 10,
            "dct_dim": 8,
            "device": torch.cuda,
            "max_num_const": 10,
            "diff_steps": 1,
            "max_layers": 1
        }
        for key, value in default_model_parameters.items():
            if key not in kwargs.keys():
                kwargs[key] = value

        self.device = kwargs["device"]
        self.opt_const = kwargs["opt_const"]
        self.max_depth = kwargs["max_depth"]
        self.oversampling_scalar = kwargs["oversampling"]
        self.two_children_funcs = two_children_funcs
        self.two_children_num = len(two_children_funcs)
        self.one_children_funcs = one_children_funcs
        self.one_children_num = len(one_children_funcs) + len(two_children_funcs)
        self.variables = variables
        self.max_num_const = kwargs["max_num_const"]
        self.max_layers = kwargs["max_layers"]

        self.library_size = len(self.two_children_funcs) + len(self.one_children_funcs) + len(self.variables)
        self.input_size = 2 * (self.library_size + 1)
        self.label_size = self.library_size
        self.embedding_dim = kwargs["embedding_dim"]
        self.dct_dim = kwargs["dct_dim"]

        self.diff_helper = DiffusionHelper(kwargs["diff_steps"], self.library_size, s=0.008)

        if kwargs["use_dct"]:
            self.dct_matrix = create_dct(self.embedding_dim, self.dct_dim).to(self.device)
        else:
            self.dct_matrix = None
            self.dct_dim = self.embedding_dim

        # self.position = OneDimensionalPositionalEncoding(d_model=self.input_size, max_len=max_depth)

        self.ps_embedding = nn.Linear(in_features=self.input_size, out_features=self.embedding_dim)
        self.target_embedding = nn.Linear(in_features=self.library_size + 1, out_features=self.embedding_dim)

        self.mask = self.generate_square_subsequent_mask(self.max_depth)

        if kwargs["encoder_layers"] != 0:
            encoder_layer = nn.TransformerEncoderLayer(d_model=self.dct_dim, dim_feedforward=kwargs["dim_feedforward"], nhead=kwargs["num_heads"], dropout=0,
                                                       batch_first=True, norm_first=True)

            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=kwargs["encoder_layers"])
        else:
            self.encoder = None

        if kwargs["decoder_layers"] != 0:
            decoder_layer = nn.TransformerDecoderLayer(d_model=self.dct_dim, dim_feedforward=kwargs["dim_feedforward"], nhead=kwargs["num_heads"], dropout=0,
                                                       batch_first=True, norm_first=True)
            self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=kwargs["decoder_layers"])
        else:
            self.decoder = None

        self.linear = nn.Linear(in_features=self.embedding_dim, out_features=self.library_size)

        self.softmax = nn.Softmax(dim=2)

        self.scr_mask = self.generate_square_subsequent_mask(self.max_depth)

        self.tgt_mask = self.generate_square_subsequent_mask(self.max_depth)

        # Weight initialization
        for name, param in self.named_parameters():
            if 'weight' in name and param.data.dim() == 2:
                nn.init.xavier_uniform_(param)

    def generate_square_subsequent_mask(self, size):  # Generate mask covering the top right triangle of a matrix
        mask = torch.triu(torch.full((size, size), float('-inf'), device=self.device), diagonal=1)
        return mask


class AutoregressiveModel(BTSTransformerModel):
    def __init__(self, two_children_funcs, one_children_funcs, variables, **kwargs):
        super(AutoregressiveModel, self).__init__(two_children_funcs, one_children_funcs, variables, **kwargs)

        if "dpo_split" not in kwargs.keys():
            kwargs["dpo_split"] = 3

        if "pe" not in kwargs.keys():
            kwargs["pe"] = PositionalEncodings.TwoDPE

        if PositionalEncodings.TwoDPE == kwargs["pe"]:
            self.position = TwoDimensionalPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        elif PositionalEncodings.OneDPE == kwargs["pe"]:
            self.position = OneDimensionalPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        else:
            self.position = NoPositionalEncoding()

        self.dpo_split = kwargs["dpo_split"]

    def forward(self, targets, ps_information, p, temp=1):

        ps_information = self.ps_embedding(ps_information)
        ps_information = self.position(ps_information, p)

        targets = right_shift(targets)
        targets = self.target_embedding(targets)
        targets = self.position(targets, p)

        if self.dct_matrix is not None:
            ps_information = ps_information @ self.dct_matrix.T
            targets = targets @ self.dct_matrix.T

        if self.encoder is not None:
            encoder_info = self.encoder(ps_information, mask=self.scr_mask)
        else:
            encoder_info = targets

        x = self.decoder(tgt=targets, memory=encoder_info, tgt_mask=self.tgt_mask, memory_mask=self.scr_mask)

        if self.dct_matrix is not None:
            x = x @ self.dct_matrix

        x = self.linear(x)

        labels = self.softmax(x / temp)
        return labels

    def sample(self, n, device):
        sample_equs = {}
        with torch.no_grad():
            dictionary = {"Fetch PS Time": 0, "Prediction Time": 0, "Apply Rules Time": 0, "Rand_Cat Time": 0,
                          "Add Node Time": 0, "Build Time": 0, "Equation Build Time": 0, "Comparison Time": 0}

            a = time.time()
            trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                   one_children_funcs=self.one_children_funcs, variables=self.variables,
                                   max_depth=self.max_depth, max_num_const=self.max_num_const, opt_const=self.opt_const, device=device)
            torch.cuda.synchronize()
            dictionary["Build Time"] += time.time() - a
            for j in range(self.max_depth):
                torch.cuda.synchronize()
                a = time.time()
                ps_info = trees.get_inputs().float().to(self.device)
                targets = trees.get_labels().float().to(self.device)
                positions = trees.get_positions().float().to(self.device)

                torch.cuda.synchronize()
                dictionary["Fetch PS Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                x = self.forward(targets, ps_info, positions, temp=1)[:, j, :] + 1E-5
                x = x.to(device)
                torch.cuda.synchronize()
                dictionary["Prediction Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                rules = trees.fetch_rules(j)

                x = x * rules
                torch.cuda.synchronize()
                dictionary["Apply Rules Time"] += time.time() - a
                torch.cuda.synchronize()
                a = time.time()
                predicted_vals = categorical_sample(x)

                torch.cuda.synchronize()
                dictionary["Rand_Cat Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                trees.add(predicted_vals.int(), j)
                torch.cuda.synchronize()
                dictionary["Add Node Time"] += time.time() - a

            torch.cuda.synchronize()
            a = time.time()
            equations = trees.equation_string()
            torch.cuda.synchronize()
            dictionary["Equation Build Time"] += time.time() - a
            torch.cuda.synchronize()
            a = time.time()
            unique = []
            i = 0
            for index in range(int(self.oversampling_scalar * n)):
                equ = equations[index]
                if equ not in sample_equs:
                    unique.append(index)
                    sample_equs[equ] = -torch.inf
                    i += 1
                    if i == n:
                        break
                elif n - i >= self.oversampling_scalar * n - index:
                    unique.append(index)
                    i += 1

            trees.reduce(unique)
            torch.cuda.synchronize()
            dictionary["Comparison Time"] += time.time() - a

            # print(dictionary)
            return trees, dictionary

    def dpo_sample(self, n, device):

        with torch.no_grad():
            dictionary = {"Fetch PS Time": 0, "Prediction Time": 0, "Apply Rules Time": 0, "Rand_Cat Time": 0,
                          "Add Node Time": 0, "Build Time": 0, "Equation Build Time": 0, "Comparison Time": 0}

            a = time.time()
            trees = ExpressionTree(n=int(n), two_children_funcs=self.two_children_funcs,
                                   one_children_funcs=self.one_children_funcs, variables=self.variables,
                                   max_depth=self.max_depth, max_num_const=self.max_num_const, opt_const=self.opt_const, device=device)
            torch.cuda.synchronize()
            dictionary["Build Time"] += time.time() - a
            for j in range(self.dpo_split):
                torch.cuda.synchronize()
                a = time.time()
                ps_info = trees.get_inputs().float().to(self.device)
                targets = trees.get_labels().float().to(self.device)
                positions = trees.get_positions().float().to(self.device)

                torch.cuda.synchronize()
                dictionary["Fetch PS Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                x = self.forward(targets, ps_info, positions, temp=1)[:, j, :] + 1E-5
                x = x.to(device)
                torch.cuda.synchronize()
                dictionary["Prediction Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                rules = trees.fetch_rules(j)

                x = x * rules * (~trees.vars_rule)
                torch.cuda.synchronize()
                dictionary["Apply Rules Time"] += time.time() - a
                torch.cuda.synchronize()
                a = time.time()
                predicted_vals = categorical_sample(x)

                torch.cuda.synchronize()
                dictionary["Rand_Cat Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                trees.add(predicted_vals.int(), j)
                torch.cuda.synchronize()
                dictionary["Add Node Time"] += time.time() - a

            trees.duplicate(2)

            for j in range(self.dpo_split, self.max_depth):
                torch.cuda.synchronize()
                a = time.time()
                ps_info = trees.get_inputs().float().to(self.device)
                targets = trees.get_labels().float().to(self.device)
                positions = trees.get_positions().float().to(self.device)

                torch.cuda.synchronize()
                dictionary["Fetch PS Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                x = self.forward(targets, ps_info, positions, temp=1)[:, j, :] + 1E-5
                x = x.to(device)
                torch.cuda.synchronize()
                dictionary["Prediction Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                rules = trees.fetch_rules(j)

                x = x * rules
                torch.cuda.synchronize()
                dictionary["Apply Rules Time"] += time.time() - a
                torch.cuda.synchronize()
                a = time.time()
                predicted_vals = categorical_sample(x)

                torch.cuda.synchronize()
                dictionary["Rand_Cat Time"] += time.time() - a

                torch.cuda.synchronize()
                a = time.time()
                trees.add(predicted_vals.int(), j)
                torch.cuda.synchronize()
                dictionary["Add Node Time"] += time.time() - a

            torch.cuda.synchronize()
            a = time.time()
            trees.equation_string()
            torch.cuda.synchronize()
            dictionary["Equation Build Time"] += time.time() - a
            torch.cuda.synchronize()

            return trees, dictionary


class D3PM(BTSTransformerModel):
    def __init__(self, two_children_funcs, one_children_funcs, variables, **kwargs):
        super(D3PM, self).__init__(two_children_funcs, one_children_funcs, variables, **kwargs)
        default_parameters = {
            "pe": PositionalEncodings.TwoDPE,
            "diff_steps": 25,
            "parent_sibling_info_in_encoder": True,
        }
        for key, value in default_parameters.items():
            if key not in kwargs.keys():
                kwargs[key] = value

        self.diff_helper = DiffusionHelper(kwargs["diff_steps"], self.library_size, s=0.008)
        self.ps_encoding = kwargs["parent_sibling_info_in_encoder"]

        if PositionalEncodings.TwoDPE == kwargs["pe"]:
            self.diff_pe = DiffusionTwoDimensionalPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        elif PositionalEncodings.OneDPE == kwargs["pe"]:
            self.diff_pe = DiffusionSequentialPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        else:
            self.diff_pe = NoPositionalEncoding()

        self.target_embedding = nn.Linear(in_features=self.library_size, out_features=self.embedding_dim)

        self.scr_mask = torch.zeros((self.max_depth, self.max_depth), device=self.device)
        self.tgt_mask = torch.zeros((self.max_depth, self.max_depth), device=self.device)

    def forward(self, targets, ps_information, p, t, temp=1):

        ps_information = self.ps_embedding(ps_information)
        ps_information = self.diff_pe(ps_information, p, t)

        targets = self.target_embedding(targets)
        targets = self.diff_pe(targets, p, t)

        if self.dct_matrix is not None:
            ps_information = ps_information @ self.dct_matrix.T
            targets = targets @ self.dct_matrix.T

        if self.encoder is not None:
            if self.ps_encoding:
                encoder_info = self.encoder(ps_information)
            else:
                encoder_info = self.encoder(targets)
        else:
            encoder_info = targets

        x = self.decoder(tgt=targets, memory=encoder_info)

        if self.dct_matrix is not None:
            x = x @ self.dct_matrix

        x = self.linear(x)

        labels = self.softmax(x / temp)
        return labels

    def sample(self, n, device):
        sample_equs = {}
        with torch.no_grad():
            dictionary = {"Fetch PS Time": 0, "Prediction Time": 0, "Apply Rules Time": 0, "Rand_Cat Time": 0,
                          "Add Node Time": 0, "Build Time": 0, "Equation Build Time": 0, "Comparison Time": 0}

            a = time.time()
            trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                   one_children_funcs=self.one_children_funcs, variables=self.variables,
                                   max_depth=self.max_depth, max_num_const=self.max_num_const, time_steps=self.diff_helper.timesteps,
                                   opt_const=self.opt_const, device=device)
            torch.cuda.synchronize()
            dictionary["Build Time"] += time.time() - a
            initial = torch.ones((trees.n, trees.max_depth, trees.library_size), device=trees.device, dtype=torch.float32)
            temp_trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                        one_children_funcs=self.one_children_funcs, variables=self.variables,
                                        max_depth=self.max_depth, max_num_const=self.max_num_const, time_steps=self.diff_helper.timesteps,
                                        opt_const=self.opt_const, device=device)
            temp_trees.sample_full_trees(initial)

            for t in reversed(range(1, self.diff_helper.timesteps)):
                beta_t = self.diff_helper.noise_schedule(t_int=torch.tensor(t)).repeat(trees.n).unsqueeze(1)  # (bs, 1)
                alpha_s_bar = self.diff_helper.get_alpha_bar(t_normalized=self.diff_helper.s).repeat(trees.n).unsqueeze(1)
                alpha_t_bar = self.diff_helper.get_alpha_bar(t_int=torch.tensor(t)).repeat(trees.n).unsqueeze(1)

                # # Retrieve transitions matrix
                Qtb = self.diff_helper.get_Qt_bar(alpha_t_bar, self.device)
                Qsb = self.diff_helper.get_Qt_bar(alpha_s_bar, self.device)
                Qt = self.diff_helper.get_Qt(beta_t, self.device)

                # Might want to add an empty node input value.
                inputs = trees.get_inputs().float().to(self.device)
                labels = trees.get_labels().float().to(self.device)
                positions = trees.get_positions().float().to(self.device)
                x = self.forward(labels, inputs, positions, torch.ones(trees.n, device=self.device) * t)  # /self.diff_helper.timesteps)
                prior = compute_batched_over0_posterior_distribution(x, Qt, Qsb, Qtb)

                unnormalized_X = (x.unsqueeze(-1) * prior).sum(dim=2)
                unnormalized_X[torch.sum(unnormalized_X, dim=-1) == 0] = 1e-5
                alpha_t = self.diff_helper.get_alpha_bar(t_int=torch.tensor(t))
                unnormalized_X = torch.sqrt(1 - alpha_t) * unnormalized_X + torch.sqrt(alpha_t) * labels
                probs = unnormalized_X / torch.sum(unnormalized_X, dim=-1, keepdim=True)
                # trees.diffusion_input[t-1]  = F.one_hot(probs.flatten(1).multinomial(1).reshape((trees.n, self.library_size)))

                temp_trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                            one_children_funcs=self.one_children_funcs, variables=self.variables,
                                            max_depth=self.max_depth, max_num_const=self.max_num_const,
                                            time_steps=self.diff_helper.timesteps, opt_const=self.opt_const, device=device)
                temp_trees.sample_full_trees(probs + 1E-8)

            temp_trees.positions_history = trees.positions_history
            temp_trees.diffusion_inputs = trees.diffusion_inputs
            temp_trees.diffusion_labels = trees.diffusion_labels
            trees = temp_trees

            equations = trees.equation_string()
            torch.cuda.synchronize()

            unique = []
            i = 0
            for index in range(int(self.oversampling_scalar * n)):
                equ = equations[index]
                if equ not in sample_equs:
                    unique.append(index)
                    sample_equs[equ] = -torch.inf
                    i += 1
                    if i == n:
                        break
                elif n - i >= self.oversampling_scalar * n - index:
                    unique.append(index)
                    i += 1

            trees.reduce(unique)
            torch.cuda.synchronize()
            dictionary["Comparison Time"] += time.time() - a
            # print(dictionary)
            return trees, dictionary


class DiffusionModel(BTSTransformerModel):
    def __init__(self, two_children_funcs, one_children_funcs, variables, **kwargs):
        super(DiffusionModel, self).__init__(two_children_funcs, one_children_funcs, variables, **kwargs)
        default_parameters = {
            "pe": PositionalEncodings.TwoDPE,
            "diff_steps": 25,
            "parent_sibling_info_in_encoder": True,
        }
        for key, value in default_parameters.items():
            if key not in kwargs.keys():
                kwargs[key] = value

        self.diff_helper = DiffusionHelper(kwargs["diff_steps"], self.library_size, s=0.008)
        self.ps_encoding = kwargs["parent_sibling_info_in_encoder"]

        if PositionalEncodings.TwoDPE == kwargs["pe"]:
            self.diff_pe = DiffusionTwoDimensionalPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        elif PositionalEncodings.OneDPE == kwargs["pe"]:
            self.diff_pe = DiffusionSequentialPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        else:
            self.diff_pe = NoPositionalEncoding()

        self.target_embedding = nn.Linear(in_features=self.library_size, out_features=self.embedding_dim)

        self.scr_mask = torch.zeros((self.max_depth, self.max_depth), device=self.device)
        self.tgt_mask = torch.zeros((self.max_depth, self.max_depth), device=self.device)

    def forward(self, targets, ps_information, p, t, temp=1):

        ps_information = self.ps_embedding(ps_information)
        ps_information = self.diff_pe(ps_information, p, t)

        targets = self.target_embedding(targets)
        targets = self.diff_pe(targets, p, t)

        if self.dct_matrix is not None:
            ps_information = ps_information @ self.dct_matrix.T
            targets = targets @ self.dct_matrix.T

        if self.encoder is not None:
            if self.ps_encoding:
                encoder_info = self.encoder(ps_information)
            else:
                encoder_info = self.encoder(targets)
        else:
            encoder_info = targets

        x = self.decoder(tgt=targets, memory=encoder_info)

        if self.dct_matrix is not None:
            x = x @ self.dct_matrix

        x = self.linear(x)

        labels = self.softmax(x / temp)
        return labels

    def sample(self, n, device):
        sample_equs = {}
        with torch.no_grad():
            dictionary = {"Fetch PS Time": 0, "Prediction Time": 0, "Apply Rules Time": 0, "Rand_Cat Time": 0,
                          "Add Node Time": 0, "Build Time": 0, "Equation Build Time": 0, "Comparison Time": 0}

            a = time.time()
            trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                   one_children_funcs=self.one_children_funcs, variables=self.variables,
                                   max_depth=self.max_depth, max_num_const=self.max_num_const, time_steps=self.diff_helper.timesteps,
                                   opt_const=self.opt_const, device=device)
            torch.cuda.synchronize()
            dictionary["Build Time"] += time.time() - a
            initial = torch.ones((trees.n, trees.max_depth, trees.library_size), device=trees.device, dtype=torch.float32)
            temp_trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                        one_children_funcs=self.one_children_funcs, variables=self.variables,
                                        max_depth=self.max_depth, max_num_const=self.max_num_const, time_steps=self.diff_helper.timesteps,
                                        opt_const=self.opt_const, device=device)
            temp_trees.sample_full_trees(initial)
            trees.update_diffusion_info(0, self.diff_helper.timesteps-1,
                                        temp_trees.get_positions(), temp_trees.get_inputs(), temp_trees.get_labels())

            for t in reversed(range(1, self.diff_helper.timesteps)):
                # beta_t = self.diff_helper.noise_schedule(t_int=torch.tensor(t)).repeat(trees.n).unsqueeze(1)  # (bs, 1)
                # alpha_s_bar = self.diff_helper.get_alpha_bar(t_normalized=self.diff_helper.s).repeat(trees.n).unsqueeze(1)
                # alpha_t_bar = self.diff_helper.get_alpha_bar(t_int=torch.tensor(t)).repeat(trees.n).unsqueeze(1)

                # # Retrieve transitions matrix
                # Qtb = self.diff_helper.get_Qt_bar(alpha_t_bar, self.device)
                # Qsb = self.diff_helper.get_Qt_bar(alpha_s_bar, self.device)
                # Qt = self.diff_helper.get_Qt(beta_t, self.device)

                # Might want to add an empty node input value.
                inputs, labels, positions = trees.fetch_diffusion_info(0, t)
                x = self.forward(labels, inputs, positions, torch.ones(trees.n, device=self.device) * t)  # /self.diff_helper.timesteps)
                # prior = compute_batched_over0_posterior_distribution(x, Qt, Qsb, Qtb)

                unnormalized_X = x  #  (x.unsqueeze(-1) * prior).sum(dim=2)
                unnormalized_X[torch.sum(unnormalized_X, dim=-1) == 0] = 1e-5
                alpha_t = self.diff_helper.get_alpha_bar(t_int=torch.tensor(t))
                unnormalized_X = torch.sqrt(1 - alpha_t) * unnormalized_X + torch.sqrt(alpha_t) * labels
                probs = unnormalized_X / torch.sum(unnormalized_X, dim=-1, keepdim=True)
                # trees.diffusion_input[t-1] = F.one_hot(probs.flatten(1).multinomial(1).reshape((trees.n, self.library_size)))

                temp_trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                            one_children_funcs=self.one_children_funcs, variables=self.variables,
                                            max_depth=self.max_depth, max_num_const=self.max_num_const,
                                            time_steps=self.diff_helper.timesteps, opt_const=self.opt_const, device=device)
                temp_trees.sample_full_trees(probs)

                trees.update_diffusion_info(0, t - 1, temp_trees.get_positions(),
                                            temp_trees.get_inputs(), temp_trees.get_labels())

            temp_trees.positions_history = trees.positions_history
            temp_trees.diffusion_inputs = trees.diffusion_inputs
            temp_trees.diffusion_labels = trees.diffusion_labels
            trees = temp_trees

            equations = trees.equation_string()
            torch.cuda.synchronize()

            unique = []
            i = 0
            for index in range(int(self.oversampling_scalar * n)):
                equ = equations[index]
                if equ not in sample_equs:
                    unique.append(index)
                    sample_equs[equ] = -torch.inf
                    i += 1
                    if i == n:
                        break
                elif n - i >= self.oversampling_scalar * n - index:
                    unique.append(index)
                    i += 1

            trees.reduce(unique)
            torch.cuda.synchronize()
            dictionary["Comparison Time"] += time.time() - a
            # print(dictionary)
            return trees, dictionary


class LLDiffusionModel(BTSTransformerModel):
    def __init__(self, two_children_funcs, one_children_funcs, variables, **kwargs):
        super(LLDiffusionModel, self).__init__(two_children_funcs, one_children_funcs, variables, **kwargs)
        default_parameters = {
            "parent_sibling_info_in_encoder": True,
        }
        for key, value in default_parameters.items():
            if key not in kwargs.keys():
                kwargs[key] = value

        self.diff_helper = DiffusionHelper(kwargs["max_depth"], self.library_size, s=0.008)
        self.ps_encoding = kwargs["parent_sibling_info_in_encoder"]

        self.diff_pe = DiffusionSequentialPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)

        self.target_embedding = nn.Linear(in_features=self.library_size, out_features=self.embedding_dim)

        self.scr_mask = torch.zeros((self.max_depth, self.max_depth), device=self.device)
        self.tgt_mask = torch.zeros((self.max_depth, self.max_depth), device=self.device)

    def forward(self, targets, ps_information, p, t, temp=1):

        if self.ps_encoding:
            ps_information = self.ps_embedding(ps_information)
            ps_information = self.diff_pe(ps_information, p, t)

        targets = self.target_embedding(targets)
        targets = self.diff_pe(targets, p, t)

        if self.dct_matrix is not None:
            if self.ps_encoding:
                ps_information = ps_information @ self.dct_matrix.T
            targets = targets @ self.dct_matrix.T

        if self.encoder is not None:
            if self.ps_encoding:
                encoder_info = self.encoder(ps_information)
            else:
                encoder_info = self.encoder(targets)
        else:
            encoder_info = targets

        if self.decoder is not None:
            x = self.decoder(tgt=targets, memory=encoder_info)
        else:
            x = encoder_info

        if self.dct_matrix is not None:
            x = x @ self.dct_matrix

        x = self.linear(x)

        labels = self.softmax(x / temp)
        return labels

    def sample(self, n, device):
        sample_equs = {}
        with torch.no_grad():
            dictionary = {"Fetch PS Time": 0, "Prediction Time": 0, "Apply Rules Time": 0, "Rand_Cat Time": 0,
                          "Add Node Time": 0, "Build Time": 0, "Equation Build Time": 0, "Comparison Time": 0}

            a = time.time()

            prev_labels = torch.zeros((int(self.oversampling_scalar * n), self.max_depth, self.library_size), device=device)
            mask = torch.ones((int(self.oversampling_scalar * n), self.max_depth, self.library_size), device=device)
            nrange = torch.arange(0, int(self.oversampling_scalar * n), device=device)
            for t in reversed(range(1, self.diff_helper.timesteps)):

                # Might want to add an empty node input value.
                inputs = None
                positions = None

                priors = self.forward(prev_labels, inputs, positions, torch.ones(prev_labels.shape[0], device=self.device) * t)  # /self.diff_helper.timesteps)
                masked_priors = priors * mask
                samples = categorical_sample(masked_priors.flatten(1))
                token_id = (samples % self.library_size).int()
                depths = (samples / self.library_size).int()
                prev_labels[nrange, depths, token_id] = 1.0
                mask = (~(prev_labels.sum(dim=2, keepdim=True).bool())).repeat(1, 1, self.library_size).float()

            trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                   one_children_funcs=self.one_children_funcs, variables=self.variables,
                                   max_depth=self.max_depth, max_num_const=self.max_num_const, time_steps=self.diff_helper.timesteps,
                                   opt_const=self.opt_const, device=device)

            trees.sample_full_trees(prev_labels + 1e-8)

            equations = trees.equation_string()
            torch.cuda.synchronize()

            unique = []
            i = 0
            for index in range(int(self.oversampling_scalar * n)):
                equ = equations[index]
                if equ not in sample_equs:
                    unique.append(index)
                    sample_equs[equ] = -torch.inf
                    i += 1
                    if i == n:
                        break
                elif n - i >= self.oversampling_scalar * n - index:
                    unique.append(index)
                    i += 1

            trees.reduce(unique)
            torch.cuda.synchronize()
            dictionary["Comparison Time"] += time.time() - a
            # print(dictionary)
            return trees, dictionary


class AutoregressiveDiffusionModel(BTSTransformerModel):
    def __init__(self, two_children_funcs, one_children_funcs, variables, **kwargs):
        super(AutoregressiveDiffusionModel, self).__init__(two_children_funcs, one_children_funcs, variables, **kwargs)
        default_parameters = {
            "pe": PositionalEncodings.TwoDPE,
            "diff_steps": 10,
            "max_layers": 20
        }
        for key, value in default_parameters.items():
            if key not in kwargs.keys():
                kwargs[key] = value

        self.diff_helper = DiffusionHelper(kwargs["diff_steps"], self.library_size, s=0.008)
        self.max_layers = kwargs["max_layers"]

        if PositionalEncodings.TwoDPE == kwargs["pe"]:
            self.diff_pe = AutoDiffusionTwoDimensionalPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        elif PositionalEncodings.OneDPE == kwargs["pe"]:
            raise TypeError("No yet implement and might not work with Autoregressive Diffusion")
            # self.diff_pe = AutoDiffusionSequentialPositionalEncoding(d_model=self.embedding_dim, max_len=kwargs["max_depth"], device=self.device)
        else:
            self.diff_pe = NoPositionalEncoding()

        self.target_embedding = nn.Linear(in_features=self.library_size, out_features=self.embedding_dim)

        self.scr_mask = torch.ones((self.max_depth, self.max_depth), device=self.device)
        self.tgt_mask = torch.ones((self.max_depth, self.max_depth), device=self.device)

    def forward(self, targets, ps_information, p, t, temp=1):

        ps_information = self.ps_embedding(ps_information)
        ps_information = self.diff_pe(ps_information, p, t)

        targets = self.target_embedding(targets)
        targets = self.diff_pe(targets, p, t)

        src_mask = self.create_mask(p)
        tgt_mask = src_mask

        if self.dct_matrix is not None:
            ps_information = ps_information @ self.dct_matrix.T
            targets = targets @ self.dct_matrix.T

        if self.encoder is not None:
            encoder_info = self.encoder(ps_information, mask=self.scr_mask)
        else:
            encoder_info = targets

        x = self.decoder(tgt=targets, memory=encoder_info, tgt_mask=tgt_mask, memory_mask=src_mask)

        if self.dct_matrix is not None:
            x = x @ self.dct_matrix

        x = self.linear(x)

        labels = self.softmax(x / temp)
        return labels

    def create_mask(self, position): # TOD0
        layers = (position[:, :, 0] - 1)
        ranges = torch.arange(0, layers.shape[1], device=self.device).unsqueeze(1).unsqueeze(0).repeat(layers.shape[0], 1, layers.shape[1]).permute(0, 2, 1)
        mask = (ranges < layers.unsqueeze(2).repeat(1, 1, layers.shape[1]))
        mask_values = torch.zeros(mask.shape, device=self.device)
        mask_values[mask] = -torch.inf
        return mask

    def create_diff_matrix(self, positions, depth, t):
        layers = positions[:, :, 0]
        matrix = (layers < depth) * 0 + (layers == depth) * t + (layers > depth) * self.diff_helper.timesteps
        return matrix

    def sample(self, n, device):
        sample_equs = {}
        with torch.no_grad():
            dictionary = {"Fetch PS Time": 0, "Prediction Time": 0, "Apply Rules Time": 0, "Rand_Cat Time": 0,
                          "Add Node Time": 0, "Build Time": 0, "Equation Build Time": 0, "Comparison Time": 0}

            a = time.time()
            trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                   one_children_funcs=self.one_children_funcs, variables=self.variables,
                                   max_depth=self.max_depth, max_num_const=self.max_num_const, max_layers_steps=self.max_layers,
                                   time_steps=self.diff_helper.timesteps, opt_const=self.opt_const, device=device)
            torch.cuda.synchronize()
            dictionary["Build Time"] += time.time() - a
            initial = torch.ones((trees.n, trees.max_depth, trees.library_size), device=trees.device, dtype=torch.float32)
            temp_trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                        one_children_funcs=self.one_children_funcs, variables=self.variables,
                                        max_depth=self.max_depth, max_num_const=self.max_num_const, max_layers_steps=self.max_layers,
                                        time_steps=self.diff_helper.timesteps, opt_const=self.opt_const, device=device)
            temp_trees.sample_full_trees(initial)
            trees.update_diffusion_info(0, self.diff_helper.timesteps, temp_trees.get_positions(),
                                        temp_trees.get_inputs(),
                                        temp_trees.get_labels())

            for depth in range(self.max_layers):

                if depth != 0:
                    inputs, labels, positions = trees.fetch_diffusion_info(depth-1, 0)
                    trees.update_diffusion_info(depth, self.diff_helper.timesteps, positions, inputs, labels)

                for t in reversed(range(1, self.diff_helper.timesteps+1)):
                    inputs, labels, positions = trees.fetch_diffusion_info(depth, t)
                    diff_steps = self.create_diff_matrix(positions, depth, t)
                    x = self.forward(labels, inputs, positions, diff_steps)

                    x[torch.sum(x, dim=-1) == 0] = 1e-5
                    probs = x / torch.sum(x, dim=-1, keepdim=True)

                    temp_trees = ExpressionTree(n=int(self.oversampling_scalar * n), two_children_funcs=self.two_children_funcs,
                                                one_children_funcs=self.one_children_funcs, variables=self.variables,
                                                max_depth=self.max_depth, max_num_const=self.max_num_const, time_steps=self.diff_helper.timesteps,
                                                opt_const=self.opt_const, device=device)

                    layers = positions[:, :, 0]
                    layers[layers == 0] = self.max_depth
                    mask = (layers >= depth).float().unsqueeze(2)
                    probs = (1 - mask) * labels + mask * probs
                    if probs.isnan().any():
                        raise ValueError("Nans in prob")
                    temp_trees.sample_full_trees(probs)
                    trees.update_diffusion_info(depth, t-1, temp_trees.get_positions(), temp_trees.get_inputs(), temp_trees.get_labels())

                temp_trees.positions_history = trees.positions_history
                temp_trees.diffusion_inputs = trees.diffusion_inputs
                temp_trees.diffusion_labels = trees.diffusion_labels
                trees = temp_trees

            equations = trees.equation_string()
            torch.cuda.synchronize()

            unique = []
            i = 0
            for index in range(int(self.oversampling_scalar * n)):
                equ = equations[index]
                if equ not in sample_equs:
                    unique.append(index)
                    sample_equs[equ] = -torch.inf
                    i += 1
                    if i == n:
                        break
                elif n - i >= self.oversampling_scalar * n - index:
                    unique.append(index)
                    i += 1

            trees.reduce(unique)
            torch.cuda.synchronize()
            dictionary["Comparison Time"] += time.time() - a
            # print(dictionary)
            return trees, dictionary


def categorical_sample(x):
    x = (x / torch.sum(x, dim=1, keepdim=True))
    return Categorical(x).sample()


def right_shift(targets):
    padded_targets = F.pad(targets, (1, 0, 1, 0, 0, 0), "constant", 0)
    padded_targets[:, 0, 0] = 1
    return padded_targets[:, :-1]


def dct(src, dim=-1, norm='ortho'):
    # type: (torch.tensor, int, str) -> torch.tensor

    x = src.clone()
    N = x.shape[dim]

    x = x.transpose(dim, -1)
    x_shape = x.shape
    x = x.contiguous().view(-1, N)

    v = torch.empty_like(x, device=x.device)
    v[..., :(N - 1) // 2 + 1] = x[..., ::2]

    if N % 2:  # odd length
        v[..., (N - 1) // 2 + 1:] = x.flip(-1)[..., 1::2]
    else:  # even length
        v[..., (N - 1) // 2 + 1:] = x.flip(-1)[..., ::2]

    V = torch.fft.fft(v, dim=-1)

    k = torch.arange(N, device=x.device)
    V = 2 * V * torch.exp(-1j * np.pi * k / (2 * N))

    if norm == 'ortho':
        V[..., 0] *= math.sqrt(1/(4*N))
        V[..., 1:] *= math.sqrt(1/(2*N))

    V = V.real
    V = V.view(*x_shape).transpose(-1, dim)

    return V


def idct(src, dim=-1, norm='ortho'):
    # type: (torch.tensor, int, str) -> torch.tensor

    X = src.clone()
    N = X.shape[dim]

    X = X.transpose(dim, -1)
    X_shape = X.shape
    X = X.contiguous().view(-1, N)

    if norm == 'ortho':
        X[..., 0] *= 1 / math.sqrt(2)
        X *= N*math.sqrt((2 / N))
    else:
        raise Exception("idct with norm=None is buggy A.F")

    k = torch.arange(N, device=X.device)

    X = X * torch.exp(1j * np.pi * k / (2 * N))
    X = torch.fft.ifft(X, dim=-1).real
    v = torch.empty_like(X, device=X.device)

    v[..., ::2] = X[..., :(N - 1) // 2 + 1]
    v[..., 1::2] = X[..., (N - 1) // 2 + 1:].flip(-1)

    v = v.view(*X_shape).transpose(-1, dim)

    return v


def create_dct(n, m=None):

    I = torch.eye(n)
    Q = dct(I, dim=0)

    if m is not None:
        Q = Q[:m,:]

    return Q

