import os
import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data.dataset import TensorDataset

def load_data(path, suffix, batch_size = 128):
    loc_train, loc_test, vel_train, vel_test, edges_train, edges_test = train_test_split(
        np.load(os.path.join(path, 'loc_' + suffix + '.npy')),
        np.load(os.path.join(path, 'vel_' + suffix + '.npy')),
        np.load(os.path.join(path, 'edges_' + suffix + '.npy')),
        test_size=0.2, random_state=42)

    # [num_samples, num_timesteps, num_dims, num_atoms]
    num_atoms = loc_train.shape[3]

    loc_max = loc_train.max()
    loc_min = loc_train.min()
    vel_max = vel_train.max()
    vel_min = vel_train.min()

    # Normalize to [-1, 1]
    loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1
    vel_train = (vel_train - vel_min) * 2 / (vel_max - vel_min) - 1

    loc_test = (loc_test - loc_min) * 2 / (loc_max - loc_min) - 1
    vel_test = (vel_test - vel_min) * 2 / (vel_max - vel_min) - 1


    loc_train = loc_train.reshape(*loc_train.shape[:2], -1)
    vel_train = vel_train.reshape(*vel_train.shape[:2], -1)
    feat_train = np.concatenate([loc_train, vel_train], axis=-1)
    edges_train = np.reshape(edges_train, [-1, num_atoms ** 2])
    edges_train = np.array((edges_train + 1) / 2, dtype=np.int64)

    loc_test = loc_test.reshape(*loc_test.shape[:2], -1)
    vel_test = vel_test.reshape(*vel_test.shape[:2], -1)
    feat_test = np.concatenate([loc_test, vel_test], axis=-1)
    edges_test = np.reshape(edges_test, [-1, num_atoms ** 2])
    edges_test = np.array((edges_test + 1) / 2, dtype=np.int64)

    feat_train = torch.FloatTensor(feat_train)
    edges_train = torch.FloatTensor(edges_train)
    feat_test = torch.FloatTensor(feat_test)
    edges_test = torch.FloatTensor(edges_test)

    off_diag_idx = np.ravel_multi_index(
        np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
        [num_atoms, num_atoms])
    edges_train = edges_train[:, off_diag_idx]
    edges_test = edges_test[:, off_diag_idx]

    train_data = TensorDataset(feat_train, edges_train)
    test_data = TensorDataset(feat_test, edges_test)

    train_data_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data_loader = Data.DataLoader(test_data, batch_size=batch_size)

    return train_data_loader, test_data_loader

class NaturalCubicSpline:
    def __init__(self, times, coeffs, **kwargs):
        super(NaturalCubicSpline, self).__init__(**kwargs)
        (a, b, two_c, three_d) = coeffs
        self._times = times # times.shape == (batch_size, n_take)
        self._a = a
        self._b = b
        # as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
        self._two_c = two_c
        self._three_d = three_d
        self.range = torch.arange(0, self._times.size(0))

    def _interpret_t(self, t):
        maxlen = self._b.size(-2) - 1
        index = (t > self._times).sum(dim=1) - 1 # index.size == (batch_size)
        index = index.clamp(0, maxlen)  # clamp because t may go outside of [t[0], t[-1]]; this is fine
        # will never access the last element of self._times; this is correct behaviour
        fractional_part = t - self._times[self.range, index]
        return fractional_part.unsqueeze(dim=1), index

    def evaluate(self, t):
        """Evaluates the natural cubic spline interpolation at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = 0.5 * self._two_c[self.range, index, :] + self._three_d[self.range, index, :] * fractional_part / 3
        inner = self._b[self.range, index, :] + inner * fractional_part
        return self._a[self.range, index, :] + inner * fractional_part

    def derivative(self, t):
        """Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = self._two_c[self.range, index, :] + self._three_d[self.range, index, :] * fractional_part
        deriv = self._b[self.range, index, :] + inner * fractional_part
        return deriv

