import math
import torch
import torchcde
import random
import numpy as np
from irregular_sampled_datasets import PersonData, Walker2dImitationData
import torch.utils.data as data
from torchmetrics.functional import accuracy
import torch.nn as nn
import logging

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', filemode="w")
file_handler = logging.FileHandler(filename="NCDE_person.log", mode='a', encoding='utf-8', delay=False)
file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
logger = logging.getLogger("NCDE_person.log")
logger.addHandler(file_handler)
logger.info("Logging has started")

# Define criterion
criterion = nn.CrossEntropyLoss()

######################
# A CDE model looks like
#
# z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
#
# Where X is your data and f_\theta is a neural network. So the first thing we need to do is define such an f_\theta.
# That's what this CDEFunc class does.
# Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.
######################
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z


######################
# Next, we need to package CDEFunc up into a model that computes the integral.
######################
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, interpolation="cubic"):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)
        self.interpolation = interpolation

    def forward(self, coeffs):
        if self.interpolation == 'cubic':
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y


######################
# Now we need some data.
# Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
######################
def get_data(num_timepoints=100):
    t = torch.linspace(0., 4 * math.pi, num_timepoints)

    start = torch.rand(128) * 2 * math.pi
    x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos[:64] *= -1
    y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos += 0.01 * torch.randn_like(x_pos)
    y_pos += 0.01 * torch.randn_like(y_pos)
    ######################
    # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
    # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
    ######################
    X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
    y = torch.zeros(128)
    y[:64] = 1

    perm = torch.randperm(128)
    X = X[perm]
    y = y[perm]


    ######################
    # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
    # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise
    # respectively.
    ######################
    return X, y


def load_dataset():
    # Load the data from PersonData
    dataset = PersonData()
    train_x = torch.Tensor(dataset.train_x)
    train_y = torch.LongTensor(dataset.train_y)
    train_ts = torch.Tensor(dataset.train_t)
    test_x = torch.Tensor(dataset.test_x)
    test_y = torch.LongTensor(dataset.test_y)
    test_ts = torch.Tensor(dataset.test_t)

    train_y = train_y[:, -1]
    test_y = test_y[:, -1]
 
    train_x = torch.cat([train_ts, train_x], dim=2)
    test_x = torch.cat([test_ts, test_x], dim=2)

    X,y = get_data()

    print(X.shape)
    print(y.shape)
    
    return train_x, train_y, test_x, test_y 

# Main function
def main(num_epochs=200, seeds=[8, 10, 35, 69, 96]):
    for run, seed in enumerate(seeds):
        # Set different seeds from the list for each run
        random.seed(seed)  # Python random module
        np.random.seed(seed)  # Numpy library
        torch.manual_seed(seed)  # PyTorch random number generator
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)  # CUDA randomness
            torch.cuda.manual_seed_all(seed)  # CUDA all GPUs
            torch.backends.cudnn.deterministic = True  # CuDNN behavior

        train_X, train_y, test_X, test_y = load_dataset()
        model = NeuralCDE(input_channels=8, hidden_channels=64, output_channels=7)
        model.to(device)  
        optimizer = torch.optim.Adam(model.parameters())
        train_X = train_X.to(device)
        train_y = train_y.to(device)
        test_X = test_X.to(device)
        test_y = test_y.to(device)
        train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)
        train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=256)

        for epoch in range(num_epochs):
            for batch_coeffs, batch_y in train_dataloader:
                batch_coeffs = batch_coeffs.to(device)
                batch_y = batch_y.to(device)
                optimizer.zero_grad()
                pred_y = model(batch_coeffs).squeeze(-1)
                loss = criterion(pred_y, batch_y)
                loss.backward()
                optimizer.step()
            print(f'Run: {run}, Seed: {seed}, Epoch: {epoch}, Training loss: {loss.item()}')
            logger.info(f"Run: {run}, Seed: {seed}, Epoch: {epoch} - Train Loss:{loss.item()};")

        test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
        pred_y = model(test_coeffs).squeeze(-1)
        proportion_correct = accuracy(pred_y, test_y)
        print(f'Run: {run}, Seed: {seed}, Test Accuracy:', proportion_correct)
        logger.info(f"Run: {run}, Seed: {seed}, Test Accuracy : {proportion_correct},")

if __name__ == '__main__':
    main()