import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange

# helpers

def exists(val):
    return val is not None

def cast_tuple(val, repeat = 1):
    return val if isinstance(val, tuple) else ((val,) * repeat)

class Sine(nn.Module):
    def __init__(self, w0 = 1.):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        self.dim = dim
        self.theta = theta

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Siren(nn.Module):
    def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None, dropout = False):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.dim_out = dim_out
        self.dropout = dropout

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c = c, w0 = w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

    def init_(self, weight, bias, c, w0):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def forward(self, x):
        x = F.linear(x, self.weight, self.bias)

        x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.activation(x)
        return out

class ConditionalSirenModuleOld(nn.Module):
    def __init__(self, dim_in, dim_out, dim_cond, w0=1., c=6., is_first=False, use_bias=True, activation=None,
                 dropout_rate=False, emb_dropout_rate=False):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.dim_out = dim_out
        self.emb_dropout_rate = emb_dropout_rate
        self.dropout_rate = dropout_rate

        self.condition_on = True

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c=c, w0=w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

        time_dim = dim_in * 2
        # cond_dim = dim_in * 2
        cond_dim = dim_in

        sinu_pos_emb = SinusoidalPosEmb(dim_in, theta=100)
        fourier_dim = dim_in

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        self.cond_mlp = nn.Sequential(
            nn.Linear(dim_cond, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim)
        )

        self.emb_dropout = nn.Dropout(self.emb_dropout_rate)

        self.dropout = nn.Dropout(self.dropout_rate)

    def init_(self, weight, bias, c, w0):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def toggle_condition(self, on=True):
        self.condition_on = on

    def forward(self, x, time, cond_emb):
        t = self.time_mlp(time)
        # c = self.emb_dropout(self.cond_mlp(cond_emb))
        c = F.dropout(self.cond_mlp(cond_emb), p=0.3, training=True)

        scale_shift = t.chunk(2, dim=1)
        scale, shift = scale_shift

        # x = self.emb_dropout(x * (scale + 1) + shift) + c
        x = F.dropout(x * (scale + 1) + shift, p=0.3, training=True) + c

        x = F.linear(x, self.weight, self.bias)

        x = self.activation(x)

        return self.dropout(x)

class ConditionalSirenModuleStandard(nn.Module):
    def __init__(self, dim_in, dim_out, dim_cond, w0=1., c=6., is_first=False, use_bias=True, activation=None,
                 dropout_rate=0.1, emb_dropout_rate=0.1):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.dim_out = dim_out
        self.emb_dropout_rate = emb_dropout_rate
        self.dropout_rate = dropout_rate

        self.condition_on = True

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c=c, w0=w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

        time_dim = dim_in * 2
        # cond_dim = dim_in * 2
        cond_dim = dim_in

        sinu_pos_emb = SinusoidalPosEmb(dim_in, theta=100)
        fourier_dim = dim_in

        # self.x_mlp = nn.Sequential(
        #     nn.Linear(dim_in, dim_in),
        #     nn.GELU(),
        #     nn.Linear(dim_in, dim_in),
        # )

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        self.cond_mlp = nn.Sequential(
            nn.Linear(dim_cond, cond_dim),
            nn.GELU(),
            nn.Dropout(self.emb_dropout_rate),
            nn.Linear(cond_dim, cond_dim)
        )

        self.fuse_mlp = nn.Sequential(
            nn.GELU(),
            nn.Dropout(self.emb_dropout_rate),
            nn.Linear(dim_in, dim_in),
        )

        self.norm = nn.BatchNorm1d(dim_out)

        self.dropout = nn.Dropout(self.dropout_rate)

    def init_(self, weight, bias, c, w0):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def toggle_condition(self, on=True):
        self.condition_on = on

    def forward(self, x, time, cond_emb):
        # x = self.x_mlp(x)
        c = self.cond_mlp(cond_emb)

        t = self.time_mlp(time)

        scale_shift = t.chunk(2, dim=1)
        scale, shift = scale_shift

        x = x * (scale + 1) + shift
        x = self.fuse_mlp(x) + c

        x = F.linear(x, self.weight, self.bias)
        x = self.norm(x)

        x = self.activation(x)

        return self.dropout(x)

