# Standard imports
import argparse
import gc
import os
from pathlib import Path
from tqdm import tqdm
from sys import exit
import wandb
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils
from torch.nn.utils import weight_norm
from functools import partial

'''
MODELS
******

 - Introduced
     + TODO: WaveNeuralField
     + Non-Periodic WaveRNN

 - Baselines
    + iRNN
    + WaveRNN

'''

from torch.nn import functional as F, init
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==========
# INTRODUCED
# ==========

# Non-Periodic WaveRNN
# --------------------

class SoftmaxNormalizedConv1dCircular(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=False,
                 stride=1, padding=0, dilation=1, groups=1):
        """
        Parameters:
            in_channels (int): Number of channels in the input.
            out_channels (int): Number of channels produced by the convolution.
            kernel_size (int): Size of the convolution kernel.
            bias (bool): If True, adds a learnable bias.
            stride (int): Stride of the convolution.
            padding (int): Amount of circular padding to apply (applied on both sides).
            dilation (int): Spacing between kernel elements.
            groups (int): Number of blocked connections from input channels to output channels.
        """
        super().__init__()
        self.init = init
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.bias = None

        # Initialize the unnormalized weight parameters.
        self.weight = nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None

        self.weight.data[0,0] = torch.zeros(kernel_size)
        self.weight.data[0,0,self.kernel_size//2+1] = 1.0

        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

    def forward(self,hy,x=None):
        # Apply circular padding manually if needed.
        if self.padding > 0:
            hy = F.pad(hy, pad=(self.padding, self.padding), mode='constant')
            conv_padding = 0  # No additional zero padding is needed for the convolution.
        else:
            conv_padding = 0

        # Multiply the padded values by learned value
        if x is  not None:
          hy[:,:,-1] += x

        # Perform the convolution using the normalized weights.
        return F.conv1d(hy, self.weight, bias=self.bias, stride=self.stride,
                        padding=conv_padding, dilation=self.dilation, groups=self.groups)

act_dict = {'ident': nn.Identity, 'relu': nn.ReLU, 'tanh': nn.Tanh}
class RNN_Cell(nn.Module):
    def __init__(self, n_inp, n_ch=1, n_hid=10, act='ident', ksize=3):
        super(RNN_Cell, self).__init__()

        self.n_hid = n_hid # hidden = input size for this case since we have no W_x
        self.n_ch = n_ch

        self.act = act_dict[act]()
        self.Wx = nn.Linear(n_inp, self.n_ch, bias=False)
        torch.nn.init.eye_(self.Wx.weight)

        # self.Wy = nn.Conv1d(n_ch, n_ch, ksize, padding=ksize//2, padding_mode='circular', bias=False)
        self.Wy = SoftmaxNormalizedConv1dCircular(n_ch, n_ch, ksize, padding=ksize//2, bias=None)

    def forward(self,x,hy):
        # repeat Wx n_ch times to be size -1, n_ch, n_hid)
        x = self.Wx(x)#.unsqueeze(1).repeat(1, self.n_ch, 1)
        # x = self.Wx(x)
        hy = self.Wy(hy.view(-1, self.n_ch, self.n_hid),x)
        hy = self.act(hy.flatten(start_dim=1))
        # hy = self.act(x.flatten(start_dim=1) + hy.flatten(start_dim=1))
        # hy = self.Wy(x,hy.view(-1, self.n_ch, self.n_hid)).flatten(start_dim=1) + self.Wx(x)
        # hy = self.Wx(x)
        return hy

class coRNN(nn.Module):
    def __init__(self, n_inp, n_out, n_hid, n_ch=1, act='ident', ksize=3):
        super(coRNN, self).__init__()
        self.n_hid = n_hid 
        self.n_ch = n_ch
        self.n_out = n_out
        self.cell = RNN_Cell(n_inp, n_ch, n_hid, act, ksize)

        # No readout
        self.readout = nn.Linear(self.n_hid * self.n_ch, n_out, bias=False)
        torch.nn.init.eye_(self.readout.weight)
        # self.readout.weight.data = torch.eye(n_out)

    def forward(self, x, get_seq=False, n_ch_seq=3):

        ## initialize hidden states as 0
        hy = Variable(torch.zeros(x.size(1), self.n_hid * self.n_ch)).to(device)

        y_seq = []
        outputs = []

        for t in range(x.size(0)):
            hy = self.cell(x[t], hy)
            if get_seq:
                y_seq.append(hy.squeeze().view(x.size(1), self.n_ch, -1).detach().cpu())
            output = self.readout(hy)
            outputs.append(output)

        if get_seq:
            y_seq = torch.stack(y_seq, dim=0)
        outputs = torch.stack(outputs, dim=0)

        return outputs, y_seq


# =========
# BASELINES
# =========

# iRNN MODEL
# ----------

# from torch import nn
# import torch
# import numpy as np
# from torch.autograd import Variable
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class RNN_Cell(nn.Module):
#     def __init__(self, n_inp, n_hid, n_ch=1, act='tanh', ksize=3, init='eye', freeze_rnn='no', freeze_encoder='no', solo_init='no'):
#         super(RNN_Cell, self).__init__()
#         self.n_hid = n_hid
#         self.Wx = nn.Linear(n_inp, n_hid * n_ch)
#         self.Wy = nn.Linear(n_hid * n_ch, n_hid * n_ch)

#         if solo_init == 'yes':
#             nn.init.zeros_(self.Wx.weight)
#             nn.init.zeros_(self.Wx.bias)
#             with torch.no_grad():
#                 w = self.Wx.weight.view(n_ch, n_hid, n_inp)
#                 w[:, 0] = 1.0
#         elif solo_init == 'no':
#             nn.init.normal_(self.Wx.weight, mean=0.0, std=0.001)
#         else:
#             raise NotImplementedError

#         if init == 'eye':
#             nn.init.eye_(self.Wy.weight)
#             nn.init.zeros_(self.Wy.bias)
#         elif init == 'fwd':
#             nn.init.eye_(self.Wy.weight)
#             nn.init.zeros_(self.Wy.bias)
#             with torch.no_grad():
#                 self.Wy.weight.data = torch.roll(self.Wy.weight, 1, -1).data
#         elif init =='rand':
#             pass
#         else:
#             raise NotImplementedError

#         if act == 'tanh':
#             self.act = nn.Tanh()
#         elif act == 'relu':
#             self.act = nn.ReLU()
#         elif act == 'ident':
#             self.act = nn.Identity()
#         else:
#             raise NotImplementedError

#         if freeze_encoder == 'yes':
#             for param in self.Wx.parameters():
#                 param.requires_grad = False
#         else:
#             assert freeze_encoder == 'no'

#         if freeze_rnn == 'yes':
#             for param in self.Wy.parameters():
#                 param.requires_grad = False
#         else:
#             assert freeze_rnn == 'no'

#     def forward(self,x,hy):
#         hy = self.act(self.Wx(x) + self.Wy(hy))
#         return hy


# class coRNN(nn.Module):
#     def __init__(self, n_inp, n_hid, n_out, n_ch=1, act='relu', ksize=3, init='eye', freeze_rnn='no', freeze_encoder='no', solo_init='no'):
#         super(coRNN, self).__init__()
#         self.n_hid = n_hid
#         self.n_ch = n_ch
#         self.spatial = int(np.sqrt(n_hid))
#         self.cell = RNN_Cell(n_inp, n_hid, n_ch, act, ksize,  init, freeze_rnn, freeze_encoder, solo_init)
#         self.readout = nn.Linear(self.n_hid * self.n_ch, n_out)

#     def forward(self, x, get_seq=False):
#         # print(x.shape)
#         ## initialize hidden states
#         hy = Variable(torch.zeros(x.size(1), self.n_hid * self.n_ch)).to(device)
#         y_seq = []
#         outputs = []

#         for t in range(x.size(0)):
#             hy = self.cell(x[t], hy)
#             if get_seq:
#                 y_seq.append(hy.squeeze().view(x.size(1), self.n_ch, -1).detach().cpu())
#             output = self.readout(hy)
#             outputs.append(output)

#         if get_seq:
#             y_seq = torch.stack(y_seq, dim=0)
#         outputs = torch.stack(outputs, dim=0)

#         return outputs, y_seq


# =============
# WaveRNN MODEL
# =============

# from torch import nn
# import torch
# import numpy as np
# from torch.autograd import Variable
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class RNN_Cell(nn.Module):
#     def __init__(self, n_inp, n_hid, n_ch=1, act='tanh', ksize=3, init='eye', freeze_rnn='no', freeze_encoder='no', solo_init='no'):
#         super(RNN_Cell, self).__init__()
#         self.n_hid = n_hid
#         self.n_ch = n_ch
#         self.Wx = nn.Linear(n_inp, n_hid * n_ch)
#         self.Wy = nn.Conv1d(n_ch, n_ch, ksize, padding=ksize//2, padding_mode='circular')

#         if solo_init == 'yes':
#             nn.init.zeros_(self.Wx.weight)
#             nn.init.zeros_(self.Wx.bias)
#             with torch.no_grad():
#                 w = self.Wx.weight.view(n_ch, n_hid, n_inp)
#                 w[:, 0] = 1.0
#         elif solo_init == 'no':
#             nn.init.normal_(self.Wx.weight, mean=0.0, std=0.001)
#         else:
#             raise NotImplementedError

#         if act == 'tanh':
#             self.act = nn.Tanh()
#         elif act == 'relu':
#             self.act = nn.ReLU()
#         elif act == 'ident':
#             self.act = nn.Identity()
#         else:
#             raise NotImplementedError

#         assert init in ['eye', 'fwd', 'rand']

#         if init == 'eye' or init == 'fwd':
#             wts = torch.zeros(n_ch, n_ch, ksize)
#             nn.init.dirac_(wts)
#         if init == 'fwd':
#             wts = torch.roll(wts, 1, -1)

#         if init == 'eye' or init == 'fwd':
#             with torch.no_grad():
#                 self.Wy.weight.copy_(wts)

#         if freeze_encoder == 'yes':
#             for param in self.Wx.parameters():
#                 param.requires_grad = False
#         else:
#             assert freeze_encoder == 'no'

#         if freeze_rnn == 'yes':
#             for param in self.Wy.parameters():
#                 param.requires_grad = False
#         else:
#             assert freeze_rnn == 'no'

#     def forward(self,x,hy):
#         hy = self.act(self.Wx(x) + self.Wy(hy.view(-1, self.n_ch, self.n_hid)).flatten(start_dim=1))
#         return hy

# class coRNN(nn.Module):
#     def __init__(self, n_inp, n_hid, n_out, n_ch=1, act='relu', ksize=3, init='eye', freeze_rnn='no', freeze_encoder='no', solo_init='no'):
#         super(coRNN, self).__init__()
#         self.n_hid = n_hid
#         self.n_ch = n_ch
#         self.spatial = int(np.sqrt(n_hid))
#         self.cell = RNN_Cell(n_inp, n_hid, n_ch, act, ksize, init, freeze_rnn, freeze_encoder, solo_init)
#         self.readout = nn.Linear(self.n_hid * self.n_ch, n_out)

#     def forward(self, x, get_seq=False, n_ch_seq=3):
#         ## initialize hidden states
#         hy = Variable(torch.zeros(x.size(1), self.n_hid * self.n_ch)).to(device)
#         y_seq = []
#         outputs = []

#         for t in range(x.size(0)):
#             hy = self.cell(x[t], hy)
#             if get_seq:
#                 y_seq.append(hy.view(x.size(1), self.n_ch, -1)[:, :n_ch_seq].detach().cpu())
#             output = self.readout(hy)
#             outputs.append(output)

#         if get_seq:
#             y_seq = torch.stack(y_seq, dim=0)
#         outputs = torch.stack(outputs, dim=0)

#         return outputs, y_seq
