import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data import Dataset

class ThreeBodyDataset(Dataset):
    def __init__(self, t_path: str, X_path: str, y_path: str):
        self.t = torch.Tensor(np.load(t_path))
        self.X = torch.Tensor(np.load(X_path))
        self.y = torch.Tensor(np.load(y_path))

    def __getitem__(self, index):
        return self.t[index], self.X[index], self.y[index]

    def __len__(self):
        return len(self.t)

class NaturalCubicSpline:
    def __init__(self, times, coeffs, **kwargs):
        super(NaturalCubicSpline, self).__init__(**kwargs)
        (a, b, two_c, three_d) = coeffs
        self._times = times # times.shape == (batch_size, n_take)
        self._a = a
        self._b = b
        # as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
        self._two_c = two_c
        self._three_d = three_d
        self.range = torch.arange(0, self._times.size(0))

    def _interpret_t(self, t):
        maxlen = self._b.size(-2) - 1
        index = (t > self._times).sum(dim=1) - 1 # index.size == (batch_size)
        index = index.clamp(0, maxlen)  # clamp because t may go outside of [t[0], t[-1]]; this is fine
        # will never access the last element of self._times; this is correct behaviour
        fractional_part = t - self._times[self.range, index]
        return fractional_part.unsqueeze(dim=1), index

    def evaluate(self, t):
        """Evaluates the natural cubic spline interpolation at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = 0.5 * self._two_c[self.range, index, :] + self._three_d[self.range, index, :] * fractional_part / 3
        inner = self._b[self.range, index, :] + inner * fractional_part
        return self._a[self.range, index, :] + inner * fractional_part

    def derivative(self, t):
        """Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = self._two_c[self.range, index, :] + self._three_d[self.range, index, :] * fractional_part
        deriv = self._b[self.range, index, :] + inner * fractional_part
        return deriv

