import os
import random
import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from tqdm import tqdm
import numpy as np
import timm
import argparse

from src.vit_inpainting import InpaintingViT
from src.inpainting_utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main():
    """
    Main Function
    """
    parser = argparse.ArgumentParser(description="Inpainting Experiment with ViT Pooling")
    parser.add_argument('--model_type', type=str, default='small_vit',
                        choices=['small_vit', 'base_vit'],
                        help="Type of ViT model: 'small_vit' or 'base_vit'.")
    parser.add_argument('--pooling', type=str, default='attention', 
                        choices=['cls', 'avg', 'sum', 'attention', 'weighted_avg', 'max'],
                        help="Pooling strategy to use.")
    parser.add_argument('--epochs', type=int, default=10, help="Number of training epochs.")
    parser.add_argument('--lr', type=float, default=1e-2, help="Learning rate.")
    parser.add_argument('--runs', type=int, default=2, help="Number of runs per pooling strategy.")
    parser.add_argument('--mask_size', type=int, default=32, help="Size of the masked patch.")
    parser.add_argument('--num_iterations', type=int, default=1, help="Number of iterations for svd pooling.")
    parser.add_argument('--dataset_name', type=str, default="OxfordIIITPet",
                        choices=["OxfordIIITPet", "OxfordFlowers102"], help="Name of dataset.")
    args = parser.parse_args()

    # Choose timm model based on model_type.
    if args.model_type == "small_vit":
        model_name = "vit_small_patch16_224"
    elif args.model_type == "base_vit":
        model_name = "vit_base_patch16_224"
    else:
        raise ValueError("Unsupported model type. Choose 'small_vit' or 'base_vit'.")

    dataset_name = args.dataset_name
    if dataset_name == "OxfordIIITPet":
        base_dataset = torchvision.datasets.OxfordIIITPet(
            root='./data', split='trainval', download=False, transform=transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
            ])
        )
    elif dataset_name == "OxfordFlowers102":
        base_dataset = torchvision.datasets.Flowers102(
            root='./data', split='train', download=False, transform=transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
            ])
        )
    full_dataset = InpaintingDatasetWrapper(base_dataset, mask_size=args.mask_size)
    total_len = len(full_dataset)
    train_len = int(0.8 * total_len)
    val_len = total_len - train_len
    train_dataset, val_dataset = random_split(full_dataset, [train_len, val_len])
    print(f"\n=== Dataset: {dataset_name} ===")
    print(f"Total images: {total_len} | Train: {train_len} | Val: {val_len}")
    
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2, collate_fn=custom_collate_fn)
    
    
    print(f"Pooling Strategy: {args.pooling.upper()} for {dataset_name}")
    val_loss_list = []
    for run_idx in range(args.runs):
        print(f"--- Run {run_idx+1}/{args.runs} ---")
        base_model = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=0,
            global_pool=None)
        base_model.to(device)

        for param in base_model.parameters():
            param.requires_grad = False
        
        model = InpaintingViT(
            base_model=base_model,
            pooling_strategy=args.pooling,
            mask_size=args.mask_size,
            num_iterations=args.num_iterations)
        
        # Set decoder final layer.
        out_dim = 3 * args.mask_size * args.mask_size
        model.decoder[-1] = nn.Linear(base_model.embed_dim, out_dim)
        model.to(device)
        
        optimizer = optim.Adam(model.decoder.parameters(), lr=args.lr)
        criterion = nn.MSELoss()
        for epoch in range(args.epochs):
            model.train()
            total_loss = 0.0
            total_samples = 0
            with tqdm(total=len(train_loader),
                        desc=f"[{args.pooling.upper()}|Run{run_idx+1}] Epoch {epoch+1}/{args.epochs}",
                        unit="batch") as pbar:
                for masked_img, original_img, mask_info_list in train_loader:
                    masked_img = masked_img.to(device)
                    original_img = original_img.to(device)
                    pred_patch_flat = model(masked_img)
                    
                    gt_patches = []
                    for i, (top, left, size) in enumerate(mask_info_list):
                        patch = original_img[i, :, top:top+size, left:left+size]
                        gt_patches.append(patch.flatten())
                    gt_patches = torch.stack(gt_patches).to(device)
                    
                    loss = criterion(pred_patch_flat, gt_patches)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    bs = masked_img.size(0)
                    total_loss += loss.item() * bs
                    total_samples += bs
                    pbar.set_postfix({"loss": f"{loss.item():.4f}"})
                    pbar.update(1)
            avg_loss = total_loss / total_samples
            print(f"Epoch {epoch+1}/{args.epochs} Training Loss: {avg_loss:.4f}")
        
        # Validation.
        model.eval()
        val_loss = 0.0
        val_samples = 0
        with torch.no_grad():
            for masked_img, original_img, mask_info_list in val_loader:
                masked_img = masked_img.to(device)
                original_img = original_img.to(device)
                pred_patch_flat = model(masked_img)
                gt_patches = []
                for i, (top, left, size) in enumerate(mask_info_list):
                    patch = original_img[i, :, top:top+size, left:left+size]
                    gt_patches.append(patch.flatten())
                gt_patches = torch.stack(gt_patches).to(device)
                loss = criterion(pred_patch_flat, gt_patches)
                bs = masked_img.size(0)
                val_loss += loss.item() * bs
                val_samples += bs
        avg_val_loss = val_loss / val_samples
        val_loss_list.append(avg_val_loss)
        print(f"Validation Loss (Run {run_idx+1}/{args.runs}) for pooling={args.pooling.upper()} on {dataset_name}: {avg_val_loss:.4f}\n")
    
    losses_t = torch.tensor(val_loss_list, dtype=torch.float)
    mean_val = losses_t.mean().item()
    std_val = losses_t.std().item()
    print(f"*** FINAL for pooling={args.pooling.upper()} on {dataset_name} *** -> Avg Val Loss: {mean_val:.4f}, Std: {std_val:.4f}\n")

if __name__ == "__main__":
    main()
