import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
torch.manual_seed(123)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def construct_w_poly(w,n_root):
  '''
  return tensor with size: num_freq * n_root
  '''
  complex_poly_w = []
  complex_poly_w_conj = []
  for n in range(n_root+1):
    complex_poly_w.append( (w)**n )
    complex_poly_w_conj.append( (w)**n )

  complex_poly_w=torch.stack(complex_poly_w,dim=1)
  complex_poly_w_conj = torch.stack(complex_poly_w_conj,dim=1)

  return complex_poly_w,complex_poly_w_conj

def eval_poly(real,imag,w,num_spectra,conj=False):
  
  if not conj:
    a = real+1j*imag
  else:
    a = real-1j*imag
  num_batch,n_root=a.shape
  
  d1,d2 = w.shape
  w = w.unsqueeze(0).expand(num_batch,d1,d2).to(device)
  a = a.unsqueeze(1).expand(num_batch,num_spectra,n_root)
  highest_root = torch.ones(num_batch,num_spectra,1).to(device)
  a = torch.cat([a,highest_root],dim=-1)
  prod = a*w
  prod = prod.sum(dim=-1)

  return prod

def break_size_list(length,num_seq):
  '''
  function to break a list into approx-equal-length sequences
  Input:  (1)length: length of list
          (2)num_seq: number of sequences desired
  Output: (1)list of positions of breaking sequence
  Example:  Input: (2001,5)
            Output: [400,800,1200,1600,2001]
  '''
  size_list = np.array([length//num_seq]*num_seq)
  size_list[-1] = size_list[-1]+length%num_seq

  return size_list.tolist()

class PBNN(nn.Module):
  def __init__(self,dim=512,n_root=10,num_seq=10,input_dim=2001,dropout=0.,num_hidden=2):
    super().__init__()
    #the w list
    w_list = torch.split(torch.tensor(np.linspace(1,5,2001)),break_size_list(2001,num_seq))
    self.cpw_list = [construct_w_poly(w,n_root)[0] for w in w_list]
    
    module_sequence = [nn.Linear(input_dim,dim),nn.ReLU(),nn.Dropout(p=dropout)]
    for _ in range(num_hidden):
      module_sequence.append(nn.Linear(dim,dim))
      module_sequence.append(nn.ReLU())
      module_sequence.append(nn.Dropout(p=dropout))

    output_dim = num_seq*(n_root*2+1)
    module_sequence.append(nn.Linear(dim,output_dim))
    self.nn = nn.Sequential(*module_sequence)
    
    self.nr = n_root
    self.num_seq = num_seq

  def forward(self,x):
    x = self.nn(x)

    parameters_list = []
    for i in range(self.num_seq):
      temp = x[:,i*(self.nr*2+1):(i+1)*(self.nr*2+1)]
      a0_real = temp[:,:self.nr]
      a0_imag = temp[:,self.nr:2*self.nr]
      theta = temp[:,2*self.nr:2*self.nr+1]
      parameters_list.append([a0_real,a0_imag,theta])
    
    b_list=[]

    for i,parameters in enumerate(parameters_list):
      a0_real = parameters[0]
      a0_imag = parameters[1]
      theta = parameters[-1]

      num_spectra = len(self.cpw_list[i])
      p1 = eval_poly(a0_real,a0_imag,self.cpw_list[i],num_spectra)
      p1c = eval_poly(a0_real,a0_imag,self.cpw_list[i],num_spectra,conj=True)  
      b = torch.exp(1j*theta)*(p1/p1c)
      b_list.append(b)
    b_product = torch.cat(b_list,dim=1)

    return b_product