import torch
import numpy as np
from torchdiffeq import odeint
import os
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt


def double_pendulum(t, x, params):
    l1 = 1
    l2 = 1
    m1, m2, g = params

    if len(x.shape) > 1:
        alpha1 = x[:, 0]
        alpha2 = x[:, 1]
        p1 = x[:, 2]
        p2 = x[:, 3]
    else:
        alpha1 = x[0]
        alpha2 = x[1]
        p1 = x[2]
        p2 = x[3]

    A1 = (p1 * p2 * torch.sin(alpha1 - alpha2)) / (l1 * l2 * (m1 + m2 * torch.sin(alpha1 - alpha2)**2))
    A2 = 1 / (2 * l1**2 * l2**2 * (m1 + m2 * torch.sin(alpha1-alpha2)**2))**2 * (p1**2 * m2 * l2**2 - 2*p1*p2*m2*l1*l2*torch.cos(alpha1 - alpha2) + p2**2 * (m1 + m2) * l1**2) * torch.sin(2 * (alpha1 - alpha2))

    alpha1_prime = (p1 * l2 - p2 * l1 * torch.cos(alpha1 - alpha2)) / (l1**2 * l2 * (m1 + m2 * torch.sin(alpha1 - alpha2)**2))

    alpha2_prime = (p2 * (m1 + m2) * l1 - p1*m2*l2*torch.cos(alpha1 - alpha2)) / (m2 * l1 * l2**2 * (m1 + m2 * torch.sin(alpha1 - alpha2)**2))

    p1_prime = -(m1 + m2) * g * l1 * torch.sin(alpha1) - A1 + A2
    p2_prime = -m2 * g * l2 * torch.sin(alpha2) + A1 - A2

    return torch.stack((alpha1_prime, alpha2_prime, p1_prime, p2_prime))


def spring(t, x, params, input_force):
    # params = k,b,m
    k = params[0]
    b = params[1]
    m = params[2]

    x1 = x[0]
    x2 = x[1]

    dx1 = x2
    dx2 = -k / m * x1 - b / m * x2 + 1/m * input_force

    return torch.stack((dx1, dx2))


def angle_length_to_x(l1, l2, alpha1, alpha2):
    x1 = l1 * torch.sin(alpha1)
    x2 = l1 * torch.sin(alpha1) + l2 * torch.sin(alpha2)
    y1 = -l1 * torch.cos(alpha1)
    y2 = -l1 * torch.cos(alpha1) - l2 * torch.cos(alpha2)
    return torch.stack((x2, y2))

def solve_ode_pendulum(params):
    x0 = torch.Tensor([np.pi/2, np.pi/2, 2, 2])
    t = torch.linspace(0, 5, 200)  # 0 5 100
    y = odeint(lambda t, x : double_pendulum(t, x, params), x0, t)

    # Get the tip of the pendulum
    l1 = 1#params[2]
    l2 = 1#params[3]
    x = angle_length_to_x(l1, l2, y[:, 0], y[:, 1]).T
    return t, x, params


def solve_ode(params, ode):

    if ode == 'pendulum':
        t, y, params = solve_ode_pendulum(params)

    elif ode == 'spring':
        initial = torch.Tensor([1, 1]) # position, velocity
        t = torch.linspace(0, 10, 100)
        y = odeint(lambda t, x : spring(t, x, params, 1), initial, t)
        y = y[:, 0:1] # Get position only

    return t, y, params


def generate_dataset_pendulum(nn):
    m1 = np.random.uniform(0.5, 1.5, size=(nn, 1)) #np.linspace(1, 3, nn)
    m2 = np.random.uniform(0.5, 1.5, size=(nn, 1)) #np.linspace(1, 3, nn)
    g = np.random.uniform(5, 15, size=(nn, 1))
    params = np.concatenate((m1, m2, g), -1)

    all_ys = []
    for param in tqdm(params):
        y = solve_ode(param, 'pendulum')

        all_ys.append(y)

    return all_ys

def generate_dataset_spring(n):

    k_params = np.random.uniform(0.5, 2.5, size=(n, 1))
    b_params = np.random.uniform(0.5, 2.5, size=(n, 1))
    m_params = np.random.uniform(0.5, 2.5, size=(n, 1))
    params = np.concatenate((k_params, b_params, m_params), -1)

    all_ys = []
    for param in tqdm(params):
        y = solve_ode(param, 'spring')
        all_ys.append(y)

    return all_ys

class ODEDataset(torch.utils.data.Dataset):
    def __init__(self, support_size, n, phase, ode, device):

        if ode == 'pendulum':
            file_name = f'./data/pendulum_data/pendulum_m_g_different_init_{phase}.pkl'
        elif ode == "spring":
            file_name = f'./data/spring_data/spring_1000_{phase}.pkl'

        self.support_size = support_size
        self.ode = ode
        self.device = device

        if not os.path.exists(file_name):
            print(f"Generating: {file_name}...")

            if ode == "spring":
                self.data = generate_dataset_spring(n)
            elif ode == 'pendulum':
                self.data = generate_dataset_pendulum(n)


            with open(file_name, 'wb') as f:
                pickle.dump(self.data, f)
                print("Saved")
        else:
            with open(file_name, 'rb') as f:
                self.data = pickle.load(f)

        # print()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        t, y, param = self.data[idx]

        if self.ode == "pendulum":
            idx = np.arange(len(t))
            np.random.shuffle(idx)

            t = torch.from_numpy(t).float()

            support_idx = idx[0:self.support_size]
            query_idx = idx[self.support_size:]

            t_support, y_support = t[support_idx], y[support_idx, 1]
            t_query, y_query = t[query_idx], y[query_idx, 1]

            return t_support, y_support, t_query, y_query

        t = t.unsqueeze(-1)
        idx = np.arange(len(t))
        np.random.shuffle(idx)

        support_idx = idx[0:self.support_size]
        query_idx = idx[self.support_size:]

        t_support, y_support = t[support_idx], y[support_idx]
        t_query, y_query = t[query_idx], y[query_idx]

        return t_support.to(self.device), y_support.to(self.device), t_query.to(self.device), y_query.to(self.device)


if __name__ == '__main__':

    dset = ODEDataset(20, 1000, "train", "pendulum")
