from torchvision import transforms
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import sys
sys.path.append('/home/***/work/doob_apps/hug')
from src.models.CT_model_predictor import RotationPredictorCNN
from src.models.CT_autoencoder import Autoencoder, SmallAutoencoder, Autoencoder32
from src.preference.CT_learn_preference import CTImageDataset, RandomRotationWithLabel
import os
import wandb
import torch.nn.functional as F

import datetime
from tqdm import tqdm

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup

def train_diffusion(autoencoder, unet, train_dataloader, noise_scheduler, latent_size, criterion, device, config, dir):
    # Training loop
    learning_rate = config["learning_rate"]
    autoencoder.eval()
    unet.train()
    optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
    num_epochs = config["num_epochs"]

    for epoch in tqdm(range(num_epochs)):
        loss_sum = 0
        data_count = 0
        for step, encoded_images in enumerate(train_dataloader):
            batch_size = encoded_images.shape[0]
            if step % 1000 == 0:
                print(f"Step: {step}")
            # encoded_imagesはlistになっているので, tensor.stackで結合
            # encoded_images = torch.stack(encoded_images).to(device)
            encoded_images = encoded_images.view(batch_size, 1, latent_size, latent_size).to(device)

            if encoded_images.shape != (batch_size, 1, latent_size, latent_size):
                raise ValueError(f"Unexpected input shape: {encoded_images.shape}")

            # Sample noise to add to the images
            noise = torch.randn(encoded_images.shape).to(encoded_images.device)
            bs = encoded_images.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (bs,), device=encoded_images.device
            ).long()

            # encoded_images.requires_grad = True

            optimizer.zero_grad()

            # Add noise to the clean images according to the noise magnitude at each timestep
            noisy_images = noise_scheduler.add_noise(encoded_images, noise, timesteps)

            # Get the model prediction
            noise_pred = unet(noisy_images, timesteps, return_dict=False)[0]
            assert noise_pred.shape == noise.shape

            # Calculate the loss
            loss = criterion(noise_pred, noise)
            loss.backward()
            optimizer.step()

            loss_sum += loss.item()
            data_count += batch_size
            # Update the model parameters with the optimizer

        print(f"Epoch:{epoch+1}, loss-Epoch: {loss_sum/data_count}")
        wandb.log({"Epoch": epoch+1, "Loss-Epoch": loss_sum/data_count})
        # Save the model
        if (epoch+1) % 5 == 0:
            filename = os.path.join(dir, 'pipeline_'+str(epoch+1))
            image_pipe = DDPMPipeline(unet=unet, scheduler=noise_scheduler)
            image_pipe.save_pretrained(filename)

