from torchvision import transforms
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import sys
sys.path.append('/home//work/doob_apps/hug')
from src.models.CT_model_predictor import RotationPredictorCNN
import os

import wandb

class RandomRotationWithLabel:
    """データにランダムな回転を加え、回転角度をラベルとして返すクラス"""
    
    def __init__(self, degrees):
        self.degrees = degrees

    def __call__(self, img):
        # -degrees から +degrees の範囲でランダムな回転角度を生成
        angle = random.uniform(-self.degrees, self.degrees)
        # imgをまず, PIL画像に変換
        img = transforms.ToPILImage()(img)
        # imgの(0,0)の色wを取得
        fillcolor = img.getpixel((0, 0))
        # 画像を回転させる
        rotated_img = img.rotate(angle, fillcolor=fillcolor, resample=Image.BICUBIC)
        # 回転後の画像をテンソルに変換
        rotated_img = transforms.ToTensor()(rotated_img)

        return rotated_img, angle

# データセットに使用する前処理
data_transforms = transforms.Compose([ 
    transforms.Resize((64, 64)),        # リサイズ
    transforms.ToTensor(),                # テンソル化
    # transforms.Normalize(mean=[0.5], std=[0.5]),  # 正規化
    RandomRotationWithLabel(degrees=45) # ランダム回転
])

class CTImageDataset(Dataset):
    """CT画像データセットにランダム回転を加えてラベルとするクラス"""
    def __init__(self, image_dir, transform=None, duplicate_num=1):
        self.transform = transform
        self.image_paths = [
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if not fname.startswith('.')  # Ignore files starting with '.'
            and self.is_image_file(fname)  # Ensure it's an image file
        ]
        # 同じ画像を2回出現させる
        self.image_paths = self.image_paths * duplicate_num

    def is_image_file(self, filename):
        valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')
        ext_flag = filename.lower().endswith(valid_extensions)
        # 除外する一部のファイル名
        exclude_filenames = ['0000'+str(i)+'.png' for i in range(29, 59)]
        return ext_flag and (filename not in exclude_filenames)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # 画像を読み込む（PILフォーマットなどで読み込む前提）
        image = Image.open(self.image_paths[idx]).convert('L')
        
        if self.transform:
            # ランダム回転とラベル生成
            image, angle = self.transform(image)
        return image, torch.tensor([angle], dtype=torch.float32)

def train_model(model, dataloader, val_dataloader, criterion, optimizer, num_epochs, batch_size, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        data_count = 0
        for inputs, labels in dataloader:
            # デバイスに転送
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 勾配の初期化
            optimizer.zero_grad()

            # 順伝播
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # 逆伝播
            loss.backward()
            optimizer.step()

            # 損失の集計
            running_loss += loss.item()
            data_count += batch_size

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/data_count}")
        # wandbにログを送信
        wandb.log({"Training Loss": running_loss/data_count})
        # epochもwandbにログを送信
        wandb.log({"Epoch": epoch+1})
        eval_model(model, val_dataloader, criterion, batch_size, device)

    print("Training complete")

def eval_model(model, val_dataloader, criterion, batch_size, device):
    model.eval()
    with torch.no_grad():
        # 1バッチ分のデータを取得
        loss_sum = 0
        data_count = 0
        for i, (inputs, labels) in enumerate(val_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            print("Predicted angles:", outputs)
            print("True angles:", labels)
            # 1バッチ分の予測結果を表示
            # 画像と予測結果をwandbに送信
            if i < 4:
                for i, img in enumerate(inputs):
                    wandb.log({"image": wandb.Image(img), 
                               "Predicted angles": outputs[i].cpu().numpy(),
                               "True angles": labels[i].cpu().numpy()
                               })
            else:
                break
            # lossを計算
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            data_count += batch_size
        print(f"Validation Loss: {loss_sum/data_count}")
        # wandbにログを送信
        wandb.log({"Validation Loss": loss_sum/data_count})

if __name__ == '__main__':
    # デバイス
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = {
        "batch_size": 2,
        "num_epochs": 10,
        "learning_rate": 0.001
    }
    batch_size = config['batch_size']
    num_epochs = config['num_epochs']
    learning_rate = config['learning_rate']
    # seedの固定
    torch.manual_seed(0)
    random.seed(0)

    # データセットのパス
    image_dir = 'hug/data/HeadCT'
    # データセットの作成
    dataset = CTImageDataset(image_dir, duplicate_num=4, transform=data_transforms)
    # 黒い画像は削除
    dataset = [data for data in dataset if data[0].max() > 0.2]
    # データローダーの作成
    # trainとvalidationに分割
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    # モデルのインスタンス化
    model = RotationPredictorCNN().to(device)
    # 損失関数
    criterion = torch.nn.MSELoss()
    # オプティマイザ
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # デバイスの設定
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # wandbの初期化
    wandb.init(project='CTImageRotation', config=config)
    # モデルの学習
    train_model(model, train_dataloader, val_dataloader, criterion, optimizer, num_epochs=num_epochs, batch_size=batch_size, device=device)
    # モデルの評価
    # eval_model(model, val_dataloader, criterion, batch_size, device)
    print("Evaluation complete")
    # モデルの保存
    # 保存先のパスを指定
    import datetime
    now = datetime.datetime.now()
    now_str = now.strftime("%Y%m%d_%H%M")
    # ディレクトリ
    dir = 'hug/src/preference/CT_predictor_'+now_str
    filename = os.path.join(dir, 'rotation_predictor.pth')
    os.makedirs(dir, exist_ok=True)
    torch.save(model.state_dict(), filename)
    # datasetを'hug/src/data/CTImageDataset.pth'に保存
    dir_data = 'hug/src/data'
    # 日付を含めたディレクトリ
    dir = os.path.join(dir_data, now_str)
    dataset_filename_train = os.path.join(dir, 'CTImageDataset_train.pth')
    dataset_filename_val = os.path.join(dir, 'CTImageDataset_val.pth')
    os.makedirs(dir, exist_ok=True)
    torch.save(train_dataset, dataset_filename_train)
    torch.save(val_dataset, dataset_filename_val)