import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import SequentialLR, LambdaLR, StepLR
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import os
import pickle
import kornia
import time
from rotation import *
from functorch.einops import rearrange
from sklearn.mixture import GaussianMixture
from sklearn.metrics import adjusted_rand_score
import numpy as np



data_path = 'datasets/'
dataset_name = 'yeast_mwa_30_snr_001_cellular_protein_mixture_subtomograms.pkl'

def init_weights(m):
    if isinstance(m, (nn.Linear, nn.ConvTranspose3d, nn.Conv3d)):
        torch.nn.init.xavier_uniform_(m.weight)  # Xavier (Glorot) initialization
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)  # Zero initialization for biases

def normalize_to_range(arr, new_min=-10, new_max=10):
    arr = np.array(arr)
    old_min, old_max = arr.min(), arr.max()
    
    # Avoid division by zero if all values are the same
    if old_min == old_max:
        return np.full_like(arr, (new_min + new_max) / 2)

    normalized = (arr - old_min) / (old_max - old_min)  # scale to [0, 1]
    scaled = normalized * (new_max - new_min) + new_min  # scale to [new_min, new_max]
    return scaled

def normalize_volumes_to_range(data, new_min=-10, new_max=10):
    data = np.array(data)
    
    # Compute min and max for each volume (axis 1, 2, 3)
    min_vals = data.min(axis=(1, 2, 3), keepdims=True)
    max_vals = data.max(axis=(1, 2, 3), keepdims=True)
    
    # Prevent division by zero
    denom = np.where(max_vals - min_vals == 0, 1, max_vals - min_vals)
    
    # Normalize to [0, 1], then scale to [new_min, new_max]
    norm = (data - min_vals) / denom
    scaled = norm * (new_max - new_min) + new_min
    return scaled

def visualize_subtomogram(arr, axis, fname='tmp.png'):
    subtomogram_size = arr.shape[0]#set the size of the subtomogram here, if subtomogram has dimension 32x32x32, size is 32
    plt_size = int(np.ceil(np.sqrt(subtomogram_size)))
    plt.clf()
    fig = plt.figure(figsize=(plt_size, plt_size))  # Notice the equal aspect ratio
    plot_per_row = plt_size
    plot_per_col = plt_size
    ax = [fig.add_subplot(plot_per_row, plot_per_col, i + 1) for i in range(plot_per_row*plot_per_col)]
    s=np.std(arr)*5
    i = 0
    for a in ax:
        a.xaxis.set_visible(False)
        a.yaxis.set_visible(False)
        a.set_aspect('equal')
        if i<subtomogram_size:
            if axis==0:
                a.imshow(arr[i,:,:], vmin=-s, vmax=s, cmap='gray')
            elif axis==1:
                a.imshow(arr[:,i,:], vmin=-s, vmax=s, cmap='gray')
            elif axis==2:
                a.imshow(arr[:,:,i], vmin=-s, vmax=s, cmap='gray')
            else:
                raise NotImplementedError
        else:
            a.imshow(np.zeros((subtomogram_size,subtomogram_size)),cmap='gray')
        i += 1

    fig.subplots_adjust(wspace=0, hspace=0)
    
    plt.savefig(fname, bbox_inches="tight")


class MuyEncoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        self.pixel = pixel
        theta_dim = 6
        super(MuyEncoder, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv3d(64, 128, kernel_size=5, stride=2, padding=1)
        self.conv4 = nn.Conv3d(128, 256, kernel_size=5, stride=1, padding=0)
        self.fc = nn.Linear(256, 256)
        self.fc_last = nn.Linear(256, latent_dims+theta_dim+3)
        self.dropout = nn.Dropout3d(0.1)
        
        # Apply Xavier initialization
        self.apply(init_weights)
        
    def forward(self, x, training=False):
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        if training:
            x = self.dropout(x)
        x = F.elu(self.conv4(x))
        if training:
            x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.elu(self.fc(x))
        phi = self.fc_last(x)
        return phi

class ConvEncoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        self.pixel = pixel
        theta_dim = 6
        super(ConvEncoder, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv3d(128, 256, kernel_size=4, stride=1, padding=0)
        self.fc = nn.Linear(6912, 1024)
        self.fc_last = nn.Linear(1024, latent_dims+theta_dim+3)
        self.fc_dropout = nn.Dropout(0.2)
        
        # Apply Xavier initialization
        self.apply(init_weights)
        
    def forward(self, x, training=False):
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = F.elu(self.fc(x))
        if training:
            x = self.fc_dropout(x)
        phi = self.fc_last(x)
        return phi
    
class ConvDecoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        super(ConvDecoder, self).__init__()
        self.fc1 = nn.Linear(latent_dims, 256 * 3 * 3 * 3)
        self.deconv1 = nn.ConvTranspose3d(256, 128, kernel_size=4, stride=1, padding=0)  # → (128, 6, 6, 6)
        self.deconv2 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1)   # → (64, 12, 12, 12)
        self.deconv3 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1)    # → (32, 24, 24, 24)
        self.deconv4 = nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, padding=1)     # → (1, 48, 48, 48)
        self.dropout = nn.Dropout(0.1)
        
        # Apply Xavier initialization
        self.apply(init_weights)
        
    def forward(self, z, training=False):
        x = F.elu(self.fc1(z))
        if training:
            x = self.dropout(x)
        x = x.view(-1, 256, 3, 3, 3)
        x = F.elu(self.deconv1(x))
        x = F.elu(self.deconv2(x))
        x = F.elu(self.deconv3(x))
        x = self.deconv4(x)
        return x


def transform3d(volume, angle3d, translations, padding_mode="zeros"):
    B, C, D, H, W = volume.shape
    center = torch.tensor([[D / 2, H / 2, W / 2]]).to(volume).repeat(B, 1).float()
    scale = torch.ones(1, 1).to(volume).repeat(B, 1).float()
    zero_translations = torch.zeros(1, 3).to(volume).repeat(B, 1).float()
    zero_angles = torch.zeros(1, 3).to(volume).repeat(B, 1).float()
    if angle3d.shape[-1] == 6:
        rotation_matrix3d = s2s2_to_SO3(angle3d)
        projective_matrix = convert_SO3_to_kornia_affine_matrix(rotation_matrix3d, center)
    elif angle3d.shape[-1] == 3:
        #rotation_matrix3d = axis_angle_to_SO3(angle3d)
        #projective_matrix = convert_SO3_to_kornia_affine_matrix(rotation_matrix3d, center)
        affine_matrix3d = kornia.geometry.get_affine_matrix3d(
            translations=zero_translations, center=center, scale=scale, angles=angle3d
        )
        projective_matrix = affine_matrix3d[:, :3, :]
    else:
        raise NotImplementedError
    volume_rotated = kornia.geometry.warp_affine3d(
        volume, projective_matrix, dsize=(D, H, W), align_corners=False, padding_mode=padding_mode
    )
    affine_matrix3d = kornia.geometry.get_affine_matrix3d(
        translations=translations, center=center, scale=scale, angles=zero_angles
    )
    projective_matrix = affine_matrix3d[:, :3, :]
    volume_transformed = kornia.geometry.warp_affine3d(
        volume_rotated, projective_matrix, dsize=(D, H, W), align_corners=False, padding_mode=padding_mode
    )
    return volume_transformed

def transform3d_cross(volume, transform_vectors):
    B, N = volume.shape[0], transform_vectors.shape[0]
    volume_batched = volume[:, None].repeat_interleave(N, dim=1)
    volume_batched = rearrange(volume_batched, "b n c d h w -> (b n) c d h w")
    transform_vectors_batched = transform_vectors[None].repeat_interleave(B, dim=0)
    transform_vectors_batched = rearrange(transform_vectors_batched, "b n c -> (b n) c")
    angle3d, translations = transform_vectors_batched[:, :3], transform_vectors_batched[:, 3:]
    volume_transformed = transform3d(volume_batched, angle3d, translations)
    volume_transformed = rearrange(volume_transformed, "(b n) c d h w -> b n c d h w", b=B)
    return volume_transformed

