
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
import kornia
import os
import matplotlib.pyplot as plt
import math
from rotation import *
from functorch.einops import rearrange
import mrcfile

data_path = 'datasets/'
dataset_name = 'yeast_mwa_30_snr_01_cellular_protein_mixture_subtomograms.pkl'
model_path = 'models/'
model_name = 'yeast_mwa_0_snr_01_cellular_protein_mixture'
K =4

class ConvEncoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        self.pixel = pixel
        theta_dim = 6
        super(ConvEncoder, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv3d(128, 256, kernel_size=4, stride=1, padding=0)
        self.fc = nn.Linear(6912, 1024)
        self.fc_last = nn.Linear(1024, latent_dims+theta_dim+3)
        self.fc_dropout = nn.Dropout(0.2)
    def forward(self, x, training=False):
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = F.elu(self.fc(x))
        if training:
            x = self.fc_dropout(x)
        phi = self.fc_last(x)
        return phi

class ConvDecoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        super(ConvDecoder, self).__init__()
        self.fc1 = nn.Linear(latent_dims, 256 * 3 * 3 * 3)
        self.deconv1 = nn.ConvTranspose3d(256, 128, kernel_size=4, stride=1, padding=0)  # → (128, 6, 6, 6)
        self.deconv2 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1)   # → (64, 12, 12, 12)
        self.deconv3 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1)    # → (32, 24, 24, 24)
        self.deconv4 = nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, padding=1)     # → (1, 48, 48, 48)
        self.dropout = nn.Dropout(0.1)
    def forward(self, z, training=False):
        x = F.elu(self.fc1(z))
        if training:
            x = self.dropout(x)
        x = x.view(-1, 256, 3, 3, 3)
        x = F.elu(self.deconv1(x))
        x = F.elu(self.deconv2(x))
        x = F.elu(self.deconv3(x))
        x = self.deconv4(x)
        return x

def transform3d(volume, angle3d, translations, padding_mode="zeros"):
    B, C, D, H, W = volume.shape
    center = torch.tensor([[D / 2, H / 2, W / 2]]).to(volume).repeat(B, 1).float()
    scale = torch.ones(1, 1).to(volume).repeat(B, 1).float()
    zero_translations = torch.zeros(1, 3).to(volume).repeat(B, 1).float()
    zero_angles = torch.zeros(1, 3).to(volume).repeat(B, 1).float()
    if angle3d.shape[-1] == 6:
        rotation_matrix3d = s2s2_to_SO3(angle3d)
        projective_matrix = convert_SO3_to_kornia_affine_matrix(rotation_matrix3d, center)
    elif angle3d.shape[-1] == 3:
        affine_matrix3d = kornia.geometry.get_affine_matrix3d(
            translations=zero_translations, center=center, scale=scale, angles=angle3d
        )
        projective_matrix = affine_matrix3d[:, :3, :]
    else:
        raise NotImplementedError
    volume_rotated = kornia.geometry.warp_affine3d(
        volume, projective_matrix, dsize=(D, H, W), align_corners=False, padding_mode=padding_mode
    )
    affine_matrix3d = kornia.geometry.get_affine_matrix3d(
        translations=translations, center=center, scale=scale, angles=zero_angles
    )
    projective_matrix = affine_matrix3d[:, :3, :]
    volume_transformed = kornia.geometry.warp_affine3d(
        volume_rotated, projective_matrix, dsize=(D, H, W), align_corners=False, padding_mode=padding_mode
    )
    return volume_transformed

def transform3d_cross(volume, transform_vectors):
    B, N = volume.shape[0], transform_vectors.shape[0]
    volume_batched = volume[:, None].repeat_interleave(N, dim=1)
    volume_batched = rearrange(volume_batched, "b n c d h w -> (b n) c d h w")
    transform_vectors_batched = transform_vectors[None].repeat_interleave(B, dim=0)
    transform_vectors_batched = rearrange(transform_vectors_batched, "b n c -> (b n) c")
    angle3d, translations = transform_vectors_batched[:, :3], transform_vectors_batched[:, 3:]
    volume_transformed = transform3d(volume_batched, angle3d, translations)
    volume_transformed = rearrange(volume_transformed, "(b n) c d h w -> b n c d h w", b=B)
    return volume_transformed

