import torch
import os
import numpy as np
from torch.distributions import Normal, Categorical
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.mixture_same_family import MixtureSameFamily
import matplotlib.pyplot as plt

# torch.manual_seed(4)
# np.random.seed(0)

def sample_gaussian_mixture(pis, d, mus, sigmas, num_samples=10000):
    # Determine the component for each sample
    categories = torch.distributions.Categorical(pis)
    components = categories.sample((num_samples,))

    # Prepare output tensor
    samples = torch.zeros((num_samples, d)).to('cuda')

    # Sample from each Gaussian component
    for i in range(len(pis)):
        # Number of samples from this component
        num_component_samples = (components == i).sum().item()

        # Mean and covariance of the component
        mean = mus[i]
        # covariance = sigmas[i] ** 2

        # Multivariate normal distribution
        if d >= 2:
            mvn = torch.distributions.MultivariateNormal(mean, covariance_matrix=sigmas[i] ** 2)
        else:
            mvn = torch.distributions.Normal(mean, sigmas[i])

        # Sampling
        samples[components == i] = mvn.sample((num_component_samples,))

    return samples

def sample_gaussian(pis, d, mus, sigmas, num_samples=50000):
    samples_1 = sample_gaussian_mixture(pis, d, mus, sigmas, num_samples)
    if d >= 2:
        samples_0 = MultivariateNormal(torch.tensor((0.0, ) * d).to('cuda'), torch.eye(d).to('cuda')).sample((num_samples,))
    else:
        samples_0 = Normal(torch.tensor(0.0).to('cuda'), torch.tensor(1.0).to('cuda')).sample((num_samples,))
        samples_0 = samples_0[:, None]
    return samples_0, samples_1