def transform2d(volume, angle2d, translations, padding_mode="zeros"):
    volume = rearrange(volume, "b c w h d -> b c d h w")
    B, C, D, H, W = volume.shape
    center = torch.tensor([[H / 2, W / 2]]).to(volume).repeat(B, 1)
    scale = torch.ones(1, 2).to(volume).repeat(B, 1)
    rotation_matrix2d = kornia.geometry.transform.get_rotation_matrix2d(center=center, angle=angle2d, scale=scale)
    images = rearrange(volume, "b c d h w -> (b d) c h w")
    images_rotated = kornia.geometry.transform.warp_affine(
        images, rotation_matrix2d.repeat(D, 1, 1), dsize=(H, W), align_corners=False
    )
    volume_rotated = rearrange(images_rotated, "(b d) c h w -> b c d h w", b=B)
    zero_angles = torch.zeros(1, 3).to(volume).repeat(B, 1)
    center = torch.tensor([[D / 2, H / 2, W / 2]]).to(volume).repeat(B, 1)
    scale = torch.ones(1, 1).to(volume).repeat(B, 1)
    affine_matrix3d = kornia.geometry.get_affine_matrix3d(
        translations=translations, center=center, scale=scale, angles=zero_angles
    )
    projective_matrix = affine_matrix3d[:, :3, :]
    volume_transformed = kornia.geometry.warp_affine3d(
        volume_rotated, projective_matrix, dsize=(D, H, W), align_corners=False, padding_mode=padding_mode
    )
    volume_transformed = rearrange(volume_transformed, "b c d h w -> b c w h d")
    return volume_transformed

def create_soft_spherical_mask_torch(box_size, particle_diameter, pixel_size, falloff=5, device='cpu'):
    radius = (particle_diameter / 2) / pixel_size
    print(radius)
    falloff_pixels = (falloff / 100) * radius
    center = box_size // 2

    coords = torch.arange(box_size, device=device) - center
    x, y, z = torch.meshgrid(coords, coords, coords, indexing='ij')
    r = torch.sqrt(x**2 + y**2 + z**2)

    mask = torch.ones_like(r)

    mask[r > radius] = 0.0

    transition_zone = (r > (radius - falloff_pixels)) & (r <= radius)
    mask[transition_zone] = 0.5 * (
        1 + torch.cos(np.pi * (r[transition_zone] - radius + falloff_pixels) / falloff_pixels)
    )

    return mask.float()

class AutoEncoder(nn.Module):
    def __init__(self, latent_dims, pixel):
        super(AutoEncoder, self).__init__()
        self.encoder = ConvEncoder(latent_dims, pixel)
        self.decoder = ConvDecoder(latent_dims, pixel)
        self.latent_dims = latent_dims

    def forward(self, x, training=False):
        theta_dim = 6
        phi = self.encoder(x, training)
        theta = phi[:, :theta_dim]
        trans = F.tanh(phi[:, theta_dim:theta_dim+3])*(2.0+0.0001)
        z = phi[:, -self.latent_dims:]
        if training:
            z += torch.randn_like(z)*0.01
        image_z = self.decoder(z, training)
        image_x_theta = transform3d(x, theta, trans)
        #image_z = torch.sin(image_z)*10.0
        #image_x_theta = torch.sin(image_x_theta)*10.0
        return image_z, image_x_theta, phi

class Siamese(nn.Module):
    def __init__(self, latent_dims, pixel):
        super(Siamese, self).__init__()
        self.autoencoder = AutoEncoder(latent_dims, pixel)

    def forward(self, image, training=False, initial=True):
        with torch.no_grad():
            angle = torch.randint(-6, 7, (image.shape[0],)).to(image)
            translations = torch.rand(image.shape[0], 3).to(image) * 2 - 0.5
            transformed_image = transform2d(image, angle2d=angle, translations=translations) 
        image_z1, image_x_theta1, phi1 = self.autoencoder(image, training=training) 
        image_z2, image_x_theta2, phi2 = self.autoencoder(transformed_image, training=training)
        return image_z1, image_z2, image_x_theta1, image_x_theta2, phi1, phi2