def transform2d(volume, angle2d, translations, padding_mode="zeros"):
    volume = rearrange(volume, "b c w h d -> b c d h w")
    B, C, D, H, W = volume.shape
    center = torch.tensor([[H / 2, W / 2]]).to(volume).repeat(B, 1)
    scale = torch.ones(1, 2).to(volume).repeat(B, 1)
    rotation_matrix2d = kornia.geometry.transform.get_rotation_matrix2d(center=center, angle=angle2d, scale=scale)
    images = rearrange(volume, "b c d h w -> (b d) c h w")
    images_rotated = kornia.geometry.transform.warp_affine(
        images, rotation_matrix2d.repeat(D, 1, 1), dsize=(H, W), align_corners=False
    )
    volume_rotated = rearrange(images_rotated, "(b d) c h w -> b c d h w", b=B)
    zero_angles = torch.zeros(1, 3).to(volume).repeat(B, 1)
    center = torch.tensor([[D / 2, H / 2, W / 2]]).to(volume).repeat(B, 1)
    scale = torch.ones(1, 1).to(volume).repeat(B, 1)
    affine_matrix3d = kornia.geometry.get_affine_matrix3d(
        translations=translations, center=center, scale=scale, angles=zero_angles
    )
    projective_matrix = affine_matrix3d[:, :3, :]
    volume_transformed = kornia.geometry.warp_affine3d(
        volume_rotated, projective_matrix, dsize=(D, H, W), align_corners=False, padding_mode=padding_mode
    )
    volume_transformed = rearrange(volume_transformed, "b c d h w -> b c w h d")
    return volume_transformed

class AutoEncoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        super(AutoEncoder, self).__init__()
        self.encoder = ConvEncoder(latent_dims, pixel)
        self.decoder = ConvDecoder(latent_dims, pixel)
        self.latent_dims = latent_dims

    def forward(self, x, training=False):
        theta_dim = 6
        phi = self.encoder(x, training)
        theta = phi[:, :theta_dim]
        trans = F.tanh(phi[:, theta_dim:theta_dim+3])*(2.0+0.0001)
        z = phi[:, -self.latent_dims:]
        if training:
            z += torch.randn_like(z)*0.01
        image_z = self.decoder(z, training)
        image_x_theta = transform3d(x, theta, trans)
        return image_z, image_x_theta, phi

class Siamese(nn.Module):
    def __init__(self, latent_dims, pixel):
        super(Siamese, self).__init__()
        self.autoencoder = AutoEncoder(latent_dims, pixel)

    def forward(self, image, training=False, initial=True):
        with torch.no_grad():
            if initial:
                angle = torch.randint(0, 360, (image.shape[0],)).to(image)
            else:
                angle = torch.randint(-6, 7, (image.shape[0],)).to(image)
            translations = torch.rand(image.shape[0], 3).to(image) * 2 - 0.5
            transformed_image = transform2d(image, angle2d=angle, translations=translations) 
        image_z1, image_x_theta1, phi1 = self.autoencoder(image, training=training) 
        image_z2, image_x_theta2, phi2 = self.autoencoder(transformed_image, training=training)
        return image_z1, image_z2, image_x_theta1, image_x_theta2, phi1, phi2

def visualize_subtomogram(arr, axis, fname):
    subtomogram_size = arr.shape[0]#set the size of the subtomogram here, if subtomogram has dimension 32x32x32, size is 32
    plt_size = int(np.ceil(np.sqrt(subtomogram_size)))
    plt.clf()
    fig = plt.figure(figsize=(plt_size, plt_size))  # Notice the equal aspect ratio
    plot_per_row = plt_size
    plot_per_col = plt_size
    ax = [fig.add_subplot(plot_per_row, plot_per_col, i + 1) for i in range(plot_per_row*plot_per_col)]
    s=np.std(arr)*5
    i = 0
    for a in ax:
        a.xaxis.set_visible(False)
        a.yaxis.set_visible(False)
        a.set_aspect('equal')
        if i<subtomogram_size:
            if axis==0:
                a.imshow(arr[i,:,:], vmin=-s, vmax=s, cmap='gray')
            elif axis==1:
                a.imshow(arr[:,i,:], vmin=-s, vmax=s, cmap='gray')
            elif axis==2:
                a.imshow(arr[:,:,i], vmin=-s, vmax=s, cmap='gray')
            else:
                raise NotImplementedError
        else:
            a.imshow(np.zeros((subtomogram_size,subtomogram_size)),cmap='gray')
        i += 1

    fig.subplots_adjust(wspace=0, hspace=0)
    
    plt.savefig(fname, dpi=150, bbox_inches='tight')


dict = pickle.load(open(os.path.join(data_path, dataset_name), 'rb'))
data = [d['subtomo'] for d in dict]
data = np.array(data).astype(np.float32)
print('dataset size', data.shape)
data = np.transpose(data, (0,3,2,1))
print('dataset size', data.shape)