class CDEFunc(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = nn.Linear(hidden_channels, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, input_channels * hidden_channels)

    def forward(self, z):
        # z.shape == (n_blocks, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)

        z = z.tanh()

        z = z.view(*z.shape[:-1], self.hidden_channels, self.input_channels) # z.shape == (n_blocks, hidden_size, input_size)
        return z

class GroupLinear_1_to_n(nn.Module):
    def __init__(self, input_size, output_size, n_block):
        super(GroupLinear_1_to_n, self).__init__()
        self.linear = nn.Linear(input_size, 2 * input_size)
        self.w = nn.Parameter(0.01 * torch.rand(n_block, 2 * input_size, output_size))
        self.b = nn.Parameter(torch.zeros(n_block, output_size))

    def forward(self, x):
        # x.shape == (input_size)
        # return.shape == (n_blocks, hidden_size)
        return torch.matmul(self.linear(x).relu(), self.w) + self.b

class GroupLinear_n_to_m(nn.Module):
    def __init__(self, input_size, output_size, n_block):
        super(GroupLinear_n_to_m, self).__init__()
        self.n_block = n_block
        self.linear = nn.Linear(input_size, output_size)
        self.w = nn.Parameter(0.01 * torch.rand(n_block, output_size, output_size))
        self.b = nn.Parameter(torch.zeros(n_block, output_size))

    def forward(self, x):
        # x.shape == (n_take, input_size)
        x = self.linear(x).relu() # x.shape == (n_take, output_size)
        # return.shape == (n_blocks, n_take, output_size)
        return torch.matmul(x, self.w) + self.b.expand(x.size(0), -1, -1).permute(1, 0, 2)

class IndCDE(nn.Module):
    def __init__(self, input_size: int = 28, hidden_size: int = 16, n_blocks: int = 6, output_size: int = 10, pred_n: int = 3):
        super(IndCDE, self).__init__()
        self.hidden_size = hidden_size
        self.n_blocks = n_blocks
        self.pred_n = pred_n
        
        self.x_encoder = GroupLinear_n_to_m(input_size, 2 * input_size, n_blocks)

        self.h_encoder = GroupLinear_1_to_n(input_size, hidden_size, n_blocks)
        self.h_decoder = nn.Linear(n_blocks * hidden_size, output_size)

        self.hidden_func = CDEFunc(2 * input_size, hidden_size)

        self.key_w_T = nn.Parameter(0.01 * torch.randn(n_blocks * n_blocks, hidden_size))
        self.query_w = nn.Parameter(0.01 * torch.randn(hidden_size, n_blocks * n_blocks))

        self.key_b = nn.Parameter(torch.zeros(n_blocks))
        self.query_b = nn.Parameter(torch.zeros(n_blocks))

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (1, self.n_blocks * self.n_blocks)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, t, x):
        # t.shape == (n_take)
        # x.shape == (n_take, input_size)

        preds = []
        # z0.shape == (1, n_blocks, hidden_size)
        z0 = self.h_encoder(x[0]).unsqueeze(dim=0)

        k = self.transpose_for_scores(torch.matmul(z0, self.key_w_T.T) + self.key_b.expand(self.n_blocks * self.n_blocks, -1).T)
        q = self.transpose_for_scores(torch.matmul(z0, self.query_w) + self.query_b.expand(self.n_blocks * self.n_blocks, -1).T)

        # z0.shape == (n_blocks, hidden_size)
        z0 = z0.squeeze(dim=0)

        # k.shape, q.shape == (1, 1, n_blocks, n_blocks * n_blocks)
        l0 = torch.matmul(q, k.transpose(-1, -2)).squeeze() / self.n_blocks # l0.shape == (n_blocks, n_blocks)

        x = self.x_encoder(x) # x.shape == (n_blocks, n_take, 2 * input_size)

        temp_t = t.expand(self.n_blocks, -1)
        x_spline = NaturalCubicSpline(temp_t, cubic_spline(temp_t, x))

        insert_t = torch.linspace(int(t[0]), int(t[-1]), 3 * int(t[-1] - t[0]) + 1, device=t.device)

        for i in insert_t:
            a0 = nn.Softmax(dim=-1)(l0)

            dz = torch.bmm(
                self.hidden_func(torch.matmul(a0, z0)), # shape == (n_blocks, hidden_size, input_size)
                x_spline.derivative(i).unsqueeze(dim=-1)
            ).squeeze(dim=-1)

            # dz.shape == (n_blocks, hidden_size)
            z0 = z0 + dz / 3

            dl = torch.matmul(torch.matmul(torch.matmul(dz, self.query_w), self.key_w_T), z0.T)
            dl = dl + torch.matmul(torch.matmul(torch.matmul(z0, self.query_w), self.key_w_T), dz.T)

            dl = dl + torch.matmul(torch.matmul(dz, self.query_w), self.key_b.expand(self.n_blocks * self.n_blocks, -1))

            dl = dl + torch.matmul(torch.matmul(self.query_b.expand(self.n_blocks * self.n_blocks, -1).T, self.key_w_T), dz.T)

            l0 = l0 + dl / self.n_blocks / 3

        preds.append(self.h_decoder(z0.reshape(-1)))

        for i in range(1, self.pred_n):
            t = torch.hstack([t[1:], t[-1] + 1])
            x = torch.cat([x[:, 1:, :], self.x_encoder(preds[-1].unsqueeze(dim=0))], dim=1)

            temp_t = t.expand(self.n_blocks, -1)
            x_spline = NaturalCubicSpline(temp_t, cubic_spline(temp_t, x))

            for dt in torch.linspace(0, 1, 4)[1:]:
                a0 = nn.Softmax(dim=-1)(l0)

                dz = torch.bmm(
                    self.hidden_func(torch.matmul(a0, z0)), # shape == (n_blocks, hidden_size, input_size)
                    x_spline.derivative(t[-1] - 1 + dt).unsqueeze(dim=-1)
                ).squeeze(dim=-1)

                z0 = z0 + dz / 3

                dl = torch.matmul(torch.matmul(torch.matmul(dz, self.query_w), self.key_w_T), z0.T)
                dl = dl + torch.matmul(torch.matmul(torch.matmul(z0, self.query_w), self.key_w_T), dz.T)
                dl = dl + torch.matmul(torch.matmul(dz, self.query_w), self.key_b.expand(self.n_blocks * self.n_blocks, -1))
                dl = dl + torch.matmul(torch.matmul(self.query_b.expand(self.n_blocks * self.n_blocks, -1).T, self.key_w_T), dz.T)

                l0 = l0 + dl / self.n_blocks / 3

            preds.append(self.h_decoder(z0.reshape(-1)))
        
        return torch.stack(preds)

