import torch
import torch.nn.functional as functional
from os.path import join, isfile
from datetime import datetime
import os
import numpy as np
import torch.nn.functional as F


class Unfolding_RNN(torch.nn.Module):
    def __init__(self, A_initializer, D_initializer, config):
        super(Unfolding_RNN, self).__init__()
        self.config = config

        self.num_hidden = self.config["num_hidden"]
        self.num_layers = self.config["num_layers"]

        self.n_input = int(
            self.config["num_features"] / self.config["compression_factor"]
        )
        self.dtype = torch.float32
        if torch.cuda.is_available():
            print("Running on CUDA")
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            print("Running on MPS")
            self.device = torch.device("mps")
        else:
            print("Running on CPU")
            self.device = torch.device("cpu")

        self.D = torch.nn.Parameter(torch.Tensor(D_initializer))
        self.D2 = torch.nn.Parameter(
            torch.empty(
                self.config["num_layers"],
                self.config["num_hidden"],
                self.config["num_hidden"],
            )
        )
        torch.nn.init.xavier_uniform_(self.D2)

        self.layer_norms = torch.nn.ModuleList(
            [
                torch.nn.LayerNorm(
                    [self.config["num_features"], self.config["time_steps"]],
                    elementwise_affine=False,
                )
                for _ in range(self.num_layers)
            ]
        )
        self.A = torch.nn.Parameter(torch.Tensor(A_initializer))
        self.alpha = torch.nn.Parameter(torch.Tensor([float(self.config["alpha"])]))
        self.lambda0 = torch.nn.Parameter(torch.Tensor([float(self.config["lambda0"])]))
        self.lambda1 = torch.nn.Parameter(torch.Tensor([float(self.config["lambda1"])]))
        self.lambda2 = torch.nn.Parameter(torch.Tensor([float(self.config["lambda2"])]))

        if self.config["h0_init"] == "zeros":
            self.h_0 = torch.nn.Parameter(torch.zeros(self.num_hidden))
        if self.config["h0_init"] == "ones":
            self.h_0 = torch.nn.Parameter(torch.ones(self.num_hidden))
        if self.config["h0_init"] == "normal_0":
            self.h_0 = torch.nn.Parameter(torch.randn(self.num_hidden))
        if self.config["h0_init"] == "normal_d":
            D_mean = torch.mean(self.D)
            D_std = torch.std(self.D)
            self.h_0 = torch.nn.Parameter(torch.randn(self.num_hidden) * D_std + D_mean)

        if not self.config["learn_lambda0"]:
            self.lambda0.requires_grad = False
        if self.config["model"].lower() in ["l1l1", "reweighted"]:
            self.G = torch.nn.Parameter(torch.eye(self.num_hidden))

            if self.config["model"].lower() == "reweighted":
                self.g = torch.nn.Parameter(
                    torch.ones(self.num_layers, self.num_hidden, dtype=self.dtype)
                )

                self.Z = torch.nn.Parameter(
                    torch.Tensor(
                        np.tile(np.eye(self.num_hidden), [self.num_layers, 1, 1])
                    )
                )

        elif self.config["model"] == "gru":
            self.compression = torch.nn.Linear(
                config["num_features"], self.n_input, bias=False
            )
            self.reconstruction = torch.nn.Linear(
                config["num_hidden"], config["num_features"], bias=False
            )
            self.generic = torch.nn.GRU(
                self.n_input,
                self.num_hidden,
                self.num_layers,
                batchnum_layersfirst=False,
            )
        elif self.config["model"] == "rnn":
            self.compression = torch.nn.Linear(
                config["num_features"], self.n_input, bias=False
            )
            self.reconstruction = torch.nn.Linear(
                config["num_hidden"], config["num_features"], bias=False
            )
            self.generic = torch.nn.RNN(
                self.n_input,
                self.num_hidden,
                self.num_layers,
                batchnum_layersfirst=False,
            )
        elif self.config["model"] == "lstm":
            self.compression = torch.nn.Linear(
                config["num_features"], self.n_input, bias=False
            )
            self.reconstruction = torch.nn.Linear(
                config["num_hidden"], config["num_features"], bias=False
            )
            self.generic = torch.nn.LSTM(
                self.n_input,
                self.num_hidden,
                self.num_layers,
                batchnum_layersfirst=False,
            )

        if not self.config["learn_lambda1"]:
            self.lambda1.requires_grad = False
        if self.config["model"].lower() == "sista":
            self.F = torch.nn.Parameter(
                torch.eye(self.config["num_features"], dtype=self.dtype)
            )
        if not self.config["learn_lambda2"]:
            self.lambda2.requires_grad = False
        
        self.reconstructions = None

    def set_sequence_length(self, time_steps):
        self.config["time_steps"] = time_steps

    def soft_l1(self, x, b):
        out = torch.sign(x) * functional.relu(torch.abs(x) - b)
        return out

    def soft_l1_l1(self, x, w0, w1, alpha1):
        alpha0 = torch.zeros(alpha1.size(), device=self.device, dtype=self.dtype)
        condition = alpha0 <= alpha1
        alpha0_sorted = torch.where(condition, alpha0, alpha1)
        alpha1_sorted = torch.where(condition, alpha1, alpha0)

        w0_sorted = torch.where(condition, w0, w1)
        w1_sorted = torch.where(condition, w1, w0)

        cond1 = x >= alpha1_sorted + w0_sorted + w1_sorted
        cond2 = x >= alpha1_sorted + w0_sorted - w1_sorted
        cond3 = x >= alpha0_sorted + w0_sorted - w1_sorted
        cond4 = x >= alpha0_sorted - w0_sorted - w1_sorted

        res1 = x - w0_sorted - w1_sorted
        res2 = alpha1_sorted
        res3 = x - w0_sorted + w1_sorted
        res4 = alpha0_sorted
        res5 = x + w0_sorted + w1_sorted

        return torch.where(
            cond1,
            res1,
            torch.where(
                cond2, res2, torch.where(cond3, res3, torch.where(cond4, res4, res5))
            ),
        )

    def soft_l1_l1_reweighted(self, x, w0, w1, alpha1, g, c):
        alpha0 = torch.zeros(alpha1.size(), device=self.device, dtype=self.dtype)
        condition = alpha0 <= alpha1
        alpha0_sorted = torch.where(condition, alpha0, alpha1)
        alpha1_sorted = torch.where(condition, alpha1, alpha0)

        w0_sorted = torch.where(condition, w0, w1) * g
        w1_sorted = torch.where(condition, w1, w0) * g

        x = torch.mm(x, c)

        cond1 = x >= alpha1_sorted + w0_sorted + w1_sorted
        cond2 = x >= alpha1_sorted + w0_sorted - w1_sorted
        cond3 = x >= alpha0_sorted + w0_sorted - w1_sorted
        cond4 = x >= alpha0_sorted - w0_sorted - w1_sorted

        res1 = x - w0_sorted - w1_sorted
        res2 = alpha1_sorted
        res3 = x - w0_sorted + w1_sorted
        res4 = alpha0_sorted
        res5 = x + w0_sorted + w1_sorted

        return torch.where(
            cond1,
            res1,
            torch.where(
                cond2, res2, torch.where(cond3, res3, torch.where(cond4, res4, res5))
            ),
        )

    def _normalize_compression_matrix(self):
        old_range = self.A.data.max() - self.A.data.min() + 1e-6
        new_range = self.A_range[1] - self.A_range[0]
        self.A.data -= self.A.data.min()
        self.A.data *= new_range / old_range
        self.A.data += self.A_range[0]

    def build_graph_reweighted(self, input):
        At = self.A.t()
        Dt = self.D[0].t()
        D = self.D[0]
        AtA = torch.mm(At, self.A)

        V = 1.0 / self.alpha * torch.mm(Dt, At)

        temp = 1.0 / self.alpha * torch.mm(torch.mm(Dt, AtA), D)
        S = torch.eye(self.num_hidden, device=self.device, dtype=self.dtype) - temp

        W_1 = self.G - torch.mm(temp, self.G)

        h = []
        h_t_kth_layer = self.h_0.repeat(self.batch_size, 1)

        for t in range(self.config["time_steps"]):
            h.append([])
            h_t_last_layer = h_t_kth_layer
            h_t_kth_layer = self.soft_l1_l1_reweighted(
                torch.mm(h_t_last_layer, W_1.t()) + torch.mm(input[t + 1], V.t()),
                self.lambda0 / self.alpha,
                self.lambda1 / self.alpha,
                torch.mm(h_t_last_layer, self.G),
                self.g[0],
                self.Z[0],
            )
            h[-1].append(h_t_kth_layer)
            for k in range(1, self.num_layers):
                h_t_kth_layer = self.soft_l1_l1_reweighted(
                    torch.mm(input[t + 1], V.t()) + torch.mm(h_t_kth_layer, S.t()),
                    self.lambda0 / self.alpha,
                    self.lambda1 / self.alpha,
                    torch.mm(h_t_last_layer, self.G),
                    self.g[k],
                    self.Z[k],
                )
                h[-1].append(h_t_kth_layer)

        h_stacked = [torch.stack(h[i]) for i in range(len(h))]
        self.sparse_code = torch.stack(h_stacked)

    def build_graph_l1_l1(self, input):
        At = self.A.t()
        Dt = self.D.t()
        AtA = torch.mm(At, self.A)

        V = 1.0 / self.alpha * torch.mm(Dt, At)

        temp = 1.0 / self.alpha * torch.mm(torch.mm(Dt, AtA), self.D)
        S = torch.eye(self.num_hidden, device=self.device, dtype=self.dtype) - temp

        W_1 = self.G - torch.mm(temp, self.G)

        h = []
        h_t_kth_layer = self.h_0.repeat(self.batch_size, 1)

        for t in range(self.config["time_steps"]):
            h.append([])
            h_t_last_layer = h_t_kth_layer
            h_t_kth_layer = self.soft_l1_l1(
                torch.mm(h_t_last_layer, W_1.t()) + torch.mm(input[t + 1], V.t()),
                self.lambda0 / self.alpha,
                self.lambda1 / self.alpha,
                torch.mm(h_t_last_layer, self.G),
            )
            h[-1].append(h_t_kth_layer)
            for k in range(1, self.num_layers):
                h_t_kth_layer = self.soft_l1_l1(
                    torch.mm(input[t + 1], V.t()) + torch.mm(h_t_kth_layer, S.t()),
                    self.lambda0 / self.alpha,
                    self.lambda1 / self.alpha,
                    torch.mm(h_t_last_layer, self.G),
                )
                h[-1].append(h_t_kth_layer)

        h_stacked = [torch.stack(h[i]) for i in range(len(h))]
        self.sparse_code = torch.stack(h_stacked)

    def build_graph_sista(self, input):
        At = self.A.t()
        Dt = self.D.t()
        AtA = torch.mm(At, self.A)
        I = torch.eye(self.config["num_features"], device=self.device, dtype=self.dtype)
        P = torch.mm(torch.mm(Dt, self.F), self.D)

        V = 1.0 / self.alpha * torch.mm(Dt, At)

        temp = 1.0 / self.alpha * torch.mm(torch.mm(Dt, AtA + self.lambda2 * I), self.D)
        S = torch.eye(self.num_hidden, device=self.device, dtype=self.dtype) - temp

        W_1 = (self.alpha + self.lambda2) / self.alpha * P - torch.mm(temp, P)
        W_k = self.lambda2 / self.alpha * P

        h = []
        h_t_kth_layer = self.h_0.repeat(self.batch_size, 1)

        for t in range(self.config["time_steps"]):
            h.append([])
            h_t_last_layer = h_t_kth_layer
            h_t_kth_layer = self.soft_l1(
                torch.mm(h_t_last_layer, W_1.t()) + torch.mm(input[t + 1], V.t()),
                self.lambda0 / self.alpha,
            )
            h[-1].append(h_t_kth_layer)
            for k in range(1, self.num_layers):
                h_t_kth_layer = self.soft_l1(
                    torch.mm(h_t_last_layer, W_k.t())
                    + torch.mm(input[t + 1], V.t())
                    + torch.mm(h_t_kth_layer, S.t()),
                    self.lambda0 / self.alpha,
                )
                h[-1].append(h_t_kth_layer)

        h_stacked = [torch.stack(h[i]) for i in range(len(h))]
        self.sparse_code = torch.stack(h_stacked)

    def build_graph_dust(self, input):
        c = self.alpha

        At = self.A.t()
        Dt = self.D.transpose(1, 2)
        AtA = torch.mm(At, self.A)
        I = torch.eye(
            self.num_hidden, device=self.device, dtype=self.dtype
        )

        h = []

        H_kth_layer = self.h_0.repeat(
            self.batch_size, self.config["time_steps"]
        ).reshape(self.batch_size, self.num_hidden, self.config["time_steps"])

        for k in range(self.num_layers):
            dict_idx = k if self.config["diff_D"] else 0
            Dt_k = Dt[dict_idx]
            D_k = self.D[dict_idx]

            U = I - 1 / c * Dt_k @ AtA @ D_k
            V = 1 / c * Dt_k @ At
            H_last_layer = H_kth_layer
            H_kth_layer = H_last_layer.clone()
            DH = D_k @ H_last_layer
            q_hat = self.layer_norms[k](DH)
            HtDtDH = q_hat.transpose(-2, -1) @ q_hat
            softmax = torch.softmax(HtDtDH, dim=-1)
            Z = (
                self.lambda2 * H_last_layer @ softmax
            )

            for t in range(self.config["time_steps"]):
                H_kth_layer[:, :, t] = self.soft_l1(
                    Z[:, :, t] @ U + input[t + 1, :, :] @ V.t(), self.lambda1 / c
                )

            h.append(H_kth_layer)

        sparse_code = torch.stack(h)
        self.sparse_code = sparse_code.permute(3, 0, 1, 2)


    def build_graph_dust_vectorized(self, input):
        c = self.alpha

        At = self.A.t()
        Dt = self.D.transpose(1, 2)
        AtA = torch.mm(At, self.A)
        I = torch.eye(
            self.num_hidden, device=self.device, dtype=self.dtype
        )

        h = []

        H_kth_layer = self.h_0.repeat(
            self.batch_size, self.config["time_steps"]
        ).reshape(self.batch_size, self.num_hidden, self.config["time_steps"])

        for k in range(self.num_layers):
            dict_idx = k if self.config["diff_D"] else 0
            Dt_k = Dt[dict_idx]
            D_k = self.D[dict_idx]

            U = I - 1 / c * Dt_k @ AtA @ D_k
            V = 1 / c * Dt_k @ At
            H_last_layer = H_kth_layer

            DH = D_k @ H_last_layer
            q_hat = self.layer_norms[k](DH)
            HtDtDH = q_hat.transpose(-2, -1) @ q_hat
            softmax = torch.softmax(HtDtDH, dim=-1)
            Z = self.lambda2 * (H_last_layer @ softmax)

            inp = input[1:].permute(1, 2, 0)
            B = self.batch_size

            V_exp = V.unsqueeze(0).expand(B, -1, -1)
            V_term = torch.bmm(V_exp, inp)

            Z_U = torch.matmul(Z.permute(0, 2, 1), U)
            Z_U = Z_U.permute(0, 2, 1)

            H_kth_layer = self.soft_l1(Z_U + V_term, self.lambda1 / c)

            h.append(H_kth_layer)

        sparse_code = torch.stack(h)
        self.sparse_code = sparse_code.permute(3, 0, 1, 2)
    

    def build_graph_unrolled_transformer(self, input):
        h = []

        At = self.A.t()
        A = self.A


        H_kth_layer = self.h_0.repeat(
            self.batch_size, self.config["time_steps"]
        ).reshape(self.batch_size, self.num_hidden, self.config["time_steps"])

        for k in range(self.num_layers):
            dict_idx = k if self.config["diff_D"] else 0
            D_k = self.D[dict_idx]
            D_k2 = self.D2[dict_idx]

            H_last_layer = H_kth_layer
            H_kth_layer = H_last_layer.clone()
            DH = D_k @ H_last_layer
            q_hat = self.layer_norms[k](DH)

            HtDtDH = q_hat.transpose(-2, -1) @ q_hat
            softmax = torch.softmax(HtDtDH, dim=-1)
            Z = (
                self.lambda2 * H_last_layer @ softmax
            )

            H_kth_layer = F.relu(D_k2 @ Z)
            V = A @ D_k
            H_kth_layer = F.relu((input[1:, :, :] @ V).permute(1, 2, 0) + D_k2 @ Z)

            h.append(H_kth_layer)

        sparse_code = torch.stack(h)
        self.sparse_code = sparse_code.permute(3, 0, 1, 2)

    def forward(self, raw_input, test_mode=False):

        if self.config["model"].lower() in [
            "reweighted",
            "sista",
            "l1l1",
            "dust",
            "dust_vec",
            "unrolled_transformer",
        ]:
    
            self.batch_size = raw_input.shape[1]
            raw_input_reshape = raw_input.reshape([-1, raw_input.shape[-1]])
            now_input_reshape = raw_input_reshape.mm(self.A.t())
            self.now_input = now_input_reshape.view(
                [raw_input.shape[0], raw_input.shape[1], -1]
            )

            pre_input = torch.zeros(
                [1, self.now_input.shape[1], self.n_input],
                dtype=self.dtype,
                device=self.device,
            )
            input = torch.cat([pre_input, self.now_input])

            if self.config["model"].lower() == "reweighted":
                self.build_graph_reweighted(input)
            elif self.config["model"].lower() == "sista":
                self.build_graph_sista(input)
            elif self.config["model"].lower() == "l1l1":
                self.build_graph_l1_l1(input)
            elif self.config["model"].lower() == "dust":
                self.build_graph_dust(input)
            elif self.config["model"].lower() == "dust_vec":
                self.build_graph_dust_vectorized(input)
            elif self.config["model"].lower() == "unrolled_transformer":
                self.build_graph_unrolled_transformer(input)

            zeros_count = torch.sum(
                torch.sum(torch.sum((self.sparse_code == 0).int(), dim=-1), dim=-1),
                dim=0,
            ).data.float()
            self.sparsity = zeros_count / (self.sparse_code.numel() / self.num_layers)
            sparse_code_reshape = (
                self.sparse_code[:, -1, ...]
                .contiguous()
                .view([-1, self.config["num_hidden"]])
            )
            D = self.D[-1]
            z_hat_flattened = torch.mm(sparse_code_reshape, D.t())
            z_hat = z_hat_flattened.view(
                [self.config["time_steps"], self.batch_size, -1]
            )
            return z_hat

        else:
            compressed = self.compression(raw_input)
            sparse_code = self.generic(compressed)[0]
            z_hat = self.reconstruction(sparse_code)
            return z_hat
