import torch
from torch import device
from itertools import cycle

from objectives import (coral_loss, geo_adapt, ddc_mmd, minimal_entropy_correlation_alignment, central_moment_discrepancy,
                        higher_order_moment_matching, geo_log_euclidean)


def baseline(source_tr_x, f, g, opt, device):

    f.train()
    g.train()

    total_loss = 0.0

    for i, x in enumerate(source_tr_x):
        opt.zero_grad()

        x = x[0].to(device)

        z = f(x)
        x_hat = g(z)

        loss = torch.nn.functional.mse_loss(x_hat, x)
        loss.backward()
        opt.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(source_tr_x)

    return avg_loss


def adaptation(source_tr_x, target_tr_x, f, g, opt, adapt_loss_mode, lamda, highest_moment, geo_adapt_metric, epoch, initial_epochs, device):
    """
    Conducts the domain adaptation process by jointly optimizing reconstruction
    and domain adaptation losses. The method processes both source and target
    domains, computes losses based on the specified adaptation mode, and minimizes
    them over batches of data.

    Args:
        source_tr_x: Source domain training data, batched in an iterable format.
        target_tr_x: Target domain training data, batched in an iterable format.
        f: The feature extractor (encoder) neural network model.
        g: The decoder neural network model.
        opt: The optimizer used for updating model parameters.
        adapt_loss_mode: Mode of adaptation loss computation such as 'coral',
            'geo_adapt', or 'homm'.
        lamda: Weight factor applied to the domain adaptation loss when combining
            losses.
        highest_moment: Specifies the highest moment to consider for domain
            adaptation losses where applicable.
        geo_adapt_metric: Metric used in the `geo_adapt` adaptation mode.
        epoch: The current training epoch. Starts from 1.
        initial_epochs: The number of epochs for initial training during which only
            reconstruction loss is optimized.
        device: The device to use for training (e.g., 'cpu' or 'cuda').

    Returns:
        tuple: A tuple containing the average reconstruction loss and average
        adaptation loss for the epoch.
    """
    f.train()
    g.train()

    total_recon_loss = 0.0
    total_adapt_loss = 0.0

    for x_s, data_t in zip(source_tr_x, cycle(target_tr_x)):
        opt.zero_grad()

        x_s = x_s[0].to(device)
        x_t = data_t[0].to(device)

        # Encode both source and target
        z_s = f(x_s)
        z_t = f(x_t)

        # Decode the source data
        x_s_hat = g(z_s)

        # Reconstruction losses
        recon_loss = torch.nn.functional.mse_loss(x_s_hat, x_s)

        # Combined features
        if epoch > initial_epochs:
            z_combined = torch.cat([z_s, z_t], dim=0)

            if adapt_loss_mode == 'ddc':
                adapt_loss = ddc_mmd(z_combined, len(x_s))

            elif adapt_loss_mode == 'coral':
                adapt_loss = coral_loss(z_combined, len(x_s))

            elif adapt_loss_mode == 'log_coral':
                adapt_loss = minimal_entropy_correlation_alignment(z_combined, len(x_s))

            elif adapt_loss_mode == 'cmd':
                adapt_loss = central_moment_discrepancy(z_combined, len(x_s), n_moments=highest_moment)

            elif adapt_loss_mode == 'homm':
                adapt_loss = higher_order_moment_matching(z_combined, len(x_s), order=highest_moment)

            elif adapt_loss_mode == 'geo_adapt':
                adapt_loss = geo_adapt(z_combined, len(x_s), geo_adapt_metric)

            elif adapt_loss_mode == 'geo_log_euclidean':
                adapt_loss = geo_log_euclidean(z_combined, len(x_s))

            else:
                raise ValueError(f'Unsupported domain adaptation mode: {adapt_loss_mode}')

            loss = recon_loss + (lamda * adapt_loss)

        else:
            loss = recon_loss
            adapt_loss = torch.tensor(0.0)

        loss.backward()
        opt.step()

        total_recon_loss += recon_loss.item()
        total_adapt_loss += adapt_loss.item()

    avg_recon_loss = total_recon_loss / len(source_tr_x)
    avg_adapt_loss = total_adapt_loss / len(source_tr_x)

    return avg_recon_loss, avg_adapt_loss, lamda


def evaluation(f, g, data_loader, device, return_images=False, num_images_to_return=6):
    """
    Evaluates the model on the given data loader.

    Args:
        f: Encoder model
        g: Decoder model
        data_loader: DataLoader containing the evaluation data
        device: Device to run evaluation on
        return_images: If True, returns sample images for visualization
        num_images_to_return: Number of images to return if return_images is True

    Returns:
        recon_error: Mean squared error between reconstructed and ground truth images
        reconstructed_img: NumPy array of reconstructed images (only if return_images=True)
        ground_truth_img: NumPy array of ground truth images (only if return_images=True)
    """
    f.eval()
    g.eval()

    with torch.no_grad():
        total_mse = 0.0
        num_samples = 0

        # For visualization if needed
        reconstructed_list = []
        ground_truth_list = []
        
        for batch_idx, batch in enumerate(data_loader):
            if len(batch) == 1:
                # Single element case: images serve as both input and ground truth
                img_batch = batch[0].to(device)
                gt_batch = img_batch.clone().detach()
                reconstructed_batch = g(f(img_batch))

            elif len(batch) == 2:
                # Two elements case: the first element is input, the second is ground truth
                img_batch = batch[0].to(device)
                gt_batch = batch[1].to(device)
                reconstructed_batch = g(f(img_batch))

            # Calculate batch MSE and accumulate
            batch_mse = torch.nn.functional.mse_loss(reconstructed_batch, gt_batch, reduction='sum')

            total_mse += batch_mse.item()
            num_samples += img_batch.numel()

            # Store sample images for visualization if needed
            if return_images and batch_idx == 0:  # Only store from first batch
                # Limit the number of images to store
                num_to_store = min(num_images_to_return, img_batch.size(0))
                reconstructed_list.append(reconstructed_batch[:num_to_store].cpu())
                ground_truth_list.append(gt_batch[:num_to_store].cpu())

        # Calculate average MSE
        avg_mse = total_mse / num_samples if num_samples > 0 else 0.0

        if return_images:
            # Concatenate the stored images (only a small sample)
            reconstructed_img = torch.cat(reconstructed_list, dim=0).numpy()
            ground_truth_img = torch.cat(ground_truth_list, dim=0).numpy()
            return avg_mse, reconstructed_img, ground_truth_img
        else:
            return avg_mse, None, None