def tridiagonal_solve(b_, A_upper_, A_diagonal_, A_lower_):
    """Solves a tridiagonal system Ax = b.

    The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
    and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
    of size (k, k), with entries:

    D[0] U[0]
    L[0] D[1] U[1]
         L[1] D[2] U[2]                     0
              L[2] D[3] U[3]
                  .    .    .
                       .      .      .
                           .        .        .
                        L[k - 3] D[k - 2] U[k - 2]
           0                     L[k - 2] D[k - 1] U[k - 1]
                                          L[k - 1]   D[k]

    Arguments:
        b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
        A_upper: A tensor of shape (..., k - 1).
        A_diagonal: A tensor of shape (..., k).
        A_lower: A tensor of shape (..., k - 1).

    Returns:
        A tensor of shape (..., k), corresponding to the x solving Ax = b

    Warning:
        This implementation isn't super fast. You probably want to cache the result, if possible.
    """

    # This implementation is very much written for clarity rather than speed.

    A_upper = torch.empty(b_.size(0), b_.size(1), b_.size(2) - 1, dtype=b_.dtype, device=b_.device)
    A_lower = torch.empty(b_.size(0), b_.size(1), b_.size(2) - 1, dtype=b_.dtype, device=b_.device)
    A_diagonal = torch.empty(*b_.shape, dtype=b_.dtype, device=b_.device)
    b = torch.empty(*b_.shape, dtype=b_.dtype, device=b_.device)

    for i in range(b_.size(0)):
        A_upper[i], _ = torch.broadcast_tensors(A_upper_[i], b_[i, :, :-1])
        A_lower[i], _ = torch.broadcast_tensors(A_lower_[i], b_[i, :, :-1])
        A_diagonal[i], b[i] = torch.broadcast_tensors(A_diagonal_[i], b_[i])

    channels = b.size(-1)

    new_shape = (b.size(0), channels, b.size(1))
    new_b = torch.zeros(*new_shape, dtype=b.dtype, device=b_.device)
    new_A_diagonal = torch.empty(*new_shape, dtype=b.dtype, device=b_.device)
    outs = torch.empty(*new_shape, dtype=b.dtype, device=b_.device)
    
    new_b[:, 0] = b[..., 0]
    new_A_diagonal[:, 0] = A_diagonal[..., 0]
    for i in range(1, channels):
        w = A_lower[..., i - 1] / new_A_diagonal[:, i - 1]
        new_A_diagonal[:, i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
        new_b[:, i] = b[..., i] - w * new_b[:, i - 1]

    outs[:, channels - 1] = new_b[:, channels - 1] / new_A_diagonal[:, channels - 1]
    for i in range(channels - 2, -1, -1):
        outs[:, i] = (new_b[:, i] - A_upper[..., i] * outs[:, i + 1]) / new_A_diagonal[:, i]

    return outs.permute(0, 2, 1)

def cubic_spline(times, x):
    path = x.transpose(-1, -2)
    length = path.size(-1)

    # Set up some intermediate values
    time_diffs = times[:, 1:] - times[:, :-1]
    time_diffs_reciprocal = time_diffs.reciprocal()
    time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2

    three_path_diffs = 3 * (path[..., 1:] - path[..., :-1])
    six_path_diffs = 2 * three_path_diffs

    # path_diffs_scaled.shape == (batch_size, input_size, n_take)
    path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared.unsqueeze(dim=1)

    # Solve a tridiagonal linear system to find the derivatives at the knots
    system_diagonal = torch.empty(times.size(0), length, dtype=path.dtype, device=path.device)
    system_diagonal[:, :-1] = time_diffs_reciprocal
    system_diagonal[:, -1] = 0
    system_diagonal[:, 1:] += time_diffs_reciprocal
    system_diagonal *= 2
    system_rhs = torch.empty(*path.shape, dtype=path.dtype, device=path.device)
    system_rhs[..., :-1] = path_diffs_scaled
    system_rhs[..., -1] = 0
    system_rhs[..., 1:] += path_diffs_scaled

    knot_derivatives = tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal, time_diffs_reciprocal)

    a = path[..., :-1]
    b = knot_derivatives[..., :-1]
    two_c = (six_path_diffs * time_diffs_reciprocal.unsqueeze(dim=1)
            - 4 * knot_derivatives[..., :-1]
            - 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal.unsqueeze(dim=1)
    three_d = (-six_path_diffs * time_diffs_reciprocal.unsqueeze(dim=1)
            + 3 * (knot_derivatives[..., :-1]
                    + knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared.unsqueeze(dim=1)

    return a.transpose(-1, -2), b.transpose(-1, -2), two_c.transpose(-1, -2), three_d.transpose(-1, -2)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    train_dataset = ThreeBodyDataset(
        t_path='three_body/train_t.npy',
        X_path='three_body/train_x.npy',
        y_path='three_body/train_y.npy'
    )

    test_dataset = ThreeBodyDataset(
        t_path='three_body/test_t.npy',
        X_path='three_body/test_x.npy',
        y_path='three_body/test_y.npy'
    )

    train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0
    )

    test_loader = Data.DataLoader(
        dataset=test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0
    )

    print(device)

    model = IndCDE(input_size=9, hidden_size=hidden_size, n_blocks=3, output_size=9, pred_n=3).to(device)

    batch_size = 128
    num_epochs = 100

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader) / batch_size, eta_min = 0.00005, last_epoch = -1)

    criterion = nn.MSELoss()

    train_loss = []
    test_loss = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        epoch_loss = 0
        model.train()
        for i, (t, x, y) in enumerate(tqdm(train_loader)):
            t, x, y = t.to(device), x.to(device), y.to(device)
            output = model(t[0], x[0])
            loss = criterion(output.view(-1), y.view(-1))
            loss.backward()
            epoch_loss += loss.item()

            if i % batch_size == 0 and i != 0:
                print(' ', version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1, norm_type=2)
                optimizer.step()
                optimizer.zero_grad()
                train_loss.append(epoch_loss / batch_size)
                print(' ', train_loss[-1])
                epoch_loss = 0
                scheduler.step()

        epoch_loss = 0
        model.eval()
        with torch.no_grad():
            for i, (t, x, y) in enumerate(tqdm(test_loader)):
                t, x, y = t.to(device), x.to(device), y.to(device)
                output = model(t[0], x[0])
                loss = criterion(output.view(-1), y.view(-1))
                epoch_loss += loss.item()
            test_loss.append(epoch_loss / len(test_loader))
            print(' ', test_loss[-1])
            epoch_loss = 0

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [512, 1024, 2048]:
        file_version = f'IndCDE_Body_{hidden_size}'
        main(file_version, hidden_size)