import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import math
import torch.optim as optim
import torch.nn.functional as F
import random
from tqdm import tqdm


def icl_reg_data_batch_wise(T, max_context, sigma, d, d_eff=None, batch_size=1_000_000):

    if d_eff is None:
        d_eff = d
    Sigma = (1 / d_eff) * torch.eye(d_eff)


    XY = torch.empty(T, max_context, d+1)
    beta_list = torch.empty(T, d)
    noise_list = torch.empty(T, max_context)

    sigmas = np.array(sigma)  # List of sigmas
    print("generating data....")
    for i in tqdm(range(0, T, batch_size), desc="Batches", dynamic_ncols=True):
        batch_T = min(batch_size, T - i)

        # Create data tensors for the current batch
        X_batch = torch.randn(batch_T, max_context, d).float()
        beta_batch = np.random.multivariate_normal(np.zeros(d_eff), Sigma, size=batch_T)

        chosen_sigmas = np.random.choice(sigmas, size=batch_T)
        noise_batch = np.random.normal(0, chosen_sigmas[:, np.newaxis], size=(batch_T, max_context))

        beta_batch = torch.from_numpy(beta_batch).float()
        noise_batch = torch.from_numpy(noise_batch).float()

        # Compute Y for the current batch
        Y_batch = torch.einsum('tni,ti->tn', X_batch[:, :, :d_eff], beta_batch) + noise_batch


        XY[i:i+batch_T,:,:-1] = X_batch
        XY[i:i + batch_T,:,-1] = Y_batch
        beta_list[i:i + batch_T, :,] = beta_batch
        noise_list[i:i + batch_T, :] = noise_batch

    return XY, beta_list, noise_list

class icl_reg_batch_wise_Dataset(Dataset):
    def __init__(self, sigmas, d, max_context, T,d_eff=None):

        self.sigmas = sigmas
        self.d = d
        self.max_context = max_context
        self.T = T

        self.xy, self.beta, self.noise = icl_reg_data_batch_wise(T, max_context, sigmas, d,d_eff)


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):

        return self.xy[idx,:,:]


def icl_reg_data(T, max_context, sigma, d, d_eff=None):


    if  d_eff == None:
        d_eff = d
    Sigma = (1 / d_eff) * torch.eye(d_eff)

    # Create data tensors
    X = torch.randn(T, max_context, d).float()  # Normally distributed X
    beta = np.random.multivariate_normal(np.zeros(d_eff), Sigma, size=T)

    # noise = np.random.normal(0, sigma, size=(T, max_context))
    sigmas = np.array(sigma)  # List of sigmas
    # Randomly choose a sigma for each T
    chosen_sigmas = np.random.choice(sigmas, size=T)
    # Generate noise
    noise = np.random.normal(0, chosen_sigmas[:, np.newaxis], size=(T, max_context))

    beta = torch.from_numpy(beta).float()
    noise = torch.from_numpy(noise).float()

    # Compute Y using only the first d_eff dimensions of X
    Y = torch.einsum('tni,ti->tn', X[:, :, :d_eff], beta) + noise  # Using einsum for batch matrix multiplication
    return X, Y, beta, noise


class icl_reg_Dataset(Dataset):
    def __init__(self, sigmas, d, max_context, T,d_eff=None):

        self.sigmas = sigmas
        self.d = d
        self.max_context = max_context
        self.T = T

        self.x, self.y, self.beta, self.noise = icl_reg_data(T, max_context, sigmas, d,d_eff)


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (self.x[idx, :, :], \
                self.y[idx, :])


