import os
import pandas as pd
import torch
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import warnings
from torch.optim.lr_scheduler import CosineAnnealingLR

from model_factory_mae import get_mae_vit_tiny_encoder, MAEDecoder, MaskedAutoencoderViT

warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('high')


class UnlabeledImageDataset(Dataset):
    def __init__(self, dataframe, image_size=224):
        self.data = dataframe.reset_index(drop=True)
        self.transform = A.Compose([
            A.Resize(image_size, image_size),
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.8),
            A.ToFloat(max_value=255.0),
            ToTensorV2(),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data.iloc[index]['Path']
        try:
            image = Image.open(img_path).convert("RGB")
            image_np = np.array(image)
            return self.transform(image=image_np)['image']
        except Exception as e:
            print(f"Warning: Could not load image at {img_path}. Error: {e}. Skipping.")
            return self.__getitem__((index + 1) % len(self))


def pretrain_mae(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print("--- Preparing Data for MAE Pre-training from CSV file ---")
    try:
        all_images_df = pd.read_csv(config['data_csv'])
        print(f"Successfully loaded {len(all_images_df)} image paths from {config['data_csv']}")
    except FileNotFoundError:
        print(f"Error: CSV file not found at '{config['data_csv']}'. Please run the create_filepath_csv.py script first.")
        return

    dataset = UnlabeledImageDataset(dataframe=all_images_df, image_size=config['img_size'])
    data_loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], pin_memory=True, drop_last=True)

    # --- 模型 ---
    print("--- Initializing MAE Model ---")
    encoder = get_mae_vit_tiny_encoder()
    decoder = MAEDecoder()
    model = MaskedAutoencoderViT(encoder, decoder).to(device)
    
    if config['compile']:
        print("--- Compiling Model ---")
        model = torch.compile(model)

    optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=0.05, betas=(0.9, 0.95))
    scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=1e-6)

    # --- 训练循环 ---
    print("--- Starting MAE Pre-training ---")
    for epoch in range(1, config['epochs'] + 1):
        model.train()
        total_loss = 0
        pbar = tqdm(data_loader, desc=f"Pretrain Epoch {epoch}/{config['epochs']}")
        for imgs in pbar:
            imgs = imgs.to(device)
            optimizer.zero_grad()
            loss, _, _ = model(imgs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{scheduler.get_last_lr()[0]:.6f}")
        
        scheduler.step()
        
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch} Average Loss: {avg_loss:.6f}")
        
        if epoch % config['save_freq'] == 0 or epoch == config['epochs']:
            encoder_to_save = model.module.encoder if hasattr(model, 'module') else model.encoder
            save_path = os.path.join(config['output_dir'], f'mae_vit_tiny_encoder_epoch_{epoch}.pth')
            torch.save(encoder_to_save.state_dict(), save_path)
            print(f"Encoder checkpoint saved to {save_path}")

if __name__ == '__main__':
    
    config = {
        # 1.包含所有图片路径的CSV文件
        "data_csv": "pretrain.csv",
        
        # 2. 输出目录
        "output_dir": "./mae_pretrained_new",
        
        # 3. 训练超参数
        "epochs": 200,
        "batch_size": 512,
        "lr": 1.5e-4,
        "num_workers": 16,
        "img_size": 224,
        "save_freq": 20,
        
        # 4. 是否使用 torch.compile
        "compile": True,
    }
    
    # ==============================================================================

    os.makedirs(config['output_dir'], exist_ok=True)
    pretrain_mae(config)