def main():
    # デバイス
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    # Initialize the autoencoder
    autoencoder_size = '32'
    if autoencoder_size == 'large':
        autoencoder = Autoencoder()
    elif autoencoder_size == '32':
        autoencoder = Autoencoder32()
    # load the pre-trained model
    autoencoder_path = "/home/***/work/doob_apps/hug/src/pretrain/autoencoder/20240918_2031_32_0.2/autoencoder_epoch_19.pth"
    autoencoder.load_state_dict(torch.load(autoencoder_path))
    autoencoder.eval()
    autoencoder.to(device)

    # datasets
    # dataset_dir = "/home/***/work/doob_apps/hug/src/data/20240915_2304"
    # train_dataset = torch.load(os.path.join(dataset_dir, 'CTImageDataset_train.pth'))

    # データセットのパス
    image_dir = 'hug/data/HeadCT'

    # データセットに使用する前処理
    data_transforms = transforms.Compose([ 
        transforms.Resize((64, 64)),        # リサイズ
        transforms.ToTensor(),                # テンソル化
        # transforms.Normalize(mean=[0.5], std=[0.5]),  # 正規化
        RandomRotationWithLabel(degrees=45) # ランダム回転
    ])

    dup_num = 2
    image_size = 64
    latent_size = 32

    we_have_dataset_path = None # "/home/***/work/doob_apps/hug/src/datasets/CT_diffusion/20240918_1316/encoded_dataset.pth"
    if we_have_dataset_path is None:
        # データセットの作成
        train_dataset = CTImageDataset(image_dir, transform=data_transforms, duplicate_num=dup_num)
        for i in range(3):
            print(train_dataset[i][0].shape)
        # 真っ黒な画像があるので, それを除外
        train_dataset = [data for data in train_dataset if data[0].max() > 0.2]
        # datasetにencoderをかけて, 全てのデータを前処理
        encoded_dataset = [
            autoencoder.encoder(data[0].view(1,1,image_size, image_size).to(device)).detach() for data in train_dataset
        ]
        encoded_dataset = torch.stack(encoded_dataset).to(device)
    else:
        encoded_dataset = torch.load(we_have_dataset_path).to(device)

    print("encoded_dataset.shape:", encoded_dataset.shape)
    print("data encoded")

    # train_size = int(0.95 * len(dataset))
    # val_size = len(dataset) - train_size
    # train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    # train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    # val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    # set seed

    random.seed(0)
    torch.manual_seed(0)

    # データローダーの作成
    batch_size = 8 #2??
    n_epochs = 15 # 20くらいやってもいいかも
    learning_rate = 0.0001
    lr_warmup_steps = 500
    num_train_timesteps = 1000
    block_out_channels = 256
    num_layers = 3
    num_blocks = 2

    config = {
        "batch_size": batch_size,
        "num_epochs": n_epochs,
        "learning_rate": learning_rate,
        "lr_warmup_steps": lr_warmup_steps,
        "autoencoder_path": autoencoder_path,
        "num_train_timesteps": num_train_timesteps,
        "block_out_channels": block_out_channels,
        "num_layers": num_layers
    }
    # set seed
    random.seed(0)
    torch.manual_seed(0)

    # ディレクトリの作成
    now = datetime.datetime.now()
    now_str = now.strftime("%Y%m%d_%H%M")
    # ディレクトリ
    dir = 'hug/src/pretrain/CT_diffusion/'+now_str
    dataset_dir = os.path.join('hug/src/datasets/CT_diffusion', now_str)

    train_dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=batch_size, shuffle=True)

    # encoded_datasetを保存
    encoded_dataset_filename = os.path.join(dataset_dir, 'encoded_dataset.pth')
    if we_have_dataset_path is None:
        os.makedirs(dataset_dir, exist_ok=True)
        torch.save(encoded_dataset, encoded_dataset_filename)

    # DDPMPipelineをload
    pretrained_pipeline = None # '/home/***/work/doob_apps/hug/src/pretrain/CT_diffusion/20240918_0016/my_pipeline'

    if pretrained_pipeline is None:
        # Create a model
        if num_blocks == 1:
            unet = UNet2DModel(
                sample_size=latent_size,  # the target image resolution
                in_channels=1,  # the number of input channels, 3 for RGB images
                out_channels=1,  # the number of output channels
                layers_per_block=num_layers,  # how many ResNet layers to use per UNet block
                block_out_channels=(block_out_channels,),  # More channels -> more parameters
                down_block_types=(
                    "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                    # "AttnDownBlock2D"
                ),
                up_block_types=(
                    "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                    # "AttnUpBlock2D"
                ),
            )
        elif num_blocks == 2:
            unet = UNet2DModel(
                sample_size=latent_size,  # the target image resolution
                in_channels=1,  # the number of input channels, 3 for RGB images
                out_channels=1,  # the number of output channels
                layers_per_block=num_layers,  # how many ResNet layers to use per UNet block
                block_out_channels=(block_out_channels//4,block_out_channels),  # More channels -> more parameters
                down_block_types=(
                    "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                    "AttnDownBlock2D"
                ),
                up_block_types=(
                    "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                    "AttnUpBlock2D"
                ),
            )
        elif num_blocks == 3:
            unet = UNet2DModel(
                sample_size=latent_size,  # the target image resolution
                in_channels=1,  # the number of input channels, 3 for RGB images
                out_channels=1,  # the number of output channels
                layers_per_block=num_layers,  # how many ResNet layers to use per UNet block
                block_out_channels=(block_out_channels//4,block_out_channels//2,block_out_channels),  # More channels -> more parameters
                down_block_types=(
                    "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
                    "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                    "AttnDownBlock2D"
                ),
                up_block_types=(
                    "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                    "AttnUpBlock2D",
                    "UpBlock2D"
                ),
            )
        unet.to(device);
        # Set the noise scheduler
        noise_scheduler = DDPMScheduler(
            num_train_timesteps=num_train_timesteps, beta_schedule="squaredcos_cap_v2"
        )

    else:
        image_pipe = DDPMPipeline.from_pretrained(pretrained_pipeline)
        # configに, pretrained_pipelineを追加
        config["pretrained_pipeline"] = pretrained_pipeline
        # Load the pre-trained model
        image_pipe = DDPMPipeline.from_pretrained(pretrained_pipeline)
        unet = image_pipe.unet
        unet.to(device)
        noise_scheduler = image_pipe.scheduler
    
    # wandbの初期化
    wandb.init(project='CTImage-diffusion', config=config)
    wandb.watch(autoencoder)
    wandb.watch(unet)

    train_diffusion(autoencoder, unet, train_dataloader, noise_scheduler, latent_size, \
                    torch.nn.MSELoss(), device, config, dir)

    image_pipe = DDPMPipeline(unet=unet, scheduler=noise_scheduler)

    # モデルの保存
    filename = os.path.join(dir, 'my_pipeline')
    image_pipe.save_pretrained(filename)

    pipeline_output = image_pipe(batch_size=4)

    for i in range(4):
        image_i = pipeline_output.images[i]
        # pil to tensor
        image_i = transforms.ToTensor()(image_i).to(device)
        image_i = image_i.view(1, 1, latent_size, latent_size)
        # decode the image
        decoded_image_i = autoencoder.decoder(image_i).squeeze(0)
        wandb.log({"decoded_image": wandb.Image(decoded_image_i), "encoded_image": wandb.Image(image_i)})
    
    wandb.finish()

if __name__ == "__main__":
    main()