def normalize_volumes_to_range(data, new_min=-10, new_max=10):
    data = np.array(data)
    
    # Compute min and max for each volume (axis 1, 2, 3)
    min_vals = data.min(axis=(1, 2, 3), keepdims=True)
    max_vals = data.max(axis=(1, 2, 3), keepdims=True)
    
    # Prevent division by zero
    denom = np.where(max_vals - min_vals == 0, 1, max_vals - min_vals)
    
    # Normalize to [0, 1], then scale to [new_min, new_max]
    norm = (data - min_vals) / denom
    scaled = norm * (new_max - new_min) + new_min
    return scaled

means = np.mean(data, axis=(1, 2, 3), keepdims=True)
stds = np.std(data, axis=(1, 2, 3), keepdims=True)
data = (data - means) / stds

with torch.no_grad():
    model = Siamese(latent_dims=50, pixel=48).cuda()
    model.load_state_dict(torch.load(f"models/yeast_mwa_0_snr_01_cellular_protein_mixture.pt"))
    print('Model Loaded!')
encoder_params = sum(p.numel() for p in model.autoencoder.encoder.parameters() if p.requires_grad)
decoder_params = sum(p.numel() for p in model.autoencoder.decoder.parameters() if p.requires_grad)

print('Encoder params: ', encoder_params/10**6, 'M')
print('Decoder params: ', decoder_params/10**6, 'M')

tensor_data = torch.from_numpy(data).float()

def create_soft_spherical_mask_torch(box_size, particle_diameter, pixel_size, falloff=5, device='cpu'):
    radius = (particle_diameter / 2) / pixel_size
    falloff_pixels = (falloff / 100) * radius
    center = box_size // 2

    coords = torch.arange(box_size, device=device) - center
    x, y, z = torch.meshgrid(coords, coords, coords, indexing='ij')
    r = torch.sqrt(x**2 + y**2 + z**2)

    mask = torch.ones_like(r)

    mask[r > radius] = 0.0

    transition_zone = (r > (radius - falloff_pixels)) & (r <= radius)
    mask[transition_zone] = 0.5 * (
        1 + torch.cos(np.pi * (r[transition_zone] - radius + falloff_pixels) / falloff_pixels)
    )

    return mask.float()


mask = create_soft_spherical_mask_torch(
    box_size=48,
    particle_diameter=320,  # Angstroms
    pixel_size=7.5,         # Angstroms per pixel
    falloff=1,              # 1% cosine edge
)

masked_data = tensor_data* mask[None, :, :, :]

loader = torch.utils.data.DataLoader(masked_data, batch_size=100, shuffle=False, drop_last=True)
midall = []
model.eval()
with torch.no_grad():
    for batch_idx, images in enumerate(loader):
        images = images.cuda()
        images = images.unsqueeze(1).float()
        phi = model.autoencoder.encoder(images)
        mid = phi[:,9:9+50]
        midall.append(mid.cpu().detach().numpy())
mid = np.array(midall).reshape(-1, 50)

import umap.umap_ as umap

reducer = umap.UMAP(n_components=2, random_state=777)
m = reducer.fit_transform(mid)


import matplotlib.patches as mpatches
from sklearn.mixture import GaussianMixture
from sklearn.metrics import adjusted_rand_score
K = 4
gmm = GaussianMixture(n_components=K, random_state=777)
gmm.fit(m)
labels = gmm.predict(m)
labels = labels[:len(m)]


plt.clf()
scatter = plt.scatter(m[:,0],m[:,1], c=pred_labels,cmap='rainbow', alpha=0.25)
unique_labels = np.unique(labels)
legend_handles = [
    mpatches.Patch(color=scatter.cmap(scatter.norm(label)), label=f'Structure {label}')
    for label in unique_labels
]
plt.legend(handles=legend_handles, title="Labels")
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.savefig('results/umap_with_gmm_labels.png')

out_path = 'results/'
os.makedirs(out_path, exist_ok=True)

for k in range(K):
    group_id = k
    labels = np.array(labels)
    group_med = np.median(mid[(labels==group_id)],axis=0)

    group_med_tensor = torch.from_numpy(group_med).float().unsqueeze(0).cuda()

    model.eval()
    with torch.no_grad():
        out = model.autoencoder.decoder(group_med_tensor).squeeze()

    fname = os.path.join(out_path, f"{group_id}_decoded_output.png")
    fname_mrc = os.path.join(out_path, f"{group_id}_decoded_output.mrc")
    with mrcfile.new(fname_mrc, overwrite=True) as mrc:
        mrc.set_data(out.cpu().detach().numpy())

    visualize_subtomogram(out.cpu().detach().numpy(),1,fname)

