import torch
from stock_model import *
import numpy as np
STOCK_MODELS = {
    "BlackScholes": BlackScholes,
    "Heston": Heston,
    "OrnsteinUhlenbeck": OrnsteinUhlenbeck,
    "HestonWOFeller": HestonWOFeller,
    "combined": Combined,
    "sine_BlackScholes": BlackScholes,
    "sine_Heston": Heston,
    "sine_OrnsteinUhlenbeck": OrnsteinUhlenbeck,
}

def create_data(n_samples_train, no_timesteps, device="cuda", stock_model_name="BlackScholes", seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    """
    hyperparam_test_stock_models = {
    'drift': 0.2, 'volatility': 0.3, 'mean': 0.5,
    'speed': 0.5, 'correlation': 0.5, 'nb_paths': n_samples_train, 'nb_steps': no_timesteps-1,
    'S0': 1, 'maturity': 1., 'dimension': 1}
    """
    hyperparam_test_stock_models = {
    'drift': 2., 'volatility': 0.3, 'mean': 4,
    'speed': 2., 'correlation': 0.5, 'nb_paths': n_samples_train, 'nb_steps': no_timesteps-1,
    'S0': 1, 'maturity': 1., 'dimension': 1, 
    'obs_perc': 0.1,
    'scheme': 'euler', 'return_vol': False, 'v0': 1,
    'model_name': stock_model_name
    }
    
    stock_model_class = STOCK_MODELS[stock_model_name]
    stockmodel = stock_model_class(**hyperparam_test_stock_models)
    stock_paths, dt = stockmodel.generate_paths()
    stock_paths = stock_paths.reshape(n_samples_train, no_timesteps)
    tensor = torch.tensor(stock_paths, device=device).float()
    #time = torch.linspace(0.,no_timesteps-1, no_timesteps) 
    time = torch.linspace(0.,1, no_timesteps) 

    return tensor,time

class synthetic_data:
    def __init__(self, stock_model_name="BlackScholes"):
        if stock_model_name not in STOCK_MODELS:
            raise ValueError(f"Unknown stock model: {stock_model_name}")
        self.stock_model_name = stock_model_name

    def create_data(self, n_samples, no_timesteps, device="cuda", seed=0):
        return create_data(
            n_samples_train=n_samples,
            no_timesteps=no_timesteps,
            device=device,
            stock_model_name=self.stock_model_name,
            seed=seed
        )

def random_subsample(data, times, no_timesteps, subsample_time, random_seed=0, device="cuda"):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

    batch_size, total_timesteps = data.shape
    assert total_timesteps == no_timesteps, "Mismatch in no_timesteps and data shape."

    # Generate random indices for each sample in the batch
    perm = torch.rand(batch_size, total_timesteps, device=device).argsort(dim=1)[:, :subsample_time]
    perm_sorted, _ = torch.sort(perm, dim=1)

    # Indexing
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, subsample_time)

    subsampled_data = data[batch_indices, perm_sorted].unsqueeze(-1)  # (B, T, 1)
    subsampled_times = times[perm_sorted].unsqueeze(-1)               # (B, T, 1)

    data_train = torch.cat([subsampled_data, subsampled_times], dim=2)  # (B, T, 2)
    return data_train

def random_times(no_timesteps, subsample_time, random_seed = None, device="cuda", equidist="False"):
    if equidist=="True":
        jump_every = no_timesteps//subsample_time
        sorted_perm = torch.arange(0, no_timesteps, device=device)[::jump_every]
    else:
        if random_seed is not None:
            torch.cuda.manual_seed(random_seed)
            torch.manual_seed(random_seed)
            torch.cuda.manual_seed_all(random_seed)
        perm = torch.randperm(no_timesteps, device=device)[:subsample_time-1]#[:subsample_time-1]
        sorted_perm = perm.sort().values
    return sorted_perm
class load_data():
    def __init__(self, seed):
        self.seed = seed
    def get_data(self,subsample_time, no_timesteps, batch_size = 100, device = "cuda",equidist = "True"):
        # Load data
        train_data,train_times = torch.load("data_bs/train_data_0.pt").to(device), torch.load("data_bs/train_times_0.pt").to(device)
        val_data = torch.load("data_bs/val_data_100.pt").to(device)
        test_data = torch.load("data_bs/test_data_200.pt").to(device)

        data_train = random_subsample(train_data, train_times,no_timesteps, subsample_time, random_seed =self.seed, device = device)

        dataloader = torch.utils.data.DataLoader(data_train, batch_size = batch_size, shuffle = True)

        times_eval = random_times(no_timesteps, subsample_time, random_seed = 0,equidist = equidist)

        return dataloader, val_data, test_data, times_eval


os.makedirs("data_bs", exist_ok=True)

# Train
if not os.path.exists("data_bs/train_data_0.pt") or not os.path.exists("data_bs/train_times_0.pt"):
    train_data, train_times = create_data(128*1000, 101,seed= 0)
    torch.save(train_data, "data_bs/train_data_0.pt")
    torch.save(train_times, "data_bs/train_times_0.pt")

# Val
if not os.path.exists("data_bs/val_data_100.pt") or not os.path.exists("data_bs/val_times_100.pt"):
    val_data, val_times = create_data(5000, 101, seed=100)
    torch.save(val_data, "data_bs/val_data_100.pt")
    torch.save(val_times, "data_bs/val_times_100.pt")

# Test
if not os.path.exists("data_bs/test_data_200.pt") or not os.path.exists("data_bs/test_times_200.pt"):
    test_data, test_times = create_data(5000, 101, seed=200)
    torch.save(test_data, "data_bs/test_data_200.pt")
    torch.save(test_times, "data_bs/test_times_200.pt")

