import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torchvision.utils import save_image

from utils import ExponentialMovingAverage
from data.load_dataset import load_dataset
from global_config import ROOT_DIRECTORY
import argparse
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader

class CVaRDataset(Dataset):
    def __init__(self, images, importance_weights, normalizing_constant):
        """
        Args:
            images (Tensor): A tensor of images with shape (N, C, H, W)
            values (Tensor): A tensor of values with shape (N,)
            constant (float or int): A constant value
        """
        self.images = images
        self.importance_weights = importance_weights
        self.normalizing_constant = normalizing_constant

        self.weights = self.normalizing_constant / self.importance_weights
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        weight = self.weights[idx]
        #normalizing_constant = self.normalizing_constant[idx]
        return image, weight #, normalizing_constant

def get_average_loss(loss_function, dataloader, device="cuda:0"):
    losses = 0.0
    num_samples = 0

    with torch.no_grad():
        for images in tqdm(dataloader, desc="Evaluating"):
            if isinstance(images, (list, tuple)):
                images = images[0]
            images = images.to(device)

            loss = loss_function(images)
            losses += torch.sum(loss).item()
            num_samples += len(images)

    avg_loss = losses / num_samples
    print(f"Average loss: {avg_loss:.5f}")
    return avg_loss
def get_weighted_dataset(dataloader, loss_function=None, normalizing_constant=None, device="cuda:0", batch_size=32, algorithm_name=""):
    losses = 0.0
    num_samples = 0
    all_images = []
    all_losses = []

    # If loss_function is None, use a constant function
    if loss_function is None:
        loss_function = lambda x: torch.ones(x.shape[0], device=device)
        normalizing_constant = 1.0

    with torch.no_grad():
        for batch in dataloader:
            images = batch[0] if isinstance(batch, (list, tuple)) else batch
            images = images.to(device)

            loss = loss_function(images)  # Compute loss for each image
            losses += torch.sum(loss).item()
            num_samples += images.shape[0]

            all_images.append(images.cpu())  # Store images in CPU memory
            all_losses.append(loss.cpu())  # Store losses for importance weighting

    avg_loss = losses / num_samples
    print(f"Average loss: {avg_loss:.5f}")

    # Stack all tensors
    all_images = torch.cat(all_images, dim=0)
    all_losses = torch.cat(all_losses, dim=0)
    normalizing_constant = torch.ones_like(all_losses).to(device) * normalizing_constant

    if algorithm_name == "unweighted_is_cvar" or algorithm_name == "cvar":
        all_losses = torch.ones_like(all_losses)
        normalizing_constant = torch.ones_like(normalizing_constant)
    elif algorithm_name == "reversed_is_cvar":
        temp = all_losses
        all_losses = normalizing_constant
        normalizing_constant = temp

    # Create CVaR dataset
    cvar_dataset = CVaRDataset(all_images, all_losses.to(device), normalizing_constant.to(device))


    return cvar_dataset