import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import time
from transformers import ViTModel
from utilies import normalize_matrix

# Import necessary components from your existing implementation
from vit import ResponseGTImageDataset, transform, target_transform, SSIMLoss, RidgeLoss

# Define the ViTUNetForImageReconstruction model
class ViTUNetForImageReconstruction(nn.Module):
    def __init__(self, pretrained_model_name):
        super().__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)
        
        # Define feature dimensions - these would need to be adjusted based on the ViT model
        self.hidden_dim = 768  # For base ViT model
        
        # Adaptive pooling to get spatial features of appropriate size
        self.adaptive_pool1 = nn.AdaptiveAvgPool2d((60, 60))  # For early features
        self.adaptive_pool2 = nn.AdaptiveAvgPool2d((30, 30))  # For middle features
        self.adaptive_pool3 = nn.AdaptiveAvgPool2d((15, 15))  # For late features
        self.adaptive_pool_final = nn.AdaptiveAvgPool2d((15, 15))  # For final features
        
        # Skip connection processing for each level
        self.skip_conn1 = nn.Conv2d(self.hidden_dim, 128, kernel_size=1)  # For 60x60 features -> 120x120 upsampling
        self.skip_conn2 = nn.Conv2d(self.hidden_dim, 256, kernel_size=1)  # For 30x30 features -> 60x60 upsampling
        self.skip_conn3 = nn.Conv2d(self.hidden_dim, 512, kernel_size=1)  # For 15x15 features -> 30x30 upsampling
        
        # Decoder with skip connections
        # First upsampling block (15x15 -> 30x30)
        self.up1 = nn.Sequential(
            nn.Upsample(size=(30, 30), mode='bilinear', align_corners=True),
            nn.Conv2d(self.hidden_dim, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        # Second upsampling block (30x30 -> 60x60)
        self.up2 = nn.Sequential(
            nn.Upsample(size=(60, 60), mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Third upsampling block (60x60 -> 120x120)
        self.up3 = nn.Sequential(
            nn.Upsample(size=(120, 120), mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Final upsampling blocks
        self.up4 = nn.Sequential(
            nn.Upsample(size=(240, 240), mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(size=(449, 449), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def extract_intermediate_features(self, pixel_values):
        """Extract features from different layers of the ViT model"""
        # Get outputs from different transformer blocks
        outputs = self.vit(pixel_values, output_hidden_states=True)
        
        # Last hidden state (output from final layer)
        final_layer = outputs.last_hidden_state
        
        # Intermediate features (from earlier layers)
        early_layer = outputs.hidden_states[3]  # 3rd layer
        middle_layer = outputs.hidden_states[6]  # 6th layer
        late_layer = outputs.hidden_states[9]   # 9th layer
        
        return early_layer, middle_layer, late_layer, final_layer
    
    def reshape_features(self, features):
        """Reshape features to have spatial dimensions"""
        batch_size, num_patches, hidden_dim = features.shape
        height_width = int((num_patches - 1) ** 0.5)  # Adjust for CLS token
        
        # Remove CLS token and reshape to spatial dimensions
        spatial_features = features[:, 1:, :].permute(0, 2, 1)
        spatial_features = spatial_features.view(batch_size, hidden_dim, height_width, height_width)
        
        return spatial_features

    def forward(self, pixel_values):
        # Extract multi-level features
        early_features, middle_features, late_features, final_features = self.extract_intermediate_features(pixel_values)
        
        # Reshape features to spatial form (batch, channels, height, width)
        early_spatial = self.reshape_features(early_features)
        middle_spatial = self.reshape_features(middle_features)
        late_spatial = self.reshape_features(late_features)
        final_spatial = self.reshape_features(final_features)
        
        # Apply adaptive pooling to get appropriate sizes for skip connections
        skip1 = self.adaptive_pool1(early_spatial)  # 60x60
        skip2 = self.adaptive_pool2(middle_spatial)  # 30x30
        skip3 = self.adaptive_pool3(late_spatial)  # 15x15
        
        # Process final features (deepest encoder output)
        x = self.adaptive_pool_final(final_spatial)  # 15x15
        
        # First upsampling block (15x15 -> 30x30)
        x = self.up1(x)  # Now 30x30
        
        # Add skip connection from late_spatial (already at 15x15, need to upsample)
        skip3_processed = self.skip_conn3(skip3)  # Process channels
        skip3_upsampled = F.interpolate(skip3_processed, size=(30, 30), mode='bilinear', align_corners=True)
        x = x + skip3_upsampled  # Add skip connection
        
        # Second upsampling block (30x30 -> 60x60)
        x = self.up2(x)  # Now 60x60
        
        # Add skip connection from middle_spatial (already at 30x30, need to upsample)
        skip2_processed = self.skip_conn2(skip2)  # Process channels
        skip2_upsampled = F.interpolate(skip2_processed, size=(60, 60), mode='bilinear', align_corners=True)
        x = x + skip2_upsampled  # Add skip connection
        
        # Third upsampling block (60x60 -> 120x120)
        x = self.up3(x)  # Now 120x120
        
        # Add skip connection from early_spatial (already at 60x60, need to upsample)
        skip1_processed = self.skip_conn1(skip1)  # Process channels
        skip1_upsampled = F.interpolate(skip1_processed, size=(120, 120), mode='bilinear', align_corners=True)
        x = x + skip1_upsampled  # Add skip connection
        
        # Final upsampling to target size (120x120 -> 449x449)
        x = self.up4(x)
        
        return x


def setup(rank, world_size):
    """
    Setup distributed training
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    """
    Clean up distributed training
    """
    dist.destroy_process_group()


def visualize_new(low_res, outputs, high_res, index=0, save_dir="results"):
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().numpy()  # Keep dimensions for safer indexing
    outputs = outputs.detach().numpy()
    high_res = high_res.detach().numpy()

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"prediction_{index:03d}.png")

    # Plot and save
    plt.figure(figsize=(12, 6))
    
    # Handle different tensor shapes more robustly
    if len(low_res.shape) == 4:  # [batch, channel, height, width]
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0, 0]), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0, 0]), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0, 0]), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    else:  # Handling alternative shapes
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0] if len(low_res.shape) > 2 else low_res), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0] if len(outputs.shape) > 2 else outputs), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0] if len(high_res.shape) > 2 else high_res), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
    return save_path  # Return the save path for logging


def train_vitunet_distributed(rank, world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                            batch_size, loss_func, model_load_path, model_save_path, 
                            check_num, save_interval=5, visualize_interval=1000):
    # Setup the distributed environment
    setup(rank, world_size)
    
    # For tracking loss and timing
    loss_history = []
    epoch_times = []
    total_start_time = time.time()
    
    # Create model and move it to GPU with id rank
    device = torch.device(f"cuda:{rank}")
    model = ViTUNetForImageReconstruction("google/vit-base-patch16-224").to(device)
    
    # Load model if specified
    if model_load_path != 'none':
        # Make sure to load to the correct device
        map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        model.load_state_dict(torch.load(model_load_path, map_location=map_location))
        if rank == 0:
            print(f'Loading Path: {model_load_path}')
            print("Model weights loaded.")
    
    # Wrap model with DDP - add find_unused_parameters=True to fix the DDP error
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
    
    # Set up the dataset and distributed sampler
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    
    # Calculate per-GPU batch size
    batch_size_per_gpu = batch_size // world_size
    if batch_size_per_gpu < 1:
        batch_size_per_gpu = 1
        if rank == 0:
            print(f"Warning: Adjusted batch size to 1 per GPU (total: {world_size})")
    
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size_per_gpu, 
        sampler=sampler, 
        num_workers=2,  # Adjust based on your system
        pin_memory=True
    )
    
    # Set up optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        if rank == 0:
            print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        if rank == 0:
            print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        if rank == 0:
            print('Using MSE Loss')
    
    if rank == 0:
        print(f'Start training with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size} (per GPU: {batch_size_per_gpu})')
        print(f'Total epochs: {num_epochs}, Save interval: {save_interval}')
    
    for epoch in range(num_epochs):
        # Set the epoch for the sampler
        sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0
        epoch_start_time = time.time()
        
        for i, (responses, labels) in enumerate(dataloader):
            # Move data to the correct device
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(responses)
            
            # Only print shapes on the first GPU during first iteration of first epoch
            if i == 0 and epoch == 0 and rank == 0:
                print(f"Response shape: {responses.shape}")
                print(f"Output shape: {outputs.shape}")
                print(f"Label shape: {labels.shape}")
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                      mode='bilinear', align_corners=True)
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            # Memory cleanup to avoid OOM errors
            if i % 10 == 0:
                torch.cuda.empty_cache()
                
            # Visualization during training (only on rank 0)
            if rank == 0 and i % visualize_interval == 0 and i > 0:
                with torch.no_grad():
                    vis_dir = os.path.join(model_save_path, 'visualizations')
                    os.makedirs(vis_dir, exist_ok=True)
                    vis_path = visualize_new(
                        responses[:1], outputs[:1], labels[:1],
                        index=epoch*1000+i+check_num,
                        save_dir=vis_dir
                    )
                    print(f"Batch {i} visualization saved to {vis_path}")
        
        # Calculate epoch time
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
            
        # Synchronize losses across GPUs
        loss_tensor = torch.tensor(total_loss).to(device)
        dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
        avg_loss = loss_tensor.item() / (len(dataloader) * world_size)
        
        if rank == 0:
            # Store loss and time
            loss_history.append(avg_loss)
            epoch_times.append(epoch_time)
            
            print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
            
            # Save model periodically (only on the first GPU)
            if (epoch + 1) % save_interval == 0:
                # Create the visualization directory if it doesn't exist
                vis_dir = os.path.join(model_save_path, 'visualizations')
                os.makedirs(vis_dir, exist_ok=True)
                
                # Visualize and save output
                with torch.no_grad():  # Add no_grad context to avoid memory leaks
                    model.eval()  # Set to evaluation mode
                    vis_path = visualize_new(
                        responses[:1], outputs[:1], labels[:1],  # Use only the first sample to save memory
                        index=epoch+check_num, 
                        save_dir=vis_dir
                    )
                    model.train()  # Set back to training mode
                print(f"Epoch {epoch+1} visualization saved to {vis_path}")
                
                # Save the model without DDP wrapper
                save_path = os.path.join(model_save_path, f'ViTUNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
                torch.save(model.module.state_dict(), save_path)
                print(f"Model saved to {save_path}")
                
                # Save loss history and timing
                metrics_path = os.path.join(model_save_path, f'metrics_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
                total_time = time.time() - total_start_time
                torch.save({
                    'loss_history': loss_history,
                    'epoch_times': epoch_times,
                    'total_time': total_time,
                    'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                    'current_epoch': epoch + check_num
                }, metrics_path)
                print(f"Training metrics saved to {metrics_path}")
                
                # Memory cleanup
                torch.cuda.empty_cache()
    
    # At the end of training, save the final metrics
    if rank == 0:
        total_time = time.time() - total_start_time
        metrics_path = os.path.join(model_save_path, f'metrics_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
        torch.save({
            'loss_history': loss_history,
            'epoch_times': epoch_times,
            'total_time': total_time,
            'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
            'total_epochs': num_epochs
        }, metrics_path)
        print(f"Final training metrics saved to {metrics_path}")
        print(f"Total training time: {total_time:.2f}s, Average epoch time: {sum(epoch_times) / len(epoch_times):.2f}s")
    
    # Clean up
    cleanup()


def main_train_vitunet(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                     loss_func, model_load_path, model_save_path, check_num, 
                     save_interval=5, visualize_interval=1000):
    # Create model save directory if it doesn't exist
    os.makedirs(model_save_path, exist_ok=True)
    
    # Number of GPUs available
    world_size = torch.cuda.device_count()
    print(f"Using {world_size} GPUs for training")
    
    if world_size > 1:
        # Set environment variable for detailed debug information
        os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
        
        # Use multiprocessing to launch multiple processes
        mp.spawn(
            train_vitunet_distributed,
            args=(world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                 batch_size, loss_func, model_load_path, model_save_path, 
                 check_num, save_interval, visualize_interval),
            nprocs=world_size,
            join=True
        )
    else:
        # Fall back to single GPU training if only one GPU is available
        print("Only one GPU detected. Using single GPU training.")
        # Define a simple wrapper for single-GPU training
        device = torch.device("cuda:0")
        model = ViTUNetForImageReconstruction("google/vit-base-patch16-224").to(device)
        
        if model_load_path != 'none':
            model.load_state_dict(torch.load(model_load_path))
            print(f'Loading Path: {model_load_path}')
            print("Model weights loaded.")
        
        # Set up the dataset and loader
        dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=True, 
            num_workers=4,
            pin_memory=True
        )
        
        # Set up optimizer and loss function
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
        if loss_func == 'ridge':
            criterion = RidgeLoss(alpha=0.5).to(device)
            print('Using Ridge Loss with alpha=0.5')
        elif loss_func == 'ssim':
            criterion = SSIMLoss(alpha=0.05).to(device)
            print('Using SSIM Loss with alpha=0.05')
        elif loss_func == 'mse':
            criterion = nn.MSELoss().to(device)
            print('Using MSE Loss')
        
        loss_history = []
        epoch_times = []
        total_start_time = time.time()
        
        print(f'Start training with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')
        
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            epoch_start_time = time.time()
            
            for i, (responses, labels) in enumerate(dataloader):
                responses = responses.float().to(device)
                labels = labels.float().to(device)
                
                optimizer.zero_grad()
                outputs = model(responses)
                
                if i == 0 and epoch == 0:
                    print(f"Response shape: {responses.shape}")
                    print(f"Output shape: {outputs.shape}")
                    print(f"Label shape: {labels.shape}")
                
                if outputs.shape != labels.shape:
                    outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                          mode='bilinear', align_corners=True)
                
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                if i % visualize_interval == 0 and i > 0:
                    with torch.no_grad():
                        vis_dir = os.path.join(model_save_path, 'visualizations')
                        os.makedirs(vis_dir, exist_ok=True)
                        vis_path = visualize_new(
                            responses[:1], outputs[:1], labels[:1],
                            index=epoch*1000+i+check_num,
                            save_dir=vis_dir
                        )
                        print(f"Batch {i} visualization saved to {vis_path}")
            
            epoch_end_time = time.time()
            epoch_time = epoch_end_time - epoch_start_time
            avg_loss = total_loss / len(dataloader)
            
            loss_history.append(avg_loss)
            epoch_times.append(epoch_time)
            
            print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
            
            if (epoch + 1) % save_interval == 0:
                vis_dir = os.path.join(model_save_path, 'visualizations')
                os.makedirs(vis_dir, exist_ok=True)
                
                with torch.no_grad():
                    model.eval()
                    vis_path = visualize_new(
                        responses[:1], outputs[:1], labels[:1],
                        index=epoch+check_num,
                        save_dir=vis_dir
                    )
                    model.train()
                print(f"Epoch {epoch+1} visualization saved to {vis_path}")
                
                save_path = os.path.join(model_save_path, f'ViTUNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
                torch.save(model.state_dict(), save_path)
                print(f"Model saved to {save_path}")
                
                metrics_path = os.path.join(model_save_path, f'metrics_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
                total_time = time.time() - total_start_time
                torch.save({
                    'loss_history': loss_history,
                    'epoch_times': epoch_times,
                    'total_time': total_time,
                    'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                    'current_epoch': epoch + check_num
                }, metrics_path)
                print(f"Training metrics saved to {metrics_path}")


if __name__ == '__main__':
    response_file_path = '/home/share/0125_dataset/01092025 test slide 1/'
    gt_file_path = '/home/share/0125_dataset/01092025 test slide 1/'

    # Clear cache (recommended for most cases)
    torch.cuda.empty_cache()

    # Optional: Clear all unreferenced memory (for aggressive cleanup)
    torch.cuda.ipc_collect()

    print('GPU Clean')


    print('ViTUNet Model Training with DDP')

    main_train_vitunet(
        response_dir=response_file_path, 
        gt_dir=gt_file_path, 
        num_epochs=50, 
        learning_rate=1e-4, 
        batch_size=150,  # Adjust based on your GPU memory
        loss_func='ssim', 
        model_load_path='none'
        model_save_path='/home/dan5/optics_recon/Optics_Recon_Project/vunt_earlyskip_param/continue/',
        check_num=0, 
        save_interval=5,
        visualize_interval=1000
    )