import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data import Dataset

class ThreeBodyDataset(Dataset):
    def __init__(self, t_path: str, X_path: str, y_path: str):
        self.t = torch.Tensor(np.load(t_path))
        t_length = self.t.shape[1]
        for i in range(1, t_length):
            self.t[:, -i] = self.t[:, -i] - self.t[:, -i-1]
        self.t[:, 0] = 1
        self.X = torch.Tensor(np.load(X_path))
        self.y = torch.Tensor(np.load(y_path))

    def __getitem__(self, index):
        return self.t[index], self.X[index], self.y[index]

    def __len__(self):
        return len(self.t)

class CTRNN(nn.Module):
    def __init__(self, device, input_size: int, hidden_size: int, output_size: int, pred_n: int = 3, num_unfolds: int = 3, tau: int = 1):
        super(CTRNN, self).__init__()
        self.device = device
        self.units = hidden_size
        self.state_size = hidden_size
        self.num_unfolds = num_unfolds
        self.tau = tau
        self.pred_n = pred_n
        
        self.kernel = nn.Linear(input_size, hidden_size)
        self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False)
        self.scale = nn.Parameter(torch.ones(hidden_size))

        self.decoder = nn.Linear(hidden_size, output_size)


    def forward(self, t, x):
        # t.shape == (batch_size, n_take)
        # x.shape == (batch_size, n_take, input_size)
        preds = []
        hidden_state = torch.zeros((t.size(0), self.units)).to(self.device)
        for i in range(t.size(1)):
            delta_t = t[:, i] / self.num_unfolds
            for _ in range(self.num_unfolds):
                hidden_state = self.euler(x[:, i, :], hidden_state, delta_t)
        preds.append(self.decoder(hidden_state).unsqueeze(dim=1))
        for i in range(1, self.pred_n):
            delta_t = torch.ones(t.size(0)).to(self.device) / self.num_unfolds
            for _ in range(self.num_unfolds):
                hidden_state = self.euler(preds[-1].squeeze(), hidden_state, delta_t)
            preds.append(self.decoder(hidden_state).unsqueeze(dim=1))
        return torch.cat(preds, dim=1)

    def dfdt(self, inputs, hidden_state):
        dh_in = self.scale * (self.kernel(inputs) + self.recurrent_kernel(hidden_state)).tanh()
        if self.tau > 0:
            dh = dh_in - hidden_state * self.tau
        else:
            dh = dh_in
        return dh

    def euler(self, inputs, hidden_state, delta_t):
        return hidden_state + (delta_t * self.dfdt(inputs, hidden_state).permute(1, 0)).permute(1, 0)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_dataset = ThreeBodyDataset(
        t_path='irregular_three_body/train_t.npy',
        X_path='irregular_three_body/train_x.npy',
        y_path='irregular_three_body/train_y.npy'
    )

    test_dataset = ThreeBodyDataset(
        t_path='irregular_three_body/test_t.npy',
        X_path='irregular_three_body/test_x.npy',
        y_path='irregular_three_body/test_y.npy'
    )

    train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0
    )

    test_loader = Data.DataLoader(
        dataset=test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0
    )

    print(device)

    model = CTRNN(device, input_size=9, hidden_size=hidden_size, output_size=9, pred_n=3).to(device)

    batch_size = 128
    num_epochs = 100

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader) / batch_size, eta_min = 0.00005, last_epoch = -1)

    criterion = nn.MSELoss()

    train_loss = []
    test_loss = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        batch_loss = 0
        model.train()
        for i, (t, x, y) in enumerate(tqdm(train_loader)):
            t, x, y = t.to(device), x.to(device), y.to(device)
            output = model(t, x)
            loss = criterion(output.view(-1), y.view(-1))
            loss.backward()
            batch_loss += loss.item()

            if i % batch_size == 0 and i != 0:
                print(' ', version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1, norm_type=2)
                optimizer.step()
                optimizer.zero_grad()
                train_loss.append(batch_loss / batch_size)
                print(' ', train_loss[-1])
                batch_loss = 0
                scheduler.step()

        epoch_loss = 0
        model.eval()
        with torch.no_grad():
            for i, (t, x, y) in enumerate(tqdm(test_loader)):
                t, x, y = t.to(device), x.to(device), y.to(device)
                output = model(t, x)
                loss = criterion(output.view(-1), y.view(-1))
                epoch_loss += loss.item()
            test_loss.append(epoch_loss / len(test_loader))
            print(' ', test_loss[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train.csv')
        except:
            print(f'Fail to save the file {version}_Train.csv')
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test.csv')
        except:
            print(f'Fail to save the file {version}_Test.csv')
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [512, 1024, 2048]:
        file_version = f'CTRNN_IrrBody_{hidden_size}'
        main(file_version, hidden_size)