import torch
import torch.nn as nn
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
import os



from argparse import ArgumentParser

# Lorenz system generator
def sample_lorenz_data(n, device="cpu", dt=0.01, sigma=10, beta=8/3, rho=28):
    x = np.zeros(n)
    y = np.zeros(n)
    z = np.zeros(n)
    x[0] = 0
    y[0] = 1
    z[0] = 0
    for i in range(1, n):
        x[i] = x[i-1] + dt * sigma * (y[i-1] - x[i-1])
        y[i] = y[i-1] + dt * (x[i-1] * (rho - z[i-1]) - y[i-1])
        z[i] = z[i-1] + dt * (x[i-1] * y[i-1] - beta * z[i-1])
    return torch.tensor(np.stack([x, y, z], axis=1), device=device, dtype=torch.float32)

# R2 function
def r2_score(y_pred, y_true):
    # Ensure the tensors are of the same shape
    assert y_true.shape == y_pred.shape, "Shape of y_true and y_pred must be the same"
    
    # Residual sum of squares
    ss_res = torch.sum((y_true - y_pred) ** 2)
    
    # Total sum of squares
    ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)
    
    # R2
    r2 = 1 - ss_res / ss_tot
    
    return r2.item()

# Define model class
class NonAutRNN(nn.Module):
    def __init__(
        self,
        in_out_size=1,
        width=512,
        nonlinearity="relu",
        fully_train=False
    ):
        super(NonAutRNN, self).__init__()
        # Recurrent layer
        self.recurrent = nn.RNN(
            input_size=in_out_size,
            hidden_size=width,
            nonlinearity=nonlinearity,
            batch_first=True
        )
        # Fully-connected layer
        self.fc = nn.Linear(
            width,
            in_out_size,
            bias=fully_train
        )

        for param in [
            self.recurrent.weight_hh_l0,
            self.recurrent.weight_ih_l0,
            self.fc.weight
        ]:
            nn.init.normal_(
                param,
                std=1/param.size(1)**0.5
            )

            param.requires_grad = fully_train

        nn.init.zeros_(self.recurrent.bias_hh_l0)
        self.recurrent.bias_hh_l0.requires_grad = fully_train

        # Set input biases between 0 and 1
        nn.init.uniform_(
            self.recurrent.bias_ih_l0,
            0,
            1
        )
    
    def forward(
        self,
        X,
        h=None
    ):
        out, hf = self.recurrent(X,h)
        out = self.fc(out)
        return out, hf

if __name__ == "__main__":
    # Import command line arguments
    parser = ArgumentParser()

    parser.add_argument("--seed",type=int,default=42)
    parser.add_argument("--epochs",type=int,default=10000)

    args = parser.parse_args()

    # Set seed and device
    seed = args.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set parameters
    tau = 27
    seq_len = tau*40
    dt = 0.01
    system = "lorenz"
    n_test = 1
    batch_size = 64
    width = 1024
    fully_train = True

    # Generate Lorenz signal and change to shape (batch_size+n_tests,seq_len+tau,1)
    signal_train = sample_lorenz_data(
        n=(batch_size+n_test)*(seq_len+tau),
        device=device,
        dt=dt
    )[:,0].cpu().numpy()

    sig = signal_train.reshape(batch_size+n_test,seq_len+tau,1)
    sig = torch.tensor(sig,dtype=torch.float32,device=device)

    X_train = sig[:-n_test]
    X_test = sig[-n_test:]

    # Initialize model
    model = NonAutRNN(
        width=width,
        fully_train=fully_train
    ).to(device)

    epochs = args.epochs
    lr = 0.0001

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=lr,
        weight_decay=0.1
    )

    highest_r2 = -float("inf")

    for epoch in range(epochs):
        optimizer.zero_grad()

        # Predict tau-shifted window and compute loss with ground-truth
        pred = model(X_train[:,:-tau])[0]
        loss = criterion(pred,X_train[:,tau:])

        val_loss = criterion(model(X_test[:,:-tau])[0],X_test[:,tau:])
        loss.backward()

        optimizer.step()

        if epoch % 100 == 0:
            curr_r2 = r2_score(model(X_test[:,:-tau])[0],X_test[:,tau:])

            print(f"Epoch {epoch}: train loss = {loss.item():.4f}, val loss = {val_loss.item():.4f}, r2 = {curr_r2:.4f},")
            
            # If R2 is the highest so far, save model
            if curr_r2 > highest_r2:
                
                torch.save(model,f"{system}_s{seed}_full.pt")
                highest_r2 = curr_r2

    print("Training complete!")

