
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import math, random

def sample_sticky_dirichlet_seq(T, K, kappa=20.0, alpha0=0.5, alpha_global=2.0, device='cpu'):
    """
    Generate a sticky gate sequence g_{1:T} with Dirichlet Markov prior.
    """
    theta = torch.ones(K, device=device) * alpha_global
    vartheta = torch.distributions.Dirichlet(theta).sample()
    g_seq = []
    # initial
    g_prev = torch.distributions.Dirichlet(alpha0 * vartheta).sample()
    g_seq.append(g_prev)
    for t in range(1, T):
        conc = kappa * g_prev + alpha0 * vartheta
        g_t = torch.distributions.Dirichlet(conc).sample()
        g_seq.append(g_t)
        g_prev = g_t
    return torch.stack(g_seq, dim=0), vartheta

class MultiTaskToyDataset(Dataset):
    """
    Synthetic multi-task, multi-dim sequences with shared subsequences.
    - K latent skills (orthonormal Q_true)
    - sticky gates with shared phase templates across tasks (shared subsequences)
    - z coefficients Gaussian
    - optional residual
    states s_t contain: one-hot task id, time features (sin/cos), and last action (for minimal context)
    """
    def __init__(self, num_tasks=3, seq_len=40, num_seq_per_task=200, d=6, K=4, seed=0):
        super().__init__()
        torch.manual_seed(seed)
        self.num_tasks = num_tasks
        self.seq_len = seq_len
        self.d = d
        self.K = K
        # Build ground-truth orthonormal Q_true via QR
        A = torch.randn(d, K)
        Q_true, _ = torch.linalg.qr(A, mode='reduced')
        self.Q_true = Q_true
        # R_true as orthonormal complement
        I = torch.eye(d)
        P_perp = I - Q_true @ Q_true.t()
        Qc, Rc = torch.linalg.qr(P_perp, mode='complete')
        self.R_true =  None

        # Shared phase templates for gates across tasks
        # Create 3 phase patterns reused across tasks with different permutations
        base_patterns = []
        for _ in range(3):
            pat = torch.zeros(seq_len, K)
            # choose 2-3 dominant experts per phase, piecewise constant
            boundaries = sorted(random.sample(range(5, seq_len-5), 2))
            segs = [0] + boundaries + [seq_len]
            for i in range(len(segs)-1):
                active = random.sample(range(K), k=min(2, K))
                gi = torch.zeros(K)
                gi[active] = 1.0
                gi = gi / gi.sum()
                pat[segs[i]:segs[i+1]] = gi
            base_patterns.append(pat * 0.9 + 0.1 / K)  # soften
        # Build dataset
        data = []
        for task in range(num_tasks):
            for n in range(num_seq_per_task):
                # pick a base pattern and permute experts to create variation
                pat = base_patterns[n % len(base_patterns)].clone()
                perm = torch.randperm(K)
                pat = pat[:, perm]
                # generate sticky gates around the template
                g_seq = []
                prev = pat[0]
                for t in range(seq_len):
                    # interpolate between template and sticky with slight noise
                    target = pat[t]
                    conc = 30.0 * target + 40.0 * prev  # sticky
                    g_t = torch.distributions.Dirichlet(conc).sample()
                    g_seq.append(g_t)
                    prev = g_t
                g_seq = torch.stack(g_seq, dim=0)
                # sample coefficients
                z = torch.randn(seq_len, K) * 0.5
                # decode action
                a = (self.Q_true @ (g_seq * z).t()).t()  # (T,d)
                r = None
                # states: one-hot task, time features, last action
                task_oh = torch.zeros(num_tasks); task_oh[task] = 1.0
                S = []
                prev_a = torch.zeros(self.d)
                for t in range(seq_len):
                    tau = float(t) / float(seq_len)
                    s_t = torch.cat([task_oh, torch.tensor([math.sin(2*math.pi*tau), math.cos(2*math.pi*tau)]), prev_a])
                    S.append(s_t)
                    prev_a = a[t]
                S = torch.stack(S, dim=0)
                data.append((S, a, g_seq, z, r, task, perm))
        self.data = data
        # derive dims
        self.s_dim = self.data[0][0].shape[-1]
        self.a_dim = self.d

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        S, a, g, z, r, task, perm = self.data[idx]
        return {
            'S': S.float(),
            'A': a.float(),
            'G': g.float(),
            'Z': z.float(),
            'R': torch.zeros_like(a) if r is None else r.float(),
            'task': int(task)
        }

def make_dataloaders(batch_size=16, **ds_kwargs):
    ds = MultiTaskToyDataset(**ds_kwargs)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True)
    return ds, dl

def make_data_splits(
    batch_size: int = 16,
    val_batch_size: int | None = None,
    test_batch_size: int | None = None,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    shuffle: bool = True,
    split_seed: int | None = None,
    **ds_kwargs,
):
    """
    Create train/val/test splits with deterministic shuffling and return loaders.
    Returns: (ds, dl_train, dl_val, dl_test)
    """
    ds = MultiTaskToyDataset(**ds_kwargs)
    N = len(ds)
    n_test = int(N * test_ratio)
    n_val = int(N * val_ratio)
    n_train = max(1, N - n_val - n_test)
    idx = torch.arange(N)
    if shuffle:
        g = torch.Generator()
        if split_seed is not None:
            g.manual_seed(int(split_seed))
        else:
            # derive split seed from dataset seed, if provided
            seed = int(ds_kwargs.get('seed', 0)) + 12345
            g.manual_seed(seed)
        idx = idx[torch.randperm(N, generator=g)]
    idx_train = idx[:n_train]
    idx_val = idx[n_train:n_train+n_val]
    idx_test = idx[n_train+n_val:]

    ds_train = Subset(ds, idx_train.tolist())
    ds_val = Subset(ds, idx_val.tolist())
    ds_test = Subset(ds, idx_test.tolist())

    if val_batch_size is None:
        val_batch_size = batch_size
    if test_batch_size is None:
        test_batch_size = batch_size

    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
    dl_val = DataLoader(ds_val, batch_size=val_batch_size, shuffle=False)
    dl_test = DataLoader(ds_test, batch_size=test_batch_size, shuffle=False)
    return ds, dl_train, dl_val, dl_test
