import torch
import random
import matplotlib.pyplot as plt
import os
def create_data(n_samples_train, no_timesteps, seed = 0, device="cuda", rand_jump=0.8):
    torch.manual_seed(seed)
    time = torch.linspace(0., no_timesteps - 1, no_timesteps, device=device)
    trend = (time / no_timesteps) ** 2

    half_n = n_samples_train // 2
    trend1 = trend.unsqueeze(0).repeat(half_n, 1)
    trend2 = -trend.unsqueeze(0).repeat(half_n, 1)

    vals1 = torch.randn(half_n, no_timesteps, device=device)
    vals2 = torch.randn(half_n, no_timesteps, device=device)

    for t in range(1, no_timesteps):
        # Generate per-sample random decisions (True = smooth, False = jump)
        rand_mask1 = torch.rand(half_n, device=device) < rand_jump
        rand_mask2 = torch.rand(half_n, device=device) < rand_jump

        # Smooth update
        vals1[rand_mask1, t] = vals1[rand_mask1, t - 1] + trend1[rand_mask1, t]
        vals2[rand_mask2, t] = vals2[rand_mask2, t - 1] + trend2[rand_mask2, t]

        # Jump update
        vals1[~rand_mask1, t] = vals1[~rand_mask1, t - 1] + trend1[~rand_mask1, t] + 0.5 * torch.randn((~rand_mask1).sum(), device=device)
        vals2[~rand_mask2, t] = vals2[~rand_mask2, t - 1] + trend2[~rand_mask2, t] + 0.5 * torch.randn((~rand_mask2).sum(), device=device)

    result = torch.cat([vals1, vals2], dim=0)
    return result, time

class synthetic_data():
    def __init__(self, sampler):
        if sampler == "real_jump":
            self.sampler = create_jump_real
        elif sampler == "jump":
            self.sampler = create_data_jump
        elif sampler == "random_jumps":
            self.sampler = create_data_random_jumps
        elif sampler == "toy":
            self.sampler = create_data
    def create_data(self, n_samples, no_timesteps, device = "cuda", seed = 0):
        torch.manual_seed(seed)
        return self.sampler(n_samples, no_timesteps, device)

def random_subsample(data, times, no_timesteps, subsample_time, random_seed=0, device="cuda"):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

    batch_size, total_timesteps = data.shape
    assert total_timesteps == no_timesteps, "Mismatch in no_timesteps and data shape."

    # Generate random indices for each sample in the batch
    perm = torch.rand(batch_size, total_timesteps, device=device).argsort(dim=1)[:, :subsample_time]
    perm_sorted, _ = torch.sort(perm, dim=1)

    # Indexing
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, subsample_time)

    subsampled_data = data[batch_indices, perm_sorted].unsqueeze(-1)  # (B, T, 1)
    subsampled_times = times[perm_sorted].unsqueeze(-1)               # (B, T, 1)

    data_train = torch.cat([subsampled_data, subsampled_times], dim=2)  # (B, T, 2)
    return data_train
def random_times(no_timesteps, subsample_time, random_seed = None, device="cuda", equidist="False"):
    if equidist=="True":
        jump_every = no_timesteps//subsample_time
        sorted_perm = torch.arange(0, no_timesteps, device=device)[::jump_every]
    else:
        if random_seed is not None:
            torch.cuda.manual_seed(random_seed)
            torch.manual_seed(random_seed)
            torch.cuda.manual_seed_all(random_seed)
        perm = torch.randperm(no_timesteps, device=device)[:subsample_time-1]#[:subsample_time-1]
        sorted_perm = perm.sort().values
    return sorted_perm

class load_data():
    def __init__(self, seed):
        self.seed = seed
    def get_data(self,subsample_time, no_timesteps, batch_size = 100, device = "cuda",equidist = "True"):
        # Load data
        train_data,train_times = torch.load("data/train_data_0.pt").to(device), torch.load("data/train_times_0.pt").to(device)
        val_data = torch.load("data/val_data_100.pt").to(device)
        test_data = torch.load("data/test_data_200.pt").to(device)

        data_train = random_subsample(train_data, train_times,no_timesteps, subsample_time, random_seed =self.seed, device = device)

        dataloader = torch.utils.data.DataLoader(data_train, batch_size = batch_size, shuffle = True, drop_last=True)

        times_eval = random_times(no_timesteps, subsample_time, random_seed = 0,equidist = equidist)

        return dataloader, val_data, test_data, times_eval




os.makedirs("data", exist_ok=True)

# Train
if not os.path.exists("data/train_data_0.pt") or not os.path.exists("data/train_times_0.pt"):
    train_data, train_times = create_data(128*5000, 50,seed= 0)
    torch.save(train_data, "data/train_data_0.pt")
    torch.save(train_times, "data/train_times_0.pt")

# Val
if not os.path.exists("data/val_data_100.pt") or not os.path.exists("data/val_times_100.pt"):
    val_data, val_times = create_data(5000, 50, seed=100)
    torch.save(val_data, "data/val_data_100.pt")
    torch.save(val_times, "data/val_times_100.pt")

# Test
if not os.path.exists("data/test_data_200.pt") or not os.path.exists("data/test_times_200.pt"):
    test_data, test_times = create_data(5000, 50, seed=200)
    torch.save(test_data, "data/test_data_200.pt")
    torch.save(test_times, "data/test_times_200.pt")

