import torch
from torch import nn
import torch.nn.functional as F
from torch import tensor as Tensor
from typing import List, Any
import math


class MLPAE_1(nn.Module):
    def __init__(self, hidden_dims: List) -> None:
        super(MLPAE_1, self).__init__()

        self.act = nn.GELU()
        self.r_proj = nn.Linear(hidden_dims[0], hidden_dims[-1])
        self.s_enc1 = nn.Linear(hidden_dims[0] * 2, hidden_dims[-1])
        self.s_enc2 = nn.Linear(hidden_dims[0] * 2, hidden_dims[-1])
        self.dec = nn.Linear(hidden_dims[-1] * 2, hidden_dims[0])

        # for edit
        self.delta = nn.Parameter(torch.zeros(hidden_dims[-1],))

    def encode(self, s: Tensor, r_inputs: Tensor):
        if s.ndim == 1:
            s = s.unsqueeze(0)
        if r_inputs.ndim == 1:
            r_inputs = r_inputs.unsqueeze(0)

        r = self.r_proj(r_inputs)
        r = ((s @ r.T) / (r @ r.T)).diag().unsqueeze(1) * r
        s1 = self.s_enc1(torch.concat([s, r], dim=-1))
        s2 = self.s_enc2(torch.concat([s, r], dim=-1))

        s1, s2 = self.act(s1), self.act(s2)

        return s1, s2

    def decode(self, s1: Tensor, s2: Tensor):
        s3 = self.dec(torch.cat((s1, s2), dim=-1))

        return s3

    def forward(self, inputs: Tensor, r_inputs: Tensor, add_delta=False, delta_on_s1=True):
        # encode
        s1, s2 = self.encode(inputs, r_inputs)

        # edit
        if add_delta:
            if delta_on_s1:
                s1 += self.delta
            else:
                s2 += self.delta

        # decode
        s3 = self.decode(s1, s2)

        return s1, s2, s3

class GivensEncoder(nn.Module):
    def __init__(self, input_dim) -> None:
        super(GivensEncoder, self).__init__()

        cur_cnt = input_dim
        step_size = 1
        idx_list = [[], [], [], []]
        mat_cnt = 0
        while cur_cnt != 1:
            left = cur_cnt % 2
            cur_cnt = cur_cnt // 2
            idx = torch.arange(cur_cnt) * step_size * 2
            mat_idx = torch.tensor([mat_cnt]).repeat(len(idx))

            idx_list[0].append((mat_idx, idx, idx))
            idx_list[1].append((mat_idx, idx, idx+step_size))
            idx_list[2].append((mat_idx, idx+step_size, idx))
            idx_list[3].append((mat_idx, idx+step_size, idx+step_size))

            step_size = step_size * 2
            cur_cnt = cur_cnt + left
            mat_cnt += 1

        self.final_idx = []
        for sub_list in idx_list:
            mat_idx = torch.cat([ss[0] for ss in sub_list])
            idx_0 = torch.cat([ss[1] for ss in sub_list])
            idx_1 = torch.cat([ss[2] for ss in sub_list])
            self.final_idx.append([mat_idx, idx_0, idx_1])

        self.theta = nn.Parameter(torch.rand((input_dim-1,)))
        self.alpha = nn.Parameter(torch.ones((1,)))
        # self.bias = nn.Parameter(torch.rand((input_dim,)))

        self.mat_cnt = mat_cnt
        self.input_dim = input_dim
        self.mat = None

    def get_mats(self):
        all_theta = F.tanh(self.theta) * math.pi
        cos_theta, sin_theta = torch.cos(all_theta), torch.sin(all_theta)

        mats = torch.eye(self.input_dim).repeat(self.mat_cnt, 1, 1).cuda()
        mats[self.final_idx[0]] = cos_theta
        mats[self.final_idx[1]] = -sin_theta
        mats[self.final_idx[2]] = sin_theta
        mats[self.final_idx[3]] = cos_theta

        return mats

    def forward(self, x):
        if self.mat is None:
            for mat in self.get_mats():
                x = x @ mat
        else:
            x = x @ self.mat
        return x * self.alpha

    def set_eval(self):
        mats = self.get_mats().detach()
        self.mat = mats[0]
        for mat in mats[1:]:
            self.mat = self.mat @ mat

class MLPAE_2(nn.Module):
    def __init__(self, hidden_dims: List) -> None:
        super(MLPAE_2, self).__init__()

        negative_slope = 1e-2
        self.act = nn.LeakyReLU(negative_slope)
        self.act_1 = lambda x: torch.where(x < 0, x / negative_slope, x)

        self.r_proj = nn.Linear(hidden_dims[0], hidden_dims[-1])
        self.given_enc_1 = GivensEncoder(hidden_dims[0])
        self.given_enc_2 = GivensEncoder(hidden_dims[0])

        self.dec = nn.Linear(hidden_dims[-1] * 2, hidden_dims[0])

        # for edit
        self.delta = nn.Parameter(torch.zeros(hidden_dims[-1],))
    
    def set_eval(self):
        self.given_enc_1.set_eval()
        self.given_enc_2.set_eval()

    def encode(self, s: Tensor, r: Tensor):
        if s.ndim == 1:
            s = s.unsqueeze(0)
        if r.ndim == 1:
            r = r.unsqueeze(0)

        x = s + self.act(self.r_proj(r))
        s1 = self.act(self.given_enc_1(x))
        s2 = self.act(self.given_enc_2(x))

        return s1, s2

    def decode(self, s1: Tensor, s2: Tensor):
        s3 = self.dec(torch.cat((s1, s2), dim=-1))

        return s3

    def forward(self, inputs: Tensor, r_inputs: Tensor, add_delta=False):
        # encode
        s1, s2 = self.encode(inputs, r_inputs)

        # edit
        if add_delta:
            s1 = s1 + self.delta

        # decode
        s3 = self.decode(s1, s2)

        return s1, s2, s3


if __name__ == "__main__":
    bs, dim = 8, 1600
    ae_model = MLPAE_2([dim]).cuda()
    s, r = torch.rand((bs, dim)).cuda(), torch.rand((bs, dim)).cuda()
    ae_model(s, r)