class CDEFunc(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = nn.Linear(hidden_channels, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, input_channels * hidden_channels)

    def forward(self, z):
        # z.shape == (n_blocks, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)

        z = z.tanh()

        z = z.view(*z.shape[:-1], self.hidden_channels, self.input_channels) # z.shape == (n_blocks, hidden_size, input_size)
        return z

class GroupLinear_1_to_n(nn.Module):
    def __init__(self, input_size, output_size, n_block):
        super(GroupLinear_1_to_n, self).__init__()
        self.linear = nn.Linear(input_size, 2 * input_size)
        self.w = nn.Parameter(0.01 * torch.rand(n_block, 2 * input_size, output_size))
        self.b = nn.Parameter(torch.zeros(n_block, output_size))

    def forward(self, x):
        # x.shape == (batch_size, input_size)
        # return.shape == (batch_size, n_blocks, hidden_size)
        return torch.matmul(self.linear(x).relu(), self.w).permute(1, 0, 2) + self.b

class GroupLinear_n_to_m(nn.Module):
    def __init__(self, input_size, output_size, n_block):
        super(GroupLinear_n_to_m, self).__init__()
        self.n_block = n_block
        self.linear = nn.Linear(input_size, output_size)
        self.w = nn.Parameter(0.01 * torch.rand(n_block, output_size, output_size))
        self.b = nn.Parameter(torch.zeros(n_block, output_size))

    def forward(self, x):
        # x.shape == (batch_size, n_take, input_size)
        x = self.linear(x).relu() # x.shape == (batch_size, n_take, output_size)
        # return.shape == (batch_size, n_blocks, n_take, output_size)
        return torch.matmul(x.expand(self.n_block, -1, -1, -1).permute(1, 0, 2, 3), self.w) + self.b.expand(x.size(1), -1, -1).permute(1, 0, 2)

class IndCDE(nn.Module):
    def __init__(self, input_size: int = 28, hidden_size: int = 16, n_blocks: int = 6, output_size: int = 10):
        super(IndCDE, self).__init__()
        self.hidden_size = hidden_size
        self.n_blocks = n_blocks
        
        self.x_encoder = GroupLinear_n_to_m(input_size, 2 * input_size, n_blocks)

        self.h_encoder = GroupLinear_1_to_n(input_size, hidden_size, n_blocks)
        self.h_decoder = nn.Linear(n_blocks * hidden_size, output_size)

        self.hidden_func = CDEFunc(2 * input_size, hidden_size)

        self.key_w_T = nn.Parameter(0.01 * torch.randn(n_blocks * n_blocks, hidden_size))
        self.query_w = nn.Parameter(0.01 * torch.randn(hidden_size, n_blocks * n_blocks))

        self.key_b = nn.Parameter(torch.zeros(n_blocks))
        self.query_b = nn.Parameter(torch.zeros(n_blocks))

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (1, self.n_blocks * self.n_blocks)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, t, x):
        # t.shape == (batch_size, n_take)
        # x.shape == (batch_size, n_take, input_size)

        # z0.shape == (batch_size, n_blocks, hidden_size)
        z0 = self.h_encoder(x[:, 0, :])

        k = self.transpose_for_scores(torch.matmul(z0, self.key_w_T.T) + self.key_b.expand(self.n_blocks * self.n_blocks, -1).T)
        q = self.transpose_for_scores(torch.matmul(z0, self.query_w) + self.query_b.expand(self.n_blocks * self.n_blocks, -1).T)

        # k.shape, q.shape == (batch_size, 1, n_blocks, n_blocks * n_blocks)
        l0 = torch.matmul(q, k.transpose(-1, -2)).squeeze(dim=1) / self.n_blocks # l0.shape == (batch_size, n_blocks, n_blocks)

        x = self.x_encoder(x) # x.shape == (batch_size, n_blocks, n_take, 2 * input_size)

        temp_t = t.expand(self.n_blocks, -1, -1).permute(1, 0, 2).reshape(-1, t.size(1))
        x_spline = NaturalCubicSpline(temp_t, cubic_spline(temp_t, x.reshape(-1, t.size(1), x.size(-1))))
        temp_t = None

        insert_t = torch.linspace(int(t[0, 0]), int(t[0, -1]), 3 * int(t[0, -1] - t[0, 0]) + 1, device=t.device)

        for i in insert_t:
            a0 = nn.Softmax(dim=-1)(l0)
            dXdt = x_spline.derivative(i).reshape(t.size(0), self.n_blocks, -1)

            # torch.bmm(a0, z0).shape == (batch_size, n_blocks, hidden_size)
            # self.hidden_func(torch.bmm(a0, z0)).shape == (batch_size, n_blocks, hidden_size, 2 * input_size)
            # dz.shape == (batch_size, n_blocks, hidden_size)

            dz = torch.matmul(
                self.hidden_func(torch.bmm(a0, z0)), # shape == (n_blocks, hidden_size, input_size)
                dXdt.unsqueeze(dim=-1)
            ).squeeze(dim=-1)

            z0 = z0 + dz / 3

            ##############################################################
            #####          n is n_blocks, h is hidden _size          #####
            #####    Z: n x h,    K, Q: h x n^2,    k, q: n x n^2    #####
            ##############################################################
            #####              l = (ZQ + q)(ZK + k)^T                #####
            #####    \frac{dl}{dt} = Z' Q K^T Z^T + Z Q K^T Z^{'T}   #####
            #####                        + Z' Q k^T + q K^T Z^{'T}   #####
            ##############################################################

            # dl.shape == (batch_size, n_blocks, n_blocks)
            dl = torch.bmm(torch.matmul(torch.matmul(dz, self.query_w), self.key_w_T), z0.permute(0, 2, 1))

            dl = dl + torch.bmm(torch.matmul(torch.matmul(z0, self.query_w), self.key_w_T), dz.permute(0, 2, 1))
            dl = dl + torch.matmul(torch.matmul(dz, self.query_w), self.key_b.expand(self.n_blocks * self.n_blocks, -1))
            dl = dl + torch.matmul(torch.matmul(self.query_b.expand(self.n_blocks * self.n_blocks, -1).T, self.key_w_T), dz.permute(0, 2, 1))

            l0 = l0 + dl / self.n_blocks / 3

        return self.h_decoder(z0.reshape(-1, self.n_blocks * self.hidden_size))

