from torchvision import transforms
import random
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from torch.utils.data import Dataset
import sys
sys.path.append('/home/***/work/doob_apps/hug')
from src.models.CT_model_predictor import RotationPredictorCNN
from src.models.CT_autoencoder import Autoencoder, SmallAutoencoder, Autoencoder32
from src.preference.CT_learn_preference import CTImageDataset, RandomRotationWithLabel

import datetime

import os, json

import wandb

def train_autoencoder(autoencoder, train_dataloader, val_dataloader, criterion, device, config, model_dir):
    # Hyperparameters
    n_epochs = config["num_epochs"]
    learning_rate = config["learning_rate"]
    batch_size = config["batch_size"]
    noise_std = config["noise_std"]

    # Define loss function and optimizer
    optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)

    autoencoder.train()
    # Training loop
    for epoch in range(n_epochs):
        loss_epoch = 0
        data_count = 0
        for i, data in enumerate(train_dataloader):
            inputs, _ = data
            inputs = inputs.to(device)
            # add noiseto the input
            noise = torch.randn(inputs.shape).to(inputs.device) * noise_std
            noised_inputs = inputs + noise
            optimizer.zero_grad()
            outputs = autoencoder(noised_inputs, noise_std)
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()
            loss_epoch += loss.item()
            data_count += batch_size
        print(f'Epoch {epoch}, Iteration {i}, Loss: {loss.item()/data_count}')
        wandb.log({"Loss": loss.item()/data_count})
        wandb.log({"Epoch": epoch})
        eval_autoencoder(autoencoder, val_dataloader, criterion, device, config)
        model_path = os.path.join(model_dir, 'autoencoder_epoch_'+str(epoch)+'.pth')
        torch.save(autoencoder.state_dict(), model_path)

    print("Training complete")

def eval_autoencoder(autoencoder, val_dataloader, criterion, device, config):
    batch_size = config["batch_size"]
    noise_std = config["noise_std"]
    autoencoder.eval()
    with torch.no_grad():
        loss_sum = 0
        data_count = 0
        for i, data in enumerate(val_dataloader):
            inputs, _ = data
            inputs = inputs.to(device)
            noise = torch.randn(inputs.shape).to(inputs.device) * noise_std
            if i % 2 == 0:
                noised_inputs = inputs + noise
                autoencoder.eval()
            else:
                noised_inputs = inputs
                autoencoder.train()
            with torch.no_grad(): 
                outputs = autoencoder(noised_inputs)
                encoded = autoencoder.encoder(noised_inputs)

            encoded += torch.randn(encoded.shape).to(encoded.device) * noise_std

            assert outputs.shape == inputs.shape
            loss = criterion(outputs, inputs)
            loss_sum += loss.item()
            data_count += batch_size
            if i == 0:
                for j, img in enumerate(inputs):
                    wandb.log({"image": wandb.Image(img), 
                               "Encoded": wandb.Image(encoded[j]),
                               "Reconstructed": wandb.Image(outputs[j])
                               })
            if i == 1:
                for j, img in enumerate(inputs):
                    wandb.log({"image": wandb.Image(img), 
                               "Encoded": wandb.Image(encoded[j]),
                               "Reconstructed": wandb.Image(outputs[j])
                               })

            
        print(f'Validation Loss: {loss_sum/data_count}')
        wandb.log({"Validation Loss": loss_sum/data_count})
    

def main():
    # datasets
    # dataset_dir = "/home/***/work/doob_apps/hug/src/data/20240915_2304"

    # データセットのパス
    image_dir = 'hug/data/HeadCT'

    # データセットに使用する前処理
    data_transforms = transforms.Compose([ 
        transforms.Resize((64, 64)),        # リサイズ
        transforms.ToTensor(),                # テンソル化
        # transforms.Normalize(mean=[0.5], std=[0.5]),  # 正規化
        RandomRotationWithLabel(degrees=45) # ランダム回転
    ])

    dup_num = 4 # 2
    batch_size = 8 # 2
    n_epochs = 20 # 20
    learning_rate = 0.0001 # 0.0005
    noise_std = 0.2

        # Initialize the autoencoder
    small_or_large = '32'
    if small_or_large == 'large':
        autoencoder = Autoencoder()
    elif small_or_large == '32':
        autoencoder = Autoencoder32()
    else:
        autoencoder = SmallAutoencoder()

    config = {
        "batch_size": batch_size,
        "num_epochs": n_epochs,
        "learning_rate": learning_rate,
        "duplicate_num": dup_num,
        "noise_std": noise_std
    }

    # データセットの作成
    dataset = CTImageDataset(image_dir, transform=data_transforms, duplicate_num=dup_num)

    # set seed
    random.seed(0)
    torch.manual_seed(0)

    train_size = int(0.95 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    now= datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M')
    # モデルの保存先
    model_dir = os.path.join("hug/src/pretrain/autoencoder", now_str+"_"+small_or_large+"_"+str(noise_std))
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, 'autoencoder.pth')
    config_path = os.path.join(model_dir, 'config.json')
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=4)

    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    # デバイス
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    autoencoder.to(device)
    # wandb
    wandb.init(project='CT_autoencoder',
               config=config
    )
    wandb.watch(autoencoder)

    # Define loss function
    criterion = nn.MSELoss()  # Use MSE loss for reconstruction
    # train
    train_autoencoder(autoencoder, train_dataloader, val_dataloader, criterion, device, config, model_dir)
    # モデルの保存
    torch.save(autoencoder.state_dict(), model_path)
    print(f"Model saved at {model_path}")
    # eval
    eval_autoencoder(autoencoder, val_dataloader, criterion, device, config)
    # finish
    wandb.finish()

if __name__ == '__main__':
    main()