def create_transform_vectors(epoch, t_size, device=None):
    if epoch<=40:
        angle = torch.randint(-90, 91, (t_size, 3), device=device).float()
    else:
        angle = torch.randint(-30, 31, (t_size, 3), device=device).float()
    trans = torch.rand(t_size, 3, device=device).float()*2 - 0.5
    transform_vectors = torch.cat([angle, trans], dim=1)
    return transform_vectors

def get_maxout_loss(x1, x2, epoch):
    device = x1.device  # Get device from input tensor
    transform_vectors = create_transform_vectors(epoch, 96, device=device)
    outrot = transform3d_cross(x1, transform_vectors)
    outrot = outrot.squeeze()
    images = x2.squeeze()
    diff = images[:,None,:,:]-outrot  
    diff = torch.sum(diff ** 2, dim=(2, 3, 4))
    diff = torch.min(diff, dim=1)[0]
    diff_loss = torch.mean(diff)
    return diff_loss

def loss_fn(image_z1, image_z2, image_x_theta1, image_x_theta2, phi1, phi2, epoch):
    n = image_x_theta1.size(0)
    recon_loss1 = get_maxout_loss(image_z1, image_x_theta1, epoch)
    recon_loss2 = get_maxout_loss(image_z2, image_x_theta2, epoch)
    #For Harmony3D use the following commented out loss functions
    #recon_loss1 = F.mse_loss(image_x_theta1, image_z1, reduction='sum').div(n)
    #recon_loss2 = F.mse_loss(image_x_theta2, image_z2, reduction='sum').div(n)
    branch_loss =  torch.mean(torch.sum((image_x_theta1-image_x_theta2) ** 2, dim=(2, 3, 4)))
    loss = (recon_loss1 + recon_loss2 + branch_loss)
    dim = 50
    z1 = phi1[:, -dim:]
    z2 = phi2[:, -dim:]
    z_loss = F.l1_loss(z1, z2, reduction="sum").div(dim*n)
    loss = loss + z_loss
    return loss



dict = pickle.load(open(os.path.join(data_path, dataset_name), 'rb'))
data = [d['subtomo'] for d in dict]
data = np.array(data).astype(np.float32)
# Transpose if mwa is non-zero
data = np.transpose(data, (0,3,2,1))
print('dataset size', data.shape)


# Vectorized normalization - much faster than loop
means = np.mean(data, axis=(1, 2, 3), keepdims=True)
stds = np.std(data, axis=(1, 2, 3), keepdims=True)
data = (data - means) / stds


small_mask = create_soft_spherical_mask_torch(
    box_size=48,
    particle_diameter=320,  # Angstroms
    pixel_size=7.5,         # Angstroms per pixel
    falloff=1,              # 1% cosine edge
)



model = Siamese(latent_dims=50, pixel=48).cuda()
encoder_params = sum(p.numel() for p in model.autoencoder.encoder.parameters() if p.requires_grad)
decoder_params = sum(p.numel() for p in model.autoencoder.decoder.parameters() if p.requires_grad)

print('Encoder params: ', encoder_params/10**6, 'M')
print('Decoder params: ', decoder_params/10**6, 'M')


# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

tensor_data = torch.from_numpy(data).float()
masked_data = tensor_data* small_mask[None, :, :, :]
print('Masked data shape', masked_data.shape)
loader = torch.utils.data.DataLoader(masked_data, batch_size=16, shuffle=True, drop_last=True)

visualize_subtomogram(masked_data[0].cpu().detach().numpy(), 0, 'masked_data.png')

num_epochs = 200

