import torch
import torch.nn as nn
from nflows.flows import Flow
from nflows.distributions import StandardNormal
from nflows.transforms import AffineCouplingTransform, RandomPermutation, CompositeTransform
from nflows.transforms import AffineCouplingTransform, CompositeTransform, RandomPermutation
import torch.optim as optim
from common.variables import *
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pickle as pkl
import numpy as np

def create_alternating_binary_mask(features, even=True):
    mask = torch.arange(features) % 2 == 0 if even else torch.arange(features) % 2 == 1
    return mask.float()

class TransformNet(nn.Module):
    def __init__(self, in_features, out_features, hidden_dim):
        super(TransformNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_features)
        )

    def forward(self, x, context=None):
        return self.net(x)

class SphericalMappingNetWithFlows(nn.Module):
    def __init__(self, input_dim_A, input_dim_B, hidden_dim, output_dim=OUTPUT_DIM, flow_layers=FLOW_LENGTH, restrict_to_positive_quadrant=False):
        super(SphericalMappingNetWithFlows, self).__init__()
        self.input_dim_A = input_dim_A
        self.input_dim_B = input_dim_B
        self.output_dim = output_dim
        self.restrict_to_positive_quadrant = restrict_to_positive_quadrant
        
        if self.input_dim_A + self.input_dim_B <= 3:
            self.fcA1 = nn.Linear(input_dim_A, hidden_dim)
            self.fcA2 = nn.Linear(hidden_dim, 1)
            self.fcB1 = nn.Linear(input_dim_B, hidden_dim)
            self.fcB2 = nn.Linear(hidden_dim, 1)
        else:
            self.fcA1 = nn.Linear(input_dim_A, hidden_dim)
            self.fcA2 = nn.Linear(hidden_dim, max(1, self.output_dim // 2 - 1))
            self.fcB1 = nn.Linear(input_dim_B, hidden_dim)
            self.fcB2 = nn.Linear(hidden_dim, max(1, self.output_dim // 2))
        
        self.flow = self.create_flow(input_dim_A + input_dim_B, hidden_dim, flow_layers)
        self.base_dist = StandardNormal(shape=[input_dim_A + input_dim_B])
    
    def forward(self, A, B):
        if self.input_dim_A + self.input_dim_B <= 3:
            return self.forward_3D(A, B)
        else:
            return self.forward_ND(A, B)

    def forward_ND(self, A, B):
        xA = F.relu(self.fcA1(A))
        phi = torch.tanh(self.fcA2(xA)) * (torch.pi / 2)
        
        xB = F.relu(self.fcB1(B))
        theta = self.fcB2(xB) * torch.pi
        
        cartesian_output = self.spherical_to_cartesian(phi, theta)
        
        if self.restrict_to_positive_quadrant:
            cartesian_output = torch.clamp(cartesian_output, min=0)
        
        try:
            inputs = torch.cat((A, B), dim=-1)
            z, log_det = self.flow.forward(inputs)
        except Exception as e:
            try:
                inputs = torch.cat((A, B), dim=-1).squeeze()
                z, log_det = self.flow.forward(inputs)
            except Exception as e2:
                print(f"Both inputs failed. Error 1: {e}, Error 2: {e2}")
        
        return cartesian_output, log_det

    def forward_3D(self, A, B):
        xA = torch.relu(self.fcA1(A))
        phi = torch.tanh(self.fcA2(xA)) * (torch.pi / 2)

        xB = torch.relu(self.fcB1(B))
        theta = self.fcB2(xB) * torch.pi

        x = torch.cos(theta) * torch.cos(phi)
        y = torch.sin(theta) * torch.cos(phi)
        z = torch.sin(phi)

        cartesian_output = torch.stack([x, y, z], dim=-1)

        if self.restrict_to_positive_quadrant:
            cartesian_output = torch.clamp(cartesian_output, min=0)

        inputs = torch.cat((A, B), dim=-1)
        z, log_det = self.flow.forward(inputs)
        
        return cartesian_output, log_det

    def spherical_to_cartesian(self, phi, theta):
        coords = []
        n = self.output_dim
        
        r = torch.ones_like(phi[..., 0])
        
        for i in range(n - 2): 
            if i < phi.shape[-1]:
                r = r * torch.cos(phi[..., i])
            else:
                r = r * torch.sin(theta[..., i - phi.shape[-1]])
            coords.append(r)
        
        coords.append(r * torch.cos(phi[..., -1]))
        coords.append(r * torch.sin(phi[..., -1]))

        return torch.stack(coords, dim=-1)

    def create_flow(self, input_dim, hidden_dim, flow_layers):
        transforms = []
        for i in range(flow_layers):
            transforms.append(RandomPermutation(features=input_dim))
            transforms.append(
                AffineCouplingTransform(
                    mask=create_alternating_binary_mask(input_dim, even=(i % 2 == 0)),
                    transform_net_create_fn=lambda in_features, out_features: TransformNet(in_features, out_features, hidden_dim)
                )
            )
        return CompositeTransform(transforms)

    def map_to_sphere(self, A, B):
        cartesian_output, _ = self.forward(A, B)
        return cartesian_output

    def log_prob(self, A, B):
        inputs = torch.cat((A, B), dim=-1)
        z, log_det = self.flow.forward(inputs)
        log_prob_base = self.base_dist.log_prob(z)
        log_prob = log_prob_base + log_det
        return log_prob

    def invert(self, z):
        if z.dim() != 2:
            z = z.view(z.size(0), -1)
        tt = self.flow.inverse(z)
        inputs, _ = self.flow.inverse(z)
        A = inputs[:, :self.input_dim_A]
        B = inputs[:, self.input_dim_A:]
        return A, B
    
    def decode(self, cartesian_coords):
        return self.invert(cartesian_coords)
    
    def save_weights(self, weight_path):
        torch.save(self.state_dict(), weight_path)

class InverseMappingNet(nn.Module):
    def __init__(self, input_dim, x1_dim, x2_dim, hidden_dim=512, dropout_rate=0.1):
        super(InverseMappingNet, self).__init__()
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Dropout(dropout_rate)
        )
        self.x1_head = nn.Sequential(
            nn.Linear(hidden_dim, x1_dim),
            nn.Sigmoid()
        )
        self.x2_head = nn.Sequential(
            nn.Linear(hidden_dim, x2_dim),
            nn.Sigmoid()
        )

    def forward(self, y):
        shared_rep = self.shared_layers(y)
        x1 = self.x1_head(shared_rep)
        x2 = self.x2_head(shared_rep)
        return x1, x2

class PercentageThresholdLoss(nn.Module):
    def __init__(self, threshold_percentage=5.0, high_penalty_factor=10.0):
        super(PercentageThresholdLoss, self).__init__()
        self.threshold_percentage = threshold_percentage / 100.0
        self.high_penalty_factor = high_penalty_factor

    def forward(self, predicted, target):
        error = torch.abs(predicted - target)
        threshold = torch.abs(target) * self.threshold_percentage
        low_penalty_loss = F.mse_loss(predicted, target, reduction='none')
        high_penalty_loss = self.high_penalty_factor * low_penalty_loss
        within_threshold = error <= threshold
        final_loss = torch.where(within_threshold, low_penalty_loss, high_penalty_loss)
        return final_loss.mean()

class HashTableInverseMapping:
    def __init__(self):
        self.hash_table = {}

    def memorize_mappings(self, y, x1, x2):
        y_hashable = [tuple(y_i.flatten().detach().tolist()) for y_i in y]
        for i, y_val in enumerate(y_hashable):
            self.hash_table[y_val] = (x1[i].detach().tolist(), x2[i].detach().tolist())

    def inverse_hash(self, y_queries, num_neighbors=1):
        y_queries_np = y_queries.view(y_queries.shape[0], -1).detach().numpy()
        x1_results = []
        x2_results = []
        for y_query_np in y_queries_np:
            nearest_x1 = []
            nearest_x2 = []
            distances = []
            for y_key in self.hash_table.keys():
                y_key_np = np.array(y_key)
                dist = np.linalg.norm(y_query_np - y_key_np)
                distances.append((dist, y_key))
            distances.sort(key=lambda x: x[0])
            nearest_keys = [y_key for _, y_key in distances[:num_neighbors]]
            for key in nearest_keys:
                x1_nearest, x2_nearest = self.hash_table.get(key, (None, None))
                nearest_x1.append(x1_nearest)
                nearest_x2.append(x2_nearest)
            x1_results.append(torch.tensor(nearest_x1))
            x2_results.append(torch.tensor(nearest_x2))
        x1_results_tensor = torch.stack(x1_results)
        x2_results_tensor = torch.stack(x2_results)
        return x1_results_tensor, x2_results_tensor

def negative_log_likelihood(flow_model, A, B):
    cartesian_output, log_det = flow_model(A, B)
    log_prob_base = flow_model.log_prob(A, B)
    nll_loss = -(log_prob_base + log_det).mean()
    return nll_loss, cartesian_output

def lipschitz_loss_with_jacobian(flow, data, lambda_lipschitz=LAMBDA_LIPSCHITZ):
    gradients = torch.autograd.functional.jacobian(lambda x: flow.forward(x)[0], data)
    norm_gradients = torch.norm(gradients, dim=(1, 2))
    lipschitz_penalty = torch.mean((norm_gradients - lambda_lipschitz).clamp(min=0.0))
    return lipschitz_penalty

def perturb_loss(output, A, B, manifold_output):
    A_perturbed = A + 0.01 * torch.randn_like(A)
    B_perturbed = B + 0.01 * torch.randn_like(B)
    manifold_A_perturbed = output.map_to_sphere(A_perturbed, B)
    manifold_B_perturbed = output.map_to_sphere(A, B_perturbed)
    geodesic_A = torch.norm(manifold_output - manifold_A_perturbed, dim=-1).mean()
    geodesic_B = torch.norm(manifold_output - manifold_B_perturbed, dim=-1).mean()
    return geodesic_A + geodesic_B

def geodesic_repulsion_loss(manifold_output, n=NUM_SAMPLES, sigma=SIGMA_REPULSION, r=1.0):
    batch_size, dim = manifold_output.shape
    manifold_output = manifold_output / torch.norm(manifold_output, dim=-1, keepdim=True)
    dot_products = torch.matmul(manifold_output, manifold_output.T)
    dot_products = torch.clamp(dot_products, -0.9999, 0.9999)
    geodesic_distances = torch.acos(dot_products)
    mask = torch.eye(batch_size, device=manifold_output.device).bool()
    geodesic_distances.masked_fill_(mask, float('inf'))
    repulsion_loss = torch.exp(-geodesic_distances / sigma).sum()
    return repulsion_loss

def refined_isoline_loss(model, A, B, manifold_output, num_points=10, epsilon=0.05, target_radius=0.2):
    central_point = model.map_to_sphere(A, B)
    perturbations = [B + epsilon * torch.randn_like(B) for _ in range(num_points)]
    manifold_outputs = [model.map_to_sphere(A, perturbed_B) for perturbed_B in perturbations]
    manifold_outputs = [output / torch.norm(output, dim=-1, keepdim=True) for output in manifold_outputs]
    isoline_loss = 0
    for perturbed_output in manifold_outputs:
        dot_product = torch.sum(central_point * perturbed_output, dim=-1)
        dot_product = torch.clamp(dot_product, -0.9999, 0.9999)
        geodesic_distance = torch.acos(dot_product)
        isoline_loss += (geodesic_distance - target_radius).pow(2).mean()
    return isoline_loss

def inversion_loss(A, B, A_inverted, B_inverted):
    A_loss = torch.mean((A - A_inverted)**2)
    B_loss = torch.mean((B - B_inverted)**2)
    return A_loss + B_loss

def inversion_loss_AB(A, B, A_inverted, B_inverted, autoendoder_balance=AUTOENCODER_BALANCE):
    A_loss = autoendoder_balance[0]*nn.HuberLoss()(A_inverted, A) 
    B_loss = autoendoder_balance[1]*nn.HuberLoss()(B_inverted, B)
    return A_loss, B_loss

def train_flow_minibatch(flow_model, data_A, data_B, 
                         batch_size=BATCH_SIZE, 
                         epochs=NUM_EPOCHS, 
                         lr=NF_LEARN_RATE, 
                         lambda_lipschitz=1.0, 
                         sigma=SIGMA_REPULSION, 
                         target_radius=0.2, 
                         patience=EARLY_STOPPING_PATIENCE, 
                         min_delta=EARLY_STOPPING_MIN_DELTA):
    dataset = TensorDataset(data_A, data_B)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    optimizer = optim.Adam(flow_model.parameters(), lr=lr)
    losses = []
    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        flow_model.train()
        epoch_loss = 0.0
        
        for batch_A, batch_B in dataloader:
            optimizer.zero_grad()
            manifold_output_raw, log_det = flow_model(batch_A, batch_B)
            log_det = torch.clamp(log_det, min=-1e9, max=1e9)
            manifold_output = manifold_output_raw.squeeze()

            if NFLOW_LOSS_DICT.get("nll_loss", 1.0) > 0:
                log_prob_base = flow_model.log_prob(batch_A, batch_B)
                log_prob_base = torch.clamp(log_prob_base, min=-1e9, max=1e9)
                nll_loss = -(log_prob_base + log_det).mean() * NFLOW_LOSS_DICT.get("nll_loss", 1.0)
            else:
                nll_loss = torch.tensor(0.0)

            lipschitz_loss = lipschitz_loss_with_jacobian(flow_model.flow, torch.cat([batch_A, batch_B], dim=-1), lambda_lipschitz) * NFLOW_LOSS_DICT.get("lipschitz_loss", 1.0) if NFLOW_LOSS_DICT.get("lipschitz_loss", 1.0) > 0 else torch.tensor(0.0)
            perturb_loss_value = perturb_loss(flow_model, batch_A, batch_B, manifold_output) * NFLOW_LOSS_DICT.get("perturb_loss", 1.0) if NFLOW_LOSS_DICT.get("perturb_loss", 1.0) > 0 else torch.tensor(0.0)
            geodesic_repulsion_loss_value = geodesic_repulsion_loss(manifold_output, sigma=sigma) * NFLOW_LOSS_DICT.get("geodesic_repulsion_loss", 1.0) if NFLOW_LOSS_DICT.get("geodesic_repulsion_loss", 1.0) > 0 else torch.tensor(0.0)
            isoline_loss_value = refined_isoline_loss(flow_model, batch_A, batch_B, manifold_output, target_radius=target_radius) * NFLOW_LOSS_DICT.get("isoline_loss", 1.0) if NFLOW_LOSS_DICT.get("isoline_loss", 1.0) > 0 else torch.tensor(0.0)

            A_inverted, B_inverted = flow_model.invert(manifold_output)
            A_recon_loss, B_recon_loss = inversion_loss_AB(batch_A, batch_B, A_inverted, B_inverted)
            inv_loss = (A_recon_loss + B_recon_loss) * NFLOW_LOSS_DICT.get("inv_loss", 1.0)

            total_loss = nll_loss + lipschitz_loss + perturb_loss_value + geodesic_repulsion_loss_value + isoline_loss_value + inv_loss
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(flow_model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += total_loss.item()

        avg_epoch_loss = epoch_loss / len(dataloader)
        losses.append(avg_epoch_loss)

        if avg_epoch_loss < best_loss - min_delta:
            best_loss = avg_epoch_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Total Loss = {total_loss.item()}, NLL Loss = {nll_loss.item()}, "
                  f"Lipschitz Loss = {lipschitz_loss.item()}, Perturb Loss = {perturb_loss_value.item()}, Inversion Loss = {inv_loss.item()} "
                  f"Geodesic Repulsion Loss = {geodesic_repulsion_loss_value.item()}, Isoline Loss = {isoline_loss_value.item()}")
            losses.append(total_loss.item())

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch} with loss {avg_epoch_loss}")
            break

        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Average Loss = {avg_epoch_loss}, Best Loss = {best_loss}")

    return losses

def initialize_weights_xavier(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

def initialize_weights_kaiming(model):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

def save_hash_mapping(hash_mapping, file_path='hash_mapping.pkl'):
    with open(file_path, 'wb') as f:
        pkl.dump(hash_mapping, f)
    print(f"Hash mapping saved to {file_path}")

def load_hash_mapping(file_path='hash_mapping.pkl'):
    with open(file_path, 'rb') as f:
        hash_mapping = pkl.load(f)
    print(f"Hash mapping loaded from {file_path}")
    return hash_mapping
