import pytorch_lightning as pl
import sys
sys.path.insert(0,"../")
#import utils
import torch
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset, TensorDataset
import os
import argparse
import numpy as np
from scipy.integrate import odeint
import pandas as pd
import random
import pickle
from tumor.data_utils import data_to_torch_tensor, process_data 


#python train_model.py --dataset_name tumor --logger_type=wandb --seed=313 --max_epochs 20 --batch_size 256 --num_blocks 1 --lr 0.001 --loss_func mse --reg1 0.00 --reg2 0.00 --state_dim 20






def transform_to_our_input(data_X, data_A, len = 58):
    data_o = data_X[:,:,0][:,:,None]
    #print('shapes')
    #print(data_X.shape)
    #print(data_A.shape)
    data_c = torch.cat((data_X[:,:,1:], data_A), dim=-1)
    data_Time = torch.Tensor(np.tile(torch.arange(len)[None,:],(data_X.shape[0],1)))
    #print(data_o.shape)
    #print(data_c.shape)
    return data_o.type('torch.FloatTensor'), data_c.type('torch.FloatTensor'), data_Time

#def get_observation_mask(data, len=58):
#    #print('obs mask')
#    #print(data["intensity"][5])
#    #print(np.diff(data["intensity"], n=1)[5])
#    #print(data["active_entries"][:,:,0][5])
#    data_ind = (np.diff(data["intensity"], n=1) + data["active_entries"][:,:,0])
#    data_ind = np.insert(data_ind, 0, 2, axis=1)
#    data_ind = (data_ind==2)[:,:len, np.newaxis]
#    return torch.from_numpy(data_ind)




class TumorDataModule(pl.LightningDataModule):
    def __init__(self, coeff, kappa, batch_size, seed, data_path, num_workers = 4,data_dir: str = "./",  **kwargs):
        super().__init__()
        self.batch_size = batch_size
        self.input_dim = 1
        self.output_dim = 1
        self.action_dim = 6
        self.seed = seed

        self.transformed_datapath = f"../{data_path}/new_cancer_sim_{coeff}_{coeff}_kappa_{kappa}.p"
        print(f"Loading transformed data from {self.transformed_datapath}")
         #read_from_file(transformed_datapath)


    def prepare_data(self):
        # download
        
        self.pickle_map = pickle.load(open(self.transformed_datapath, "rb"))

        print("Processing dataset")
        

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        self.training_processed, self.validation_processed, self.test_processed = process_data(self.pickle_map)
        sample_proportion=1
        
    
        if stage == "fit":
            data_X, data_A, data_Time, data_y, data_tr, _,_, data_mask= data_to_torch_tensor(
                        self.training_processed,
                        sample_prop=sample_proportion,
                    )
            #data_ind = get_observation_mask(self.training_processed)
            data_o, data_c, data_Time = transform_to_our_input(data_X, data_A)
            data_c= torch.cat((data_c,data_mask[:,:58,:]),dim=-1).type('torch.FloatTensor')
            print(data_X.shape)
            print(data_tr.shape)
            self.train = torch.utils.data.TensorDataset(data_o, data_c, data_Time, data_mask[:,:58,:].type('torch.FloatTensor'))
        
            data_X, data_A, data_Time, data_y, data_tr,_,_, data_mask = data_to_torch_tensor(
                        self.validation_processed,
                        sample_prop=sample_proportion,
                    )
            #data_ind = get_observation_mask(self.validation_processed)
            data_o, data_c, data_Time = transform_to_our_input(data_X, data_A)
            data_c= torch.cat((data_c, data_mask[:,:58,:]),dim=-1).type('torch.FloatTensor')
            print(data_o.shape)
            print(data_c.shape)
            print(data_Time.shape)
            self.val =  TensorDataset(data_o, data_c, data_Time, data_mask[:,:58,:].type('torch.FloatTensor'))

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            data_X, data_A, data_Time, data_y, data_tr, _,_, data_mask = data_to_torch_tensor(
                        self.validation_processed,
                        sample_prop=sample_proportion,
                    )
            #data_ind = get_observation_mask(self.validation_processed)
            data_o, data_c, data_Time = transform_to_our_input(data_X, data_A)
            data_c= torch.cat((data_c,data_mask[:,:58,:]),dim=-1).type('torch.FloatTensor')
            self.val =  TensorDataset(data_o, data_c, data_Time, data_mask[:,:58,:].type('torch.FloatTensor'))

        if stage == "predict":
            data_X, data_A, data_Time, data_y, data_tr,_,_, data_mask = data_to_torch_tensor(
                        self.test_processed,
                        sample_prop=sample_proportion,
                    )
            #data_ind = get_observation_mask(self.test_processed)
            data_o, data_c, data_Time = transform_to_our_input(data_X, data_A)
            data_c= torch.cat((data_c,data_mask[:,:58,:]),dim=-1).type('torch.FloatTensor')
            self.test = TensorDataset(data_o, data_c, data_Time, data_mask.type('torch.FloatTensor'))

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

    @classmethod
    def add_dataset_specific_args(cls, parent):
        import argparse
        parser = argparse.ArgumentParser(parents=[parent], add_help=False)
        parser.add_argument('--seed', type=int, default=42)
        parser.add_argument('--batch_size', type=int, default=1) #128
        parser.add_argument('--coeff', type=int, default=4)
        parser.add_argument('--kappa', type=int, default=1)
        parser.add_argument("--data_path", type=str, default=None)
        return parser