train=True
initial=False
train_losses = []
if train:
    print('Training Harmony (Axis Angle) + Maxout model')
    for epoch in range(1,num_epochs+1):
        epoch_start_time = time.time()
        epoch_loss = 0
        model.train()
        for batch_idx, images in enumerate(loader):
            images = images.cuda()
            images = images.unsqueeze(1).float()
            image_z1, image_z2, image_x_theta1, image_x_theta2, phi1, phi2 = model(images, training=True, initial=initial)
            loss = loss_fn(image_z1, image_z2, image_x_theta1, image_x_theta2, phi1, phi2, epoch)            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        epoch_time = time.time() - epoch_start_time
        avg_loss = epoch_loss/len(loader)
        mins = int(epoch_time // 60)
        secs = int(epoch_time % 60)
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch: {epoch}, Loss: {avg_loss:.6f}, LR: {current_lr:.6f}, Time: {mins} mins {secs} secs')
        train_losses.append(avg_loss)
        
        # Step the learning rate scheduler
        #scheduler.step()
        
        if epoch%20==0 and epoch>0:
            torch.save(model.state_dict(), f"models/yeast_mwa_0_snr_01_cellular_protein_mixture_{epoch}.pt")
            print('Model Saved!')
    torch.save(model.state_dict(), f"models/yeast_mwa_0_snr_01_cellular_protein_mixture.pt")
    print('Model Saved!')
    plt.clf()
    plt.plot([i+1 for i in range(num_epochs)], train_losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig('models/loss_curve_yeast_mwa_0_snr_01_cellular_protein_mixture.png')
else:    
    model.load_state_dict(torch.load(f"models/yeast_mwa_0_snr_01_cellular_protein_mixture.pt"))
    print('Model Loaded!')

print('Doing Inference:')

model.eval()
loader = torch.utils.data.DataLoader(masked_data, batch_size=100, shuffle=False, drop_last=True)

with torch.no_grad():
    midall = []
    for batch_idx, images in enumerate(loader):
        images = images.cuda()
        images = images.unsqueeze(1).float()
        phi = model.autoencoder.encoder(images)
        mid = phi[:,9:9+50]
        midall.append(mid.cpu().detach().numpy())
    mid = np.array(midall).reshape(-1, 50)

import umap.umap_ as umap

reducer = umap.UMAP(n_components=2, random_state=777)
m = reducer.fit_transform(mid)


labels = []
for i in range(4):
    labels.extend([i]*1000)

labels = labels[:len(m)]

plt.clf()
scatter = plt.scatter(m[:,0],m[:,1], c=labels,cmap='rainbow', alpha=0.25)
unique_labels = np.unique(labels)
legend_handles = [
    mpatches.Patch(color=scatter.cmap(scatter.norm(label)), label=f'Label {label}')
    for label in unique_labels
]
plt.legend(handles=legend_handles, title="Labels")
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.savefig("results/umap_yeast_mwa_0_snr_01_cellular_protein_mixture_with_GT_labels.png",bbox_inches='tight',dpi=200)

gmm = GaussianMixture(n_components=4, random_state=777)
gmm.fit(m)
pred_labels = gmm.predict(m)
pred_labels = pred_labels[:len(m)]

# Calculate ARI between predicted labels and ground truth labels
ari_score = adjusted_rand_score(labels, pred_labels)
print(f"Adjusted Rand Index (ARI): {ari_score:.4f}")


from scipy.optimize import linear_sum_assignment

def hungarian_match_accuracy(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    assert y_true.shape == y_pred.shape

    true_labels = np.unique(y_true)
    pred_labels = np.unique(y_pred)
    n_t, n_p = len(true_labels), len(pred_labels)
    n = max(n_t, n_p)

    true_to_idx = {v: i for i, v in enumerate(true_labels)}
    pred_to_idx = {v: i for i, v in enumerate(pred_labels)}
    y_true_i = np.vectorize(true_to_idx.get)(y_true)
    y_pred_i = np.vectorize(pred_to_idx.get)(y_pred)

    cm = np.zeros((n_t, n_p), dtype=np.int64)
    for t, p in zip(y_true_i, y_pred_i):
        cm[t, p] += 1

    if n_t != n_p:
        cm_square = np.zeros((n, n), dtype=np.int64)
        cm_square[:n_t, :n_p] = cm
    else:
        cm_square = cm

    cost = cm_square.max() - cm_square
    row_ind, col_ind = linear_sum_assignment(cost)

    mapping = {}
    for r, c in zip(row_ind, col_ind):
        if r < n_t and c < n_p:
            mapping[pred_labels[c]] = true_labels[r]

    y_pred_aligned = np.array([mapping.get(lbl, lbl) for lbl in y_pred], dtype=object)
    return float((y_pred_aligned == y_true).mean())

print('Accuracy: ', hungarian_match_accuracy(labels, pred_labels))
