import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset


class cdr_DATA(Dataset):
    def __init__(self, df):
        
        self.df = df
        self.x_data = self.df['x_data'].values.reshape(len(self.df), 1)
        self.t_data = self.df['t_data'].values.reshape(len(self.df), 1)
        self.u_data = self.df['u_data'].values.reshape(len(self.df), 1)
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        x_data = torch.FloatTensor(self.x_data[idx])
        t_data = torch.FloatTensor(self.t_data[idx])
        u_data = torch.FloatTensor(self.u_data[idx])
        return x_data, t_data, u_data
    
    def get_all_x(self):
        return torch.FloatTensor(self.x_data)
    
    def get_all_t(self):
        return torch.FloatTensor(self.t_data)

    def get_all_u(self):
        return torch.FloatTensor(self.u_data)
    
    def get_all(self):
        x = self.get_all_x()
        t = self.get_all_t()
        u = self.get_all_u()
        return torch.cat((x, t, u), dim=1)

        
    def split_lb_ub(self):
        all = self.get_all()
        lb_condition = (all[:, 0] == 0)
        ub_condition = (all[:, 0] == 2 * torch.pi)
        return all[lb_condition], all[ub_condition]

    @staticmethod
    def concat(data1, data2):
        df_concat = pd.concat([data1.df, data2.df], ignore_index=True)
        return cdr_DATA(df_concat)