import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data import Dataset

class ThreeBodyDataset(Dataset):
    def __init__(self, t_path: str, X_path: str, y_path: str):
        self.t = torch.Tensor(np.load(t_path))
        t_length = self.t.shape[1]
        for i in range(1, t_length):
            self.t[:, -i] = self.t[:, -i] - self.t[:, -i-1]
        self.t[:, 0] = 1
        self.X = torch.Tensor(np.load(X_path))
        self.y = torch.Tensor(np.load(y_path))

    def __getitem__(self, index):
        return self.t[index], self.X[index], self.y[index]

    def __len__(self):
        return len(self.t)

class CTGRU(nn.Module):
    def __init__(self, device, input_size: int, hidden_size: int, output_size: int, pred_n: int = 3, tau: int = 1, M: int = 8):
        super(CTGRU, self).__init__()
        self.device = device
        self.units = hidden_size
        self.M = M
        self.state_size = hidden_size * M
        self.pred_n = pred_n
        
        self.ln_tau_table = torch.empty(self.M).to(device)
        self.tau_table = torch.empty(self.M).to(device)
        tau = torch.tensor(1.0).to(device)
        for i in range(self.M):
            self.ln_tau_table[i] = torch.log(tau)
            self.tau_table[i] = tau
            tau = tau * (10.0 ** 0.5)

        self.retrieval_layer = nn.Linear(input_size + hidden_size, M * hidden_size)
        self.detect_layer = nn.Sequential(
            nn.Linear(input_size + hidden_size, hidden_size),
            nn.Tanh()
        )
        self.update_layer = nn.Linear(input_size + hidden_size, M * hidden_size)

        self.decoder = nn.Linear(hidden_size * self.M, output_size)


    def forward(self, t, x):
        # t.shape == (batch_size, n_take)
        # x.shape == (batch_size, n_take, input_size)
        preds = []
        h_hat = torch.zeros((t.size(0), self.units, self.M)).to(self.device)
        
        for i in range(t.size(1)):
            h = torch.sum(h_hat, dim=2)
            fused_input = torch.cat([x[:, i, :], h], dim=-1)
            ln_tau_r = self.retrieval_layer(fused_input).reshape(-1, self.units, self.M)
            
            sf_input_r = - torch.square(ln_tau_r - self.ln_tau_table)
            rki = nn.Softmax(dim=2)(sf_input_r)

            q_input = torch.sum(rki * h_hat, dim=2)
            reset_value = torch.cat([x[:, i, :], q_input], dim=1)
            qk = self.detect_layer(reset_value).unsqueeze(dim=-1)

            ln_tau_s = self.update_layer(fused_input).reshape(-1, self.units, self.M)
            sf_input_s = - torch.square(ln_tau_s - self.ln_tau_table)
            ski = nn.Softmax(dim=2)(sf_input_s)

            base_term = (1 - ski) * h_hat + ski * qk
            exp_term = torch.exp(- t[:, i].repeat(self.M).reshape(-1, self.M) / self.tau_table).reshape(-1, 1, self.M)

            h_hat = base_term * exp_term

        preds.append(self.decoder(h_hat.reshape(-1, self.units * self.M)).unsqueeze(dim=1))

        for i in range(1, self.pred_n):
            h = torch.sum(h_hat, dim=2)
            fused_input = torch.cat([preds[-1].squeeze(dim=1), h], dim=-1)
            ln_tau_r = self.retrieval_layer(fused_input).reshape(-1, self.units, self.M)
            
            sf_input_r = - torch.square(ln_tau_r - self.ln_tau_table)
            rki = nn.Softmax(dim=2)(sf_input_r)

            q_input = torch.sum(rki * h_hat, dim=2)
            reset_value = torch.cat([preds[-1].squeeze(dim=1), q_input], dim=1)
            qk = self.detect_layer(reset_value).unsqueeze(dim=-1)

            ln_tau_s = self.update_layer(fused_input).reshape(-1, self.units, self.M)
            sf_input_s = - torch.square(ln_tau_s - self.ln_tau_table)
            ski = nn.Softmax(dim=2)(sf_input_s)

            base_term = (1 - ski) * h_hat + ski * qk
            exp_term = torch.exp(- torch.ones((x.size(0), self.M)).to(self.device) / self.tau_table).reshape(-1, 1, self.M)

            h_hat = base_term * exp_term

            preds.append(self.decoder(h_hat.reshape(-1, self.units * self.M)).unsqueeze(dim=1))
        return torch.cat(preds, dim=1)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    train_dataset = ThreeBodyDataset(
        t_path='irregular_three_body/train_t.npy',
        X_path='irregular_three_body/train_x.npy',
        y_path='irregular_three_body/train_y.npy'
    )

    test_dataset = ThreeBodyDataset(
        t_path='irregular_three_body/test_t.npy',
        X_path='irregular_three_body/test_x.npy',
        y_path='irregular_three_body/test_y.npy'
    )

    train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0
    )

    test_loader = Data.DataLoader(
        dataset=test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0
    )

    print(device)

    model = CTGRU(device, input_size=9, hidden_size=hidden_size, output_size=9, pred_n=3).to(device)

    batch_size = 128
    num_epochs = 100

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader) / batch_size, eta_min = 0.00005, last_epoch = -1)

    criterion = nn.MSELoss()

    train_loss = []
    test_loss = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        batch_loss = 0
        model.train()
        for i, (t, x, y) in enumerate(tqdm(train_loader)):
            t, x, y = t.to(device), x.to(device), y.to(device)
            output = model(t, x)
            loss = criterion(output.view(-1), y.view(-1))
            loss.backward()
            batch_loss += loss.item()

            if i % batch_size == 0 and i != 0:
                print(' ', version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1, norm_type=2)
                optimizer.step()
                optimizer.zero_grad()
                train_loss.append(batch_loss / batch_size)
                print(' ', train_loss[-1])
                batch_loss = 0
                scheduler.step()

        epoch_loss = 0
        model.eval()
        with torch.no_grad():
            for i, (t, x, y) in enumerate(tqdm(test_loader)):
                t, x, y = t.to(device), x.to(device), y.to(device)
                output = model(t, x)
                loss = criterion(output.view(-1), y.view(-1))
                epoch_loss += loss.item()
            test_loss.append(epoch_loss / len(test_loader))
            print(' ', test_loss[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train.csv')
        except:
            print(f'Fail to save the file {version}_Train.csv')
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test.csv')
        except:
            print(f'Fail to save the file {version}_Test.csv')
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [512, 1024, 2048]:
        file_version = f'CTGRU_IrrBody_{hidden_size}'
        main(file_version, hidden_size)