def tridiagonal_solve(b_, A_upper_, A_diagonal_, A_lower_):
    """Solves a tridiagonal system Ax = b.

    The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
    and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
    of size (k, k), with entries:

    D[0] U[0]
    L[0] D[1] U[1]
         L[1] D[2] U[2]                     0
              L[2] D[3] U[3]
                  .    .    .
                       .      .      .
                           .        .        .
                        L[k - 3] D[k - 2] U[k - 2]
           0                     L[k - 2] D[k - 1] U[k - 1]
                                          L[k - 1]   D[k]

    Arguments:
        b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
        A_upper: A tensor of shape (..., k - 1).
        A_diagonal: A tensor of shape (..., k).
        A_lower: A tensor of shape (..., k - 1).

    Returns:
        A tensor of shape (..., k), corresponding to the x solving Ax = b

    Warning:
        This implementation isn't super fast. You probably want to cache the result, if possible.
    """

    # This implementation is very much written for clarity rather than speed.

    A_upper = torch.empty(b_.size(0), b_.size(1), b_.size(2) - 1, dtype=b_.dtype, device=b_.device)
    A_lower = torch.empty(b_.size(0), b_.size(1), b_.size(2) - 1, dtype=b_.dtype, device=b_.device)
    A_diagonal = torch.empty(*b_.shape, dtype=b_.dtype, device=b_.device)
    b = torch.empty(*b_.shape, dtype=b_.dtype, device=b_.device)

    for i in range(b_.size(0)):
        A_upper[i], _ = torch.broadcast_tensors(A_upper_[i], b_[i, :, :-1])
        A_lower[i], _ = torch.broadcast_tensors(A_lower_[i], b_[i, :, :-1])
        A_diagonal[i], b[i] = torch.broadcast_tensors(A_diagonal_[i], b_[i])

    channels = b.size(-1)

    new_shape = (b.size(0), channels, b.size(1))
    new_b = torch.zeros(*new_shape, dtype=b.dtype, device=b_.device)
    new_A_diagonal = torch.empty(*new_shape, dtype=b.dtype, device=b_.device)
    outs = torch.empty(*new_shape, dtype=b.dtype, device=b_.device)
    
    new_b[:, 0] = b[..., 0]
    new_A_diagonal[:, 0] = A_diagonal[..., 0]
    for i in range(1, channels):
        w = A_lower[..., i - 1] / new_A_diagonal[:, i - 1]
        new_A_diagonal[:, i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
        new_b[:, i] = b[..., i] - w * new_b[:, i - 1]

    outs[:, channels - 1] = new_b[:, channels - 1] / new_A_diagonal[:, channels - 1]
    for i in range(channels - 2, -1, -1):
        outs[:, i] = (new_b[:, i] - A_upper[..., i] * outs[:, i + 1]) / new_A_diagonal[:, i]

    return outs.permute(0, 2, 1)

def cubic_spline(times, x):
    path = x.transpose(-1, -2)
    length = path.size(-1)

    # Set up some intermediate values
    time_diffs = times[:, 1:] - times[:, :-1]
    time_diffs_reciprocal = time_diffs.reciprocal()
    time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2

    three_path_diffs = 3 * (path[..., 1:] - path[..., :-1])
    six_path_diffs = 2 * three_path_diffs

    # path_diffs_scaled.shape == (batch_size, input_size, n_take)
    path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared.unsqueeze(dim=1)

    # Solve a tridiagonal linear system to find the derivatives at the knots
    system_diagonal = torch.empty(times.size(0), length, dtype=path.dtype, device=path.device)
    system_diagonal[:, :-1] = time_diffs_reciprocal
    system_diagonal[:, -1] = 0
    system_diagonal[:, 1:] += time_diffs_reciprocal
    system_diagonal *= 2
    system_rhs = torch.empty(*path.shape, dtype=path.dtype, device=path.device)
    system_rhs[..., :-1] = path_diffs_scaled
    system_rhs[..., -1] = 0
    system_rhs[..., 1:] += path_diffs_scaled

    knot_derivatives = tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal, time_diffs_reciprocal)

    a = path[..., :-1]
    b = knot_derivatives[..., :-1]
    two_c = (six_path_diffs * time_diffs_reciprocal.unsqueeze(dim=1)
            - 4 * knot_derivatives[..., :-1]
            - 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal.unsqueeze(dim=1)
    three_d = (-six_path_diffs * time_diffs_reciprocal.unsqueeze(dim=1)
            + 3 * (knot_derivatives[..., :-1]
                    + knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared.unsqueeze(dim=1)

    return a.transpose(-1, -2), b.transpose(-1, -2), two_c.transpose(-1, -2), three_d.transpose(-1, -2)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    num_epochs = 25
    batch_size = 128
    n_take = 49

    train_loader, test_loader = load_data('spring', 'springs5', batch_size)

    print(device)

    model = IndCDE(input_size=20, hidden_size=hidden_size, n_blocks=10, output_size=20).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader), eta_min = 0.00005, last_epoch = -1)

    criterion = nn.BCEWithLogitsLoss()

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(num_epochs):
        print(version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.train()
        for x, y in tqdm(train_loader):
            t = torch.arange(n_take).expand(x.size(0), -1).to(device)

            x, y = x.to(device), y.to(device)
            output = model(t, x)
            loss = criterion(output, y)
            loss.backward()

            epoch_corrects += int(torch.sum((output > 0).int() == y))
            epoch_loss += loss.item() * x.size(0)
            num_sample += x.size(0) * y.size(1)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        train_loss.append(epoch_loss / num_sample)
        train_acc.append(epoch_corrects / num_sample)
        print(' ', train_loss[-1], train_acc[-1])

        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            for x, y in tqdm(test_loader):
                t = torch.arange(n_take).expand(x.size(0), -1).to(device)
                x, y = x.to(device), y.to(device)
                output = model(t, x)
                loss = criterion(output, y)

                epoch_corrects += int(torch.sum((output > 0).int() == y))
                epoch_loss += loss.item() * x.size(0)
                num_sample += x.size(0) * y.size(1)

        test_loss.append(epoch_loss / num_sample)
        test_acc.append(epoch_corrects / num_sample)
        print(' ', test_loss[-1], test_acc[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [128, 256, 512]:
        file_version = f'IndCDE_Spring_10_{hidden_size}'
        main(file_version, hidden_size)