import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms

import torch.nn as nn

from diffusers import DDPMScheduler, UNet2DModel
import json

class ModelPotential(nn.Module):
    def __init__(self, device="cuda"):
        # Create a model
        super(ModelPotential, self).__init__()
        self._load_config()
        self.unet = UNet2DModel(
            sample_size=self.image_size,  # the target image resolution
            in_channels=3,  # the number of input channels, 3 for RGB images
            out_channels=3,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=self.blocks,  # More channels -> more parameters
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D",
                "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            ),
            up_block_types=(
                "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",
                "UpBlock2D",  # a regular ResNet upsampling block
            ),
        ).to(device)

        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 出力を (batch_size, channels, 1, 1)にする
            nn.Flatten(),  # 出力を (batch_size, channels)にする
            nn.Linear(3, 1),  # 全結合層を追加し、(batch_size, 1)に変換
        ).to(device)

    def forward(self, x):
        # UNetを通して特徴量を取得
        features = self.unet(x, 0)
        # 特徴量を全結合層に通す
        output = self.fc(features.sample)  # 出力は (batch_size, 1)の形状
        return output.squeeze()  

    def _load_config(self):
        # Set image size and batch size
        # load config
        path = "/home//work/doob_apps/hug/configs/configs.json"
        with open(path, "r") as f:
            config = json.load(f)
        self.image_size = config["image_size"]
        self.blocks = (64, 128, 256)
    