import pandas as pd
import copy
import numpy as np
try:
    import torch
    from torch.utils.data import Dataset
except:
    print('No module named torch. Please pip install torch')

def get_var_df(df,var):
    var_cols = [c for c in df.columns if c.startswith(var)]
    return df[var_cols].to_numpy()

def cat(data_list, axis=1):
    try:
        output=torch.cat(data_list,axis)
    except:
        output=np.concatenate(data_list,axis)

    return output

def split(data, split_ratio=0.5):
    data1 = copy.deepcopy(data)
    data2 = copy.deepcopy(data)

    split_num = int(data.length * split_ratio)
    data1.split(0, split_num)
    data2.split(split_num, data.length)

    return data1, data2

class CausalDataset(object):
    def __init__(self, train=None, test=None, path=None):
        self.path  = path 
        if path is not None:
            self.train = getDataset(pd.read_csv(path + 'train.csv'))
            self.test  = getDataset(pd.read_csv(path + 'test.csv'))
        else:
            self.train = getDataset(train)
            self.test  = getDataset(test)

    def split(self, split_ratio=0.5, data=None):
        if data is None:
            data = self.train

        data1, data2 = split(data, split_ratio)
        self.data1 = data1
        self.data2 = data2

    def get_train(self):
        return self.train

    def get_test(self):
        return self.test

    def tensor(self):
        self.train.tensor()
        self.test.tensor()

    def double(self):
        self.train.double()
        self.test.double()

    def float(self):
        self.train.float()
        self.test.float()

    def detach(self):
        self.train.detach()
        self.test.detach()

    def to(self, device='cpu'):
        self.train.to(device)
        self.test.to(device)

    def cpu(self):
        self.train.cpu()
        self.test.cpu()

    def numpy(self):
        self.train.numpy()
        self.test.numpy()
    
    def merge(self, other):
        if not isinstance(other, CausalDataset):
            raise TypeError("Only CausalDataset can be merged.")    
        if self.train is not None and other.train is not None:
            self.train.merge(other.train)
        elif other.train is not None:
            self.train = other.train
        if self.test is not None and other.test is not None:
            self.test.merge(other.test)
        elif other.test is not None:
            self.test = other.test

class TorchDataset(Dataset):
    def __init__(self, data, device='cpu', type='tensor'):
        if type == 'tensor':
            data.tensor()
        else:
            data.double()
        data.to(device)
        
        self.data = data
    
    def __getitem__(self, idx):
        var_dict = {}
        for var in self.data.Vars:
            exec(f'var_dict[\'{var}\']=self.{var}[idx]')
        
        return var_dict

    def __len__(self):
        return self.data.length

class getDataset(Dataset):
    def __init__(self, df):
        self.length = len(df)
        self.Vars = list(set([col[0] for col in df.columns]))

        for var in self.Vars:
            exec(f'self.{var}=get_var_df(df, \'{var}\')')
            
        if 'y' in self.Vars:
            if self.y.ndim > 1:
                self.y_cf = self.y.copy()  
                self.Vars.append('y_cf')  
                self.y = self.y[:, 0:1]  
        if not hasattr(self, 'i') and hasattr(self, 'z'):
            self.i = self.z
            self.Vars.append('i')

    def append(self, var):
        self.Vars.append(var)
        self.Vars = list(set(self.Vars))
        if not hasattr(self, var):
            exec(f'self.{var}=self.zeros_like(self.t)')

    def split(self, start, end):
        for var in self.Vars:
            try:
                exec(f'self.{var} = self.{var}[start:end]')
            except:
                pass

        self.length = end - start

    def cpu(self):
        for var in self.Vars:
            try:
                exec(f'self.{var} = self.{var}.cpu()')
            except:
                break
    
    def cuda(self,n=0):
        for var in self.Vars:
            try:
                exec(f'self.{var} = self.{var}.cuda({n})')
            except:
                break

    def to(self,device='cpu'):
        for var in self.Vars:
            try:
                exec(f'self.{var} = self.{var}.to(\'{device}\')')
            except:
                break
    
    def tensor(self):
        for var in self.Vars:
            try:
                exec(f'self.{var} = torch.Tensor(self.{var})')
            except:
                break

    def float(self):
        for var in self.Vars:
            try:
                exec(f'self.{var} = torch.Tensor(self.{var}).float()')
            except:
                break    
            
    def double(self):
        for var in self.Vars:
            try:
                exec(f'self.{var} = torch.Tensor(self.{var}).double()')
            except:
                break

    def detach(self):
        for var in self.Vars:
            try:
                exec(f'self.{var} = self.{var}.detach()')
            except:
                break
            
    def numpy(self):
        try:
            self.detach()
        except:
            pass

        try:
            self.cpu()
        except:
            pass

        for var in self.Vars:
            try:
                exec(f'self.{var} = self.{var}.numpy()')
            except:
                break

    def pandas(self, path=None):
        var_list = []
        var_dims = []
        var_name = []
        for var in self.Vars:
            exec(f'var_list.append(self.{var})')
            exec(f'var_dims.append(self.{var}.shape[1])')
        for i in range(len(self.Vars)):
            for d in range(var_dims[i]):
                var_name.append(self.Vars[i]+str(d))
        df = pd.DataFrame(np.concatenate(var_list, axis=1),columns=var_name)

        if path is not None:
            df.to_csv(path, index=False)
        return df

    def __getitem__(self, idx):
        var_dict = {}
        for var in self.Vars:
            exec(f'var_dict[\'{var}\']=self.{var}[idx]')
        
        return var_dict

    def __len__(self):
        return self.length
    
    def merge(self, other):
        if not isinstance(other, getDataset):
            raise TypeError("Only getDataset objects can be merged.")

        for var in self.Vars:
            if hasattr(other, var):
                try:
                    if isinstance(getattr(self, var), torch.Tensor):
                        exec(f'self.{var} = torch.cat((self.{var}, other.{var}), dim=0)')
                    elif isinstance(getattr(self, var), np.ndarray):
                        exec(f'self.{var} = np.concatenate((self.{var}, other.{var}), axis=0)')
                    else:
                        raise TypeError(f"Unsupported data type for {var}")
                except Exception as e:
                    print(f"Failed to merge {var}: {e}")
        self.length = getattr(self, self.Vars[0]).shape[0]