def icl_NNrelu_data(T, max_context, d, hid_dim=100, batch_size=500_000):
    # Initialize empty lists to store results
    X_batches = []
    Y_batches = []

    # Calculate the number of batches
    n_batches = int(math.ceil(T / batch_size))

    for i in range(n_batches):
        batch_start = i * batch_size
        batch_end = min((i + 1) * batch_size, T)
        batch_size_current = batch_end - batch_start

        # Create data tensors for the current batch
        X = torch.randn(batch_size_current, max_context, d).float()
        w1 =  torch.randn(batch_size_current, d, hid_dim).float()
        w2 = torch.randn(batch_size_current, hid_dim, 1).float()

        relu_w1x = torch.nn.functional.relu(torch.einsum('tnd,tdh->tnh', X, w1))
        w2_relu_w1x = torch.einsum('tnh,thd->tnd', relu_w1x, w2)

        Y = w2_relu_w1x * math.sqrt(2 / hid_dim)

        # Append the batch results to the lists
        X_batches.append(X)
        Y_batches.append(Y)

        print(f"Processed batch {i + 1}/{n_batches}")

    # Concatenate all batches to form the final tensors
    X = torch.cat(X_batches, dim=0)
    Y = torch.cat(Y_batches, dim=0)

    return X, Y



class icl_NNrelu_Dataset(Dataset):
    def __init__(self, d, max_context, T,hid_dim = 100):
        self.d = d
        self.max_context = max_context
        self.T = T

        self.x, self.y = icl_NNrelu_data(T, max_context, d,hid_dim = hid_dim)

        self.y = torch.squeeze(self.y)


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (self.x[idx, :, :], \
                self.y[idx, :])


class DecisionTreeDataset(Dataset):
    def __init__(self, N, d, n_samples):
        self.N = N #largest context_size
        self.d = d
        self.n_samples = n_samples
        self.trees = []
        self.data = []
        self.xi = []
        self.fxi = []

        for _ in range(n_samples):
            tree = self._generate_tree()
            self.trees.append(tree)
            prompt = self._generate_prompt(tree)
            self.data.append(prompt)
            self.xi.append(prompt[0])
            self.fxi.append(prompt[1])

        self.xi = torch.from_numpy(np.stack(self.xi,axis=0)).float()
        self.fxi = torch.from_numpy(np.stack(self.fxi,axis=0)).float()

    def _generate_tree(self):
        # Generate the coordinates associated with non-leaf nodes uniformly at random
        coords = np.random.randint(0, self.d, size=15)
        # Generate leaf node values from N(0, 1)
        leaf_values = np.random.normal(0, 1, size=16)
        return (coords, leaf_values)

    def _evaluate_tree(self, tree, x):
        coords, leaf_values = tree
        node = 0
        for i in range(4):  # 4 levels depth
            if x[coords[node]] > 0:
                node = 2 * node + 2  # right child
            else:
                node = 2 * node + 1  # left child
        return leaf_values[node - 15]  # leaf node

    def _generate_prompt(self, tree):
        # Sample x1, ..., xk and xquery from N(0, I_d)
        xis = np.random.normal(0, 1, (self.N, self.d))
        # Calculate f(x1), ..., f(xk) and f(xquery)
        fxis = [self._evaluate_tree(tree, xi) for xi in xis]
        return xis, fxis

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        xis, fxis = self.data[idx]
        xis_tensor = torch.tensor(xis, dtype=torch.float32)
        fxis_tensor = torch.tensor(fxis, dtype=torch.float32)
        return xis_tensor, fxis_tensor





class SparseLinearDataset(Dataset):
    def __init__(self, k, d=20, s=3, num_samples=1000):
        self.k = k #max context length
        self.d = d
        self.s = s
        self.num_samples = num_samples

        # Generate all samples at once
        self.xi = torch.randn(self.num_samples, k, d)

        # Generate random weight vectors and apply sparsity
        w = torch.randn(self.num_samples, d)
        mask = torch.zeros(self.num_samples, d)

        # Create a mask with exactly s ones per row
        idx = torch.multinomial(torch.ones(self.num_samples, d), s, replacement=False)
        mask.scatter_(1, idx, 1)

        self.w = w * mask

        # Calculate f(xi) and f(xquery)
        self.fx = torch.bmm(self.xi, self.w.unsqueeze(-1)).squeeze(-1)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.xi[idx], self.fx[idx]