class ConditionalSirenModuleNew(nn.Module):
    def __init__(self, dim_in, dim_out, dim_cond, w0=1., c=6., is_first=False, use_bias=True, activation=None,
                 dropout_rate=False, emb_dropout_rate=False):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.dim_out = dim_out
        self.emb_dropout_rate = emb_dropout_rate
        self.dropout_rate = dropout_rate

        self.condition_on = True

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c=c, w0=w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.mid_activation = nn.GELU()
        self.activation = Sine(w0) if activation is None else activation

        time_dim = dim_in * 2
        # cond_dim = dim_in * 2
        cond_dim = dim_in

        sinu_pos_emb = SinusoidalPosEmb(dim_in, theta=100)
        fourier_dim = dim_in

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        self.cond_mlp = nn.Sequential(
            nn.Linear(dim_cond, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim)
        )

        self.emb_dropout = nn.Dropout(self.emb_dropout_rate)

        self.dropout = nn.Dropout(self.dropout_rate)

    def init_(self, weight, bias, c, w0):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def toggle_condition(self, on=True):
        self.condition_on = on

    def forward(self, x, time, cond_emb):
        t = self.time_mlp(time)
        c = F.dropout(self.cond_mlp(cond_emb), p=0.3, training=True)

        scale_shift = t.chunk(2, dim=1)
        scale, shift = scale_shift

        h = F.dropout(self.mid_activation(x * (scale + 1) + shift) + c, p=0.3, training=True)

        h = F.linear(h, self.weight, self.bias)

        h = self.activation(h)

        return self.dropout(h)

# siren network

class ConditionalSirenNet(nn.Module):
    def __init__(self, dim_hidden, dim_cond, w0 = 1., w0_initial = 30., use_bias = True, dropout_rate = 0.1, emb_dropout_rate = 0.3):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.dim_cond_input = dim_cond

        self.d_ins = [dim_hidden, dim_hidden // 2, dim_hidden// 4, dim_hidden// 8, dim_hidden// 4, dim_hidden // 2]
        self.d_outs = [dim_hidden // 2, dim_hidden // 4, dim_hidden// 8, dim_hidden// 4, dim_hidden // 2, dim_hidden]
        # self.d_ins = [dim_hidden, dim_hidden // 4, dim_hidden // 8, dim_hidden // 16, dim_hidden // 8, dim_hidden // 4]
        # self.d_outs = [dim_hidden // 4, dim_hidden // 8, dim_hidden // 16, dim_hidden // 8, dim_hidden // 4, dim_hidden]
        self.short_cuts = [-1, -1, -1, 2, 1, 0]

        self.layers = nn.ModuleList([])
        self.batch_norms = nn.ModuleList([])
        for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(self.d_ins, self.d_outs)):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0

            self.layers.append(ConditionalSirenModuleOld(
                dim_in = layer_dim_in,
                dim_out = layer_dim_out,
                dim_cond = dim_cond,
                w0 = layer_w0,
                use_bias = use_bias,
                is_first = is_first,
                dropout_rate = dropout_rate,
                emb_dropout_rate = emb_dropout_rate
            ))

            self.batch_norms.append(nn.BatchNorm1d(layer_dim_out))

        self.last_layer = nn.Linear(dim_hidden, dim_hidden)

    def forward(self, x, time, cond_emb):

        xs = []

        for short_cut_id, layer, norm in zip(self.short_cuts, self.layers, self.batch_norms):

            if short_cut_id != -1:
                x += xs[short_cut_id]

            x = layer(x, time, cond_emb)

            x = norm(x)

            xs.append(x)

        return self.last_layer(x)