import argparse
import torch
from torch.utils.data import DataLoader
import numpy as np

from helper import *
from NN import *
from PBNN import *

DATASET = 'ODE'

# GPU
device = torch.device(f"cuda:0") 
torch.cuda.set_device(device)
print(f"GPU:{device}")

# Loss Functions
mse = nn.MSELoss()
mse_none = nn.MSELoss(reduction='none')

# Loading Data
if DATASET == 'ODE':
    Z = np.load('./datasets/Z_ODE.npy')
elif DATASET == 'Lorent':
    Z = np.load('./datasets/Z_Lorent_4.npy')
elif DATASET == 'Hole':
    Z = np.load('./datasets/Z_holesinsi.npy')
elif DATASET == 'ADM':
    Z = np.load('./datasets/Z_ADM.npy') 
else:
    raise NotImplementedError('No Such Dataset')

Z = torch.tensor(Z)
INPUT_DIM = Z.shape[1]
OUTPUT_DIM = 2*INPUT_DIM
# Preparation of Data and Split
magZ = torch.abs(Z) #magT: T
realZ = Z.real
imagZ = Z.imag
z_cos = realZ/magZ
z_sin = imagZ/magZ
z_real = realZ
z_imag = imagZ

#Permute the data and Split
np.random.seed(42)
indices = np.random.permutation(len(magZ))
train_indices = indices[:len(magZ)//10*2]
test_indices = indices[len(magZ)//10*2:]

train_magZ = magZ[train_indices].float()
train_zcos = z_cos[train_indices]
train_zsin = z_sin[train_indices]
test_magZ = magZ[test_indices].float()
test_zcos = z_cos[test_indices]
test_zsin = z_sin[test_indices]

train_zreal = z_real[train_indices]
train_zimag = z_imag[train_indices]
test_zreal = z_real[test_indices]
test_zimag = z_imag[test_indices]

np.random.seed(42)
NUM_SAMPLE=10
train_samples = np.random.choice(range(len(train_magZ)),NUM_SAMPLE)

trloader = DataLoader(MyDataset2(train_magZ[train_samples],train_zreal[train_samples].float(),train_zimag[train_samples].float()),batch_size=256,shuffle=True)
teloader = DataLoader(MyDataset2(test_magZ[:],test_zreal[:].float(),test_zimag[:].float()),batch_size=512,shuffle=False)

## Train NN
# Example hyper-parameters
num_hidden = 256
num_layer = 2
dropout = 0.1
net = seq_maker(num_hidden, num_layer, OUTPUT_DIM, INPUT_DIM, dropout=dropout)
net.to(device)
# Optimize
optimizer = optim.Adam(net.parameters(), lr=1e-4)
min_test_loss = 1
for _ in range(6000):
    for i,data in enumerate(trloader):
        net.train()
        optimizer.zero_grad()   # zero the gradient buffers

        x,y1,y2 = data
        result = net(x.to(device))
        
        loss= mse(result,torch.cat([y1,y2],dim=1).to(device))
        loss.backward()
        optimizer.step() 
    with torch.no_grad():
        net.eval()
        loss=eval_loader_nn(net,trloader,device)
        test_loss=eval_loader_nn(net,teloader,device)
        if test_loss<min_test_loss:
            min_test_loss = test_loss
print('Minimum NN Test Loss:',min_test_loss)

## Train BPNN
# Example hyper-parameters
num_hidden = 256
num_layer = 2
dropout = 0.1
model = PBNN(dim=num_hidden,n_root=5,num_seq=1,dropout=dropout, num_hidden=num_layer)
model.to(device)

optimizer = optim.Adam(model.parameters(),lr=1e-3)
mse = nn.MSELoss()
min_trainloss = 1

for i in range(6000):
    for data in trloader:
        model.train()
        tx, tyreal, tyimag = data
        tx = tx.to(device)
        optimizer.zero_grad()
        predict_phase = model(tx)
        predict = predict_phase*tx
        recon_loss = mse(predict.real,tyreal.to(device))+mse(predict.imag,tyimag.to(device))
        loss = recon_loss
        loss.backward()
        optimizer.step()
        del tx, tyreal, tyimag
    with torch.no_grad():
        model.eval()
        trainloss = eval_loader(model,trloader,device)
        testloss = eval_loader(model,teloader,device)
        # print(f'epoch:{i},train loss:{trainloss}, test loss:{testloss}')
        if testloss<min_test_loss:
            min_test_loss = testloss
print('Minimum BPNN Test Loss:',min_test_loss) 