def data_config():
    D = 1.
    VAR = 0.003
    COMP = 3

    initial_mix = Categorical(torch.tensor([1 / COMP for i in range(COMP)]))
    initial_comp = MultivariateNormal(torch.tensor(
        [[D * np.sqrt(3) / 2., D / 2.], [-D * np.sqrt(3) / 2., D / 2.], [0.0, - D * np.sqrt(3) / 2.]]).float(),
                                      VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
    initial_model = MixtureSameFamily(initial_mix, initial_comp)
    samples_0 = initial_model.sample([10000])

    target_mix = Categorical(torch.tensor([1 / COMP for i in range(COMP)]))
    target_comp = MultivariateNormal(torch.tensor(
        [[D * np.sqrt(3) / 2., - D / 2.], [-D * np.sqrt(3) / 2., - D / 2.], [0.0, D * np.sqrt(3) / 2.]]).float(),
                                     VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
    target_model = MixtureSameFamily(target_mix, target_comp)
    samples_1 = target_model.sample([10000])
    return samples_0, samples_1


def data_config1():
    name = 'mog'
    k = 6  # 
    d = 2  # 
    radius = 2  # 
    sigma = 0.2  # 
    mean_off_set = torch.ones((k, d)).to('cuda') * 0  # 
    pis = torch.ones(k).to('cuda') / k  # 
    pis = pis / pis.sum()

    # 
    mus = torch.zeros((k, d)).to('cuda')
    sigmas = torch.zeros((k, d, d)).to('cuda')
    # 2D
    angles = torch.linspace(0, 2 * torch.pi, k + 1)[:k]
    #  means
    for i in range(k):
        x = radius * torch.cos(angles[i])
        y = radius * torch.sin(angles[i])
        mus[i, :2] = torch.tensor([x, y]).to('cuda')
        mus[i, 2:] = 0
    mus = mus + mean_off_set

    #  covariance matrices
    for i in range(k):
        # 
        # diagonal = torch.ones(d).to('cuda') * sigma
        diagonal = torch.ones(d).to('cuda')
        diagonal[:2] = diagonal[:2] * sigma
        sigmas[i] = torch.diag(diagonal)

    num_samples = 500000
    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    return pis, d, mus, sigmas, name, sample_0, sample_1

def data_config2():
    name = 'easy_gaussian'
    k = 1  # 
    d = 2  # 
    pis = torch.ones(k).to('cuda') / k  # 

    mus = torch.zeros((k, d)).to('cuda')
    sigmas = torch.zeros((k, d, d)).to('cuda')

    sigmas[0] = torch.eye(d).to('cuda')
    mus[0] = torch.tensor([5.,5.0]).to('cuda')
    # samples_0 = MultivariateNormal(torch.tensor([-0.,0.]).to('cuda'), torch.eye(d).to('cuda') * sigma).sample((num,))
    # samples_1 = MultivariateNormal(torch.tensor([5.,5.0]).to('cuda'), torch.eye(d).to('cuda') * sigma).sample((num,))
    num_samples = 50000
    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    return pis, d, mus, sigmas, name, sample_0, sample_1

def data_config3():
    name = '1d_random'
    k = 6  # 
    d = 1  # 
    # pis = torch.ones(k).to('cuda') / k  # 
    # mus = torch.zeros((k, d)).to('cuda')
    # sigmas = torch.zeros((k, d, d)).to('cuda')

    pis = torch.rand(6).to('cuda') + 2
    pis /= torch.sum(pis)
    mus = torch.randn(k, d).to('cuda')
    sigmas = torch.rand(k, d).to('cuda') * 0.7 + 0.3
    num_samples = 50000
    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    return pis, d, mus, sigmas, name, sample_0, sample_1


# i = 1.0
def data_config4(mean=1.0, sigma=0.2):
    sigma = sigma
    mean = mean
    current_random_state = torch.random.get_rng_state()
    # torch.manual_seed(3)
    name = '1d_two_gaussian'
    if sigma != 0.2:
        name = name + f'_{sigma}'
    k = 2  # 
    d = 1  # 
    # pis = torch.ones(k).to('cuda') / k  # 
    # mus = torch.zeros((k, d)).to('cuda')
    # sigmas = torch.zeros((k, d, d)).to('cuda')

    pis = torch.tensor([0.5, 0.5]).to('cuda')
    mus = torch.tensor([-mean, mean]).to('cuda')[:,None]
    # mus = torch.tensor([-0.5, 2]).to('cuda')[:,None]
    # mus = torch.tensor([-0.5, 2]).to('cuda')[:,None]
    sigmas = torch.ones(k, d).to('cuda') * sigma
    # sigmas = torch.ones(k, d).to('cuda') * 0.35
    num_samples = 50000

    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    torch.random.set_rng_state(current_random_state)
    return pis, d, mus, sigmas, name, sample_0, sample_1

# _slices = torch.linspace(-2, 2, 8)
_slices = [1.2, 1.1, 0.9,0.8]

def data_config4_high_dim(sigma=0.2, dim=10):
    sigma = sigma
    mean = 1.0
    k = 2  # 
    d = dim  # 
    name = 'two_gaussian_high_dim'
    if d != 10:
        name = name + f'_{d}'
    if sigma != 0.2:
        name = name + f'_{sigma}'

    pis = torch.tensor([0.5, 0.5]).to('cuda')
    mus = torch.zeros((k, d)).to('cuda')
    # mus[0] = torch.tensor(mean).to('cuda').repeat(d)
    # mus[1] = torch.tensor(-mean).to('cuda').repeat(d)
    mus[0][0] = mean
    mus[1][0] = -mean

    if d == 1:
        sigmas = torch.zeros((k, d)).to('cuda')
    else:
        sigmas = torch.zeros((k, d, d)).to('cuda')
    for i in range(k):
        diagonal = torch.ones(d).to('cuda') * sigma
        # diagonal = torch.ones(d).to('cuda') * 1e-7
        diagonal[0] = sigma
        sigmas[i] = torch.diag(diagonal)
    num_samples = 50000

    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    sample_1 = sample_1[torch.randperm(sample_1.size(0))]
    return pis, d, mus, sigmas, name, sample_0, sample_1


def data_config5_high_dim(sigma=0.2, dim=2):
    sigma = sigma
    mean = 0.1
    k = 2  # 
    d = dim  # 
    name = 'two_gaussian_high_dim_5'
    if d != 10:
        name = name + f'_{d}'
    if sigma != 0.2:
        name = name + f'_{sigma}'

    pis = torch.tensor([0.5, 0.5]).to('cuda')
    mus = torch.zeros((k, d)).to('cuda')
    # mus[0] = torch.tensor(mean).to('cuda').repeat(d)
    # mus[1] = torch.tensor(-mean).to('cuda').repeat(d)
    mus[0][0] = mean
    mus[1][0] = -mean

    mus[0][1] = 0.1
    mus[1][1] = -0.1

    # mus[0][2] = 1
    # mus[1][2] = -1

    if d == 1:
        sigmas = torch.zeros((k, d)).to('cuda')
    else:
        sigmas = torch.zeros((k, d, d)).to('cuda')
    for i in range(k):
        # diagonal = torch.zeros(d).to('cuda')
        diagonal = torch.ones(d).to('cuda') * 1e-6
        diagonal[0] = sigma
        # diagonal[1] = sigma
        sigmas[i] = torch.diag(diagonal)
    num_samples = 50000

    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    sample_1 = sample_1[torch.randperm(sample_1.size(0))]
    return pis, d, mus, sigmas, name, sample_0, sample_1


def data_config5_high_dim_identity(sigma=0.2, dim=2):
    sigma = sigma
    mean = 0.0
    k = 2  # 
    d = dim  # 
    name = 'two_gaussian_high_dim_5'
    if d != 10:
        name = name + f'_{d}'
    if sigma != 0.2:
        name = name + f'_{sigma}'

    pis = torch.tensor([0.5, 0.5]).to('cuda')
    mus = torch.zeros((k, d)).to('cuda')
    # mus[0] = torch.tensor(mean).to('cuda').repeat(d)
    # mus[1] = torch.tensor(-mean).to('cuda').repeat(d)
    mus[0][0] = 1
    mus[0][1] = 1

    mus[1][0] = -1
    mus[1][1] = -1

    # mus[0][2] = 1
    # mus[1][2] = -1

    # if d == 1:
    #     sigmas = torch.zeros((k, d)).to('cuda')
    # else:
    #     sigmas = torch.zeros((k, d, d)).to('cuda')
    # for i in range(k):
    #     # diagonal = torch.zeros(d).to('cuda')
    #     diagonal = torch.ones(d).to('cuda') * 1e-6
    #     diagonal[0] = sigma
    #     # diagonal[1] = sigma
    #     sigmas[i] = torch.diag(diagonal)
    cov = 0.2

    # Construct the covariance matrix
    sigmas = torch.tensor([[[cov, 0.99999*cov],
                          [0.99999*cov, cov]],
                          [[cov, 0.99999*cov],
                          [0.99999*cov, cov]]]).to('cuda')

    num_samples = 50000

    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    sample_1 = sample_1[torch.randperm(sample_1.size(0))]
    return pis, d, mus, sigmas, name, sample_0, sample_1

def data_config5_high_dim_single(sigma=0.2, dim=2):
    sigma = sigma
    mean = 0.1
    k = 1  # 
    d = dim  # 
    name = 'two_gaussian_high_dim_5_single'
    if d != 10:
        name = name + f'_{d}'
    if sigma != 0.2:
        name = name + f'_{sigma}'

    pis = torch.tensor([1]).to('cuda')
    mus = torch.zeros((k, d)).to('cuda')
    # mus[0] = torch.tensor(mean).to('cuda').repeat(d)
    # mus[1] = torch.tensor(-mean).to('cuda').repeat(d)
    mus[0][0] = mean
    mus[0][1] = 0.1
    # mus[0][2] = 1

    if d == 1:
        sigmas = torch.zeros((k, d)).to('cuda')
    else:
        sigmas = torch.zeros((k, d, d)).to('cuda')
    for i in range(k):
        # diagonal = torch.zeros(d).to('cuda')
        diagonal = torch.ones(d).to('cuda') * 1e-6
        diagonal[0] = sigma
        # diagonal[1] = sigma
        sigmas[i] = torch.diag(diagonal)
    num_samples = 50000

    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    sample_1 = sample_1[torch.randperm(sample_1.size(0))]
    return pis, d, mus, sigmas, name, sample_0, sample_1

def data_config5():
    name = '1d_four_sharp_gaussian'
    k = 4  # 
    d = 1  # 
    # mus = torch.zeros((k, d)).to('cuda')
    # sigmas = torch.zeros((k, d, d)).to('cuda')

    pis = torch.ones(k).to('cuda') / k  # 
    mus = torch.tensor([-1.5, -.5, .5, 1.5]).to('cuda')[:,None]
    sigmas = torch.ones(k, d).to('cuda') * 0.05
    num_samples = 50000
    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    return pis, d, mus, sigmas, name, sample_0, sample_1

def data_config6():
    name = '1d_four_very_sharp_gaussian'
    k = 4  # 
    d = 1  # 
    # mus = torch.zeros((k, d)).to('cuda')
    # sigmas = torch.zeros((k, d, d)).to('cuda')

    pis = torch.ones(k).to('cuda') / k  # 
    mus = torch.tensor([-1.5, -.5, .5, 1.5]).to('cuda')[:,None]
    sigmas = torch.ones(k, d).to('cuda') * 0.0001
    num_samples = 50000
    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    return pis, d, mus, sigmas, name, sample_0, sample_1


def data_config7():
    name = '1d_two_gaussian_mean_shift'
    k = 2  # 
    d = 1  # 
    # pis = torch.ones(k).to('cuda') / k  # 
    # mus = torch.zeros((k, d)).to('cuda')
    # sigmas = torch.zeros((k, d, d)).to('cuda')

    pis = torch.tensor([0.5, 0.5]).to('cuda')
    mus = torch.tensor([-0.5, 2]).to('cuda')[:,None]
    # mus = torch.tensor([-0.5, 2]).to('cuda')[:,None]
    # mus = torch.tensor([-0.5, 2]).to('cuda')[:,None]
    sigmas = torch.ones(k, d).to('cuda') * 0.2
    # sigmas = torch.ones(k, d).to('cuda') * 0.35
    num_samples = 50000
    sample_0, sample_1, = sample_gaussian(pis, d, mus, sigmas, num_samples)
    return pis, d, mus, sigmas, name, sample_0, sample_1

def chessboard():
    name = 'chessboard'
    k = None
    d = 2
    pis = None
    mus = None
    sigmas = None

    grid_size = 4
    num_samples = 500000
    cell_length = 1.0

    # Number of black cells
    num_black_cells = (grid_size // 2) ** 2

    # Samples per black cell
    num_samples_per_black_cell = num_samples // num_black_cells

    samples = []

    # Iterate over each cell in the grid
    for i in range(grid_size):
        for j in range(grid_size):
            if (i + j) % 2 == 0:  # Black cell condition
                # Calculate the origin of the cell
                origin_x = i * cell_length
                origin_y = j * cell_length

                # Generate samples for this cell
                x = origin_x + cell_length * torch.rand(num_samples_per_black_cell).to('cuda')
                y = origin_y + cell_length * torch.rand(num_samples_per_black_cell).to('cuda')

                # Insert samples into the main array
                samples.append(torch.stack([x, y], dim=-1))
                # samples[idx:idx + num_samples_per_black_cell, 0] = x
                # samples[idx:idx + num_samples_per_black_cell, 1] = y
                # idx += num_samples_per_black_cell
    samples = torch.concat(samples, dim=0)
    samples = samples - (grid_size * cell_length / 2)
    # Shuffle the samples to avoid any order bias
    samples = samples[torch.randperm(samples.size(0))]
    samples_1 = torch.zeros((samples.shape[0], d)).to('cuda')
    samples_1[:, :2] = samples
    samples_0 = MultivariateNormal(torch.tensor((0.0, ) * d).to('cuda'), torch.eye(d).to('cuda')).sample((num_samples,))

    return pis, d, mus, sigmas, name, samples_0, samples_1

def wave():
    pis, d, mus, sigmas = (None, ) * 4
    d = 2
    name = 'wave'
    num_samples = 500000
    frequency = 1
    amplitude = 1
    noise_std = 0.1

    # x values are uniformly distributed over the interval [0, 2*pi]
    x = torch.linspace(-torch.pi, torch.pi, steps=num_samples).to('cuda')

    # y values are generated using a sine wave formula
    y = amplitude * torch.sin(frequency * x)

    x[:int(num_samples/2)] = x[:int(num_samples/2)] + torch.pi /4
    x[int(num_samples/2):] = x[int(num_samples/2):] - torch.pi /4
    y[:int(num_samples/2)] = y[:int(num_samples/2)] + 0.1
    y[int(num_samples/2):] = y[int(num_samples/2):] - 0.1

    # Combine x and y into a single samples tensor
    samples_1 = torch.stack([x, y], dim=1)
    samples_1 = samples_1 + torch.randn_like(samples_1) * noise_std
    samples_0 = MultivariateNormal(torch.tensor((0.0, ) * d).to('cuda'), torch.eye(d).to('cuda')).sample((num_samples,))
    return pis, d, mus, sigmas, name, samples_0, samples_1

def wave():
    pis, d, mus, sigmas = (None, ) * 4
    d = 2
    name = 'wave'
    num_samples = 500000
    frequency = 1
    amplitude = 1
    noise_std = 0.1

    # x values are uniformly distributed over the interval [0, 2*pi]
    x = torch.linspace(-torch.pi, torch.pi, steps=num_samples).to('cuda')

    # y values are generated using a sine wave formula
    y = amplitude * torch.sin(frequency * x)

    x[:int(num_samples/2)] = x[:int(num_samples/2)] + torch.pi /4
    x[int(num_samples/2):] = x[int(num_samples/2):] - torch.pi /4
    y[:int(num_samples/2)] = y[:int(num_samples/2)] + 0.1
    y[int(num_samples/2):] = y[int(num_samples/2):] - 0.1

    # Combine x and y into a single samples tensor
    samples_1 = torch.stack([x, y], dim=1)
    samples_1 = samples_1 + torch.randn_like(samples_1) * noise_std
    samples_0 = MultivariateNormal(torch.tensor((0.0, ) * d).to('cuda'), torch.eye(d).to('cuda')).sample((num_samples,))
    return pis, d, mus, sigmas, name, samples_0, samples_1

def semicircle():
    pis, d, mus, sigmas = (None, ) * 4
    d = 2
    name = 'semicircle'
    num_samples = 500000
    radius = 1
    noise_std = 0.1

    angles = torch.linspace(-torch.pi, torch.pi, steps=num_samples).to('cuda')
    x = radius * torch.cos(angles)
    y = radius * torch.sin(angles)

    x[:int(num_samples/2)] = x[:int(num_samples/2)] + 0.5
    x[int(num_samples/2):] = x[int(num_samples/2):] - 0.5
    y[:int(num_samples/2)] = y[:int(num_samples/2)] + 0.1
    y[int(num_samples/2):] = y[int(num_samples/2):] - 0.1

    # Combine x and y into a single samples tensor
    samples_1 = torch.stack([x, y], dim=1)
    samples_1 = samples_1 + torch.randn_like(samples_1) * noise_std

    samples_1 = samples_1[torch.randperm(samples_1.size(0))]
    samples_0 = MultivariateNormal(torch.tensor((0.0, ) * d).to('cuda'), torch.eye(d).to('cuda')).sample((num_samples,))
    return pis, d, mus, sigmas, name, samples_0, samples_1

def spiral():
    pis, d, mus, sigmas = (None, ) * 4
    d = 2
    name = 'spiral'
    num_samples = 500000
    num_turns = 2
    max_radius = 2
    noise_std = 0.1

    # 360
    angles = torch.linspace(0, num_turns * 2 * torch.pi, steps=num_samples).to('cuda')

    # 
    radii = torch.linspace(0, max_radius, steps=num_samples).to('cuda')

    # 
    x = radii * torch.cos(angles)
    y = radii * torch.sin(angles)

    # xysamples
    samples_1 = torch.stack((x, y), dim=1)
    samples_1 = samples_1 + torch.randn_like(samples_1) * noise_std
    samples_1 = samples_1[torch.randperm(samples_1.size(0))]
    samples_0 = MultivariateNormal(torch.tensor((0.0, ) * d).to('cuda'), torch.eye(d).to('cuda')).sample((num_samples,))
    return pis, d, mus, sigmas, name, samples_0, samples_1

def vis(samples_0, samples_1):
    if samples_1.shape[1] >= 2:
        plt.figure(figsize=(4, 4))
        plt.xlim(-10, 10)
        plt.ylim(-10, 10)
        plt.title(r'Samples from $\pi_0$ and $\pi_1$')
        # plt.scatter(samples_0[:, 0].cpu().numpy(), samples_0[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_0$', s=1)
        plt.scatter(samples_1[:, 0].cpu().numpy(), samples_1[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_1$', s=1)
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()
    if samples_1.shape[1] == 1:
        plt.figure()
        plt.hist(samples_1.cpu().numpy(), bins=100, density=True, alpha=0.3,
                     cumulative=False, range=(-3, 3))
        plt.show()
        plt.close()


train_only_data_and_gaussian = False

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config4()
# iterations = 10000
# batchsize = 2000

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config4_high_dim(0.02, dim=25)

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config4_high_dim(0.002, dim=10)
# iterations = 10000
# batchsize = 50000
# train_only_data_and_gaussian = True

# _pis,_d,_mus,_sigmas,_name,_samples_0, samples_1_test = data_config4_high_dim(0.02, dim=1)
# iterations = 2000
# batchsize = 50000

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config1()
pis, d, mus, sigmas, name, samples_0, samples_1 = chessboard()
# pis, d, mus, sigmas, name, samples_0, samples_1 = spiral()
# pis, d, mus, sigmas, name, samples_0, samples_1 = semicircle()
iterations = 10000
batchsize = 2000

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config4_high_dim(0.2, dim=10)
# iterations = 10000
# batchsize = 2000

# iterations = 10000
# batchsize = 50000
# iterations = 20000
# batchsize = 64
# iterations = 10000
# batchsize = 50000
# iterations = 2000
# batchsize = 50000
# iterations = 2000
# batchsize = 50000
# iterations = 10000
# batchsize = 50000
# iterations = 2000
# batchsize = 50000

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config5_high_dim(0.2)
# iterations = 10000
# batchsize = 2000

# pis, d, mus, sigmas, name, samples_0, samples_1 = data_config5_high_dim_identity(0.2)
# iterations = 10000
# batchsize = 2000

save_path = rf'./saved/{name}-iter{iterations}-b{batchsize}-nums{samples_0.shape[0]}'
if not os.path.exists(save_path):
    os.makedirs(save_path)
probability_flow = True
# probability_flow = False
# train_by_gt = True
train_by_gt = False

if train_only_data_and_gaussian:
    save_path = save_path + '-train_only_data_gaussian'
config = {
    'pis':pis,
    'd':d,
    'mus':mus,
    'sigmas':sigmas,
    'name':name,
    'num_samples':samples_0.shape[0],
    'samples_1': samples_1,
    # 'samples_1_test': samples_1_test,
    'samples_0': samples_0,
    'probability_flow':probability_flow,
    'save_path':save_path,
    'train_by_gt':train_by_gt,
    'train_only_data_and_gaussian': train_only_data_and_gaussian,
    'batchsize': batchsize
}


if __name__ == '__main__':
    import os
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
    vis(samples_0, samples_1)