import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from transformers import AutoModelForImageClassification
from utilies import normalize_matrix, down_sample_matrix, mat2vec, vec2mat, mask_response_circle

from opticsNN_data_loader import ResponseGTImageDataset, RidgeLoss, SSIMLoss, visualize_predictions, transform_single_to_three_channel
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


'''
-----------------UNET----------------------
'''


class DoubleConv(nn.Module):
    """Double convolution block used in UNet"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input size might be odd, so we need to adjust the padding
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    """Final output convolution"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=3, bilinear=True, output_size=(449, 449)):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.output_size = output_size
        
        # Encoder path
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Decoder path with skip connections
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        
        # Final convolution
        self.outc = OutConv(64, n_classes)
        
        # Optional sigmoid to ensure output is in range [0, 1]
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Ensure the input is float32
        x = x.float()
        
        # Contracting path (encoder)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Expanding path (decoder) with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Final convolution and activation
        x = self.outc(x)
        x = self.sigmoid(x)
        
        # Resize to match target size
        if x.size() != self.output_size:
            x = F.interpolate(x, size=self.output_size, mode='bilinear', align_corners=True)
            
        return x


# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(transform_single_to_three_channel),
    transforms.Resize((224, 224)),  # Resize to fit the ViT model input
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(transform_single_to_three_channel),
    transforms.Resize((449, 449))  # Resize GT images to desired output size
])


def train_unet_model(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                    loss_func, model_load_path, model_save_path, check_num, 
                    save_interval=5, visualize_interval=1000):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create model with the correct output size
    model = UNet(n_channels=3, n_classes=3, bilinear=True, output_size=(449, 449)).to(device)

    if model_load_path != 'none':
        print(f'Loading Path: {model_load_path}')
        model.load_state_dict(torch.load(model_load_path))
        print("Model weights loaded.")

    # DataLoader setup
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Check sizes from the first batch
    sample_batch = next(iter(dataloader))
    input_shape = sample_batch[0].shape
    target_shape = sample_batch[1].shape
    print(f"Input shape: {input_shape}, Target shape: {target_shape}")
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        print('Using MSE Loss')

    print(f'Start training with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for i, (responses, labels) in enumerate(dataloader):
            # Convert tensors to float
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            optimizer.zero_grad()
            outputs = model(responses)
            
            # Check the sizes
            if i == 0 and epoch == 0:
                print(f"Response shape: {responses.shape}")
                print(f"Output shape: {outputs.shape}")
                print(f"Label shape: {labels.shape}")
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                        mode='bilinear', align_corners=True)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if i % visualize_interval == 0:
                visualize_predictions(responses, outputs, labels)

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

        if (epoch + 1) % save_interval == 0:
            save_path = os.path.join(model_save_path, f'UNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")

    return model


def setup(rank, world_size):
    """
    Setup distributed training
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    

def cleanup():
    """
    Clean up distributed training
    """
    dist.destroy_process_group()


import matplotlib.pyplot as plt
import os

def visualize_new(low_res, outputs, high_res, index=0, save_dir="results"):
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().numpy()  # Keep dimensions for safer indexing
    outputs = outputs.detach().numpy()
    high_res = high_res.detach().numpy()

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"prediction_{index:03d}.png")

    # Plot and save
    plt.figure(figsize=(12, 6))
    
    # Handle different tensor shapes more robustly
    if len(low_res.shape) == 4:  # [batch, channel, height, width]
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0, 0]), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0, 0]), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0, 0]), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    else:  # Handling alternative shapes
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0] if len(low_res.shape) > 2 else low_res), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0] if len(outputs.shape) > 2 else outputs), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0] if len(high_res.shape) > 2 else high_res), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
    return save_path  # Return the save path for logging


def train_unet_distributed(rank, world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                         batch_size, loss_func, model_load_path, model_save_path, 
                         check_num, save_interval=5, visualize_interval=1000):
    # Setup the distributed environment
    setup(rank, world_size)
    
    # For tracking loss and timing
    loss_history = []
    epoch_times = []
    import time
    total_start_time = time.time()
    
    # Create model and move it to GPU with id rank
    device = torch.device(f"cuda:{rank}")
    model = UNet(n_channels=3, n_classes=3, bilinear=True, output_size=(449, 449)).to(device)
    
    # Load model if specified
    if model_load_path != 'none':
        # Make sure to load to the correct device
        map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        model.load_state_dict(torch.load(model_load_path, map_location=map_location))
        if rank == 0:
            print(f'Loading Path: {model_load_path}')
            print("Model weights loaded.")
    
    # Wrap model with DDP
    model = DDP(model, device_ids=[rank], output_device=rank)
    
    # Set up the dataset and distributed sampler
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=batch_size//world_size, sampler=sampler, num_workers=4, pin_memory=True)
    
    # Set up optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        if rank == 0:
            print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        if rank == 0:
            print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        if rank == 0:
            print('Using MSE Loss')
    
    if rank == 0:
        print(f'Start training with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')
    
    for epoch in range(num_epochs):
        # Set the epoch for the sampler
        sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0
        epoch_start_time = time.time()
        
        for i, (responses, labels) in enumerate(dataloader):
            # Move data to the correct device
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            optimizer.zero_grad()
            outputs = model(responses)
            
            # Only print shapes on the first GPU during first iteration of first epoch
            if i == 0 and epoch == 0 and rank == 0:
                print(f"Response shape: {responses.shape}")
                print(f"Output shape: {outputs.shape}")
                print(f"Label shape: {labels.shape}")
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                      mode='bilinear', align_corners=True)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Calculate epoch time
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
            
        # Synchronize losses across GPUs
        loss_tensor = torch.tensor(total_loss).to(device)
        dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
        avg_loss = loss_tensor.item() / (len(dataloader) * world_size)
        
        if rank == 0:
            # Store loss and time
            loss_history.append(avg_loss)
            epoch_times.append(epoch_time)
            
            print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
            
            # Save model periodically (only on the first GPU)
            if (epoch + 1) % save_interval == 0:
                # Visualize and save output
                vis_path = visualize_new(responses, outputs, labels, index=epoch+check_num, 
                                        save_dir=os.path.join(model_save_path, 'visualizations'))
                print(f"Visualization saved to {vis_path}")
                
                # Save the model without DDP wrapper
                save_path = os.path.join(model_save_path, f'UNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
                torch.save(model.module.state_dict(), save_path)
                print(f"Model saved to {save_path}")
                
                # Save loss history and timing
                metrics_path = os.path.join(model_save_path, f'metrics_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
                total_time = time.time() - total_start_time
                torch.save({
                    'loss_history': loss_history,
                    'epoch_times': epoch_times,
                    'total_time': total_time,
                    'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                    'current_epoch': epoch + check_num
                }, metrics_path)
                print(f"Training metrics saved to {metrics_path}")
    
    # At the end of training, save the final metrics
    if rank == 0:
        total_time = time.time() - total_start_time
        metrics_path = os.path.join(model_save_path, f'metrics_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
        torch.save({
            'loss_history': loss_history,
            'epoch_times': epoch_times,
            'total_time': total_time,
            'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
            'total_epochs': num_epochs
        }, metrics_path)
        print(f"Final training metrics saved to {metrics_path}")
        print(f"Total training time: {total_time:.2f}s, Average epoch time: {sum(epoch_times) / len(epoch_times):.2f}s")
    
    # Clean up
    cleanup()


def main_train_unet(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                   loss_func, model_load_path, model_save_path, check_num, 
                   save_interval=5, visualize_interval=1000):
    # Number of GPUs available
    world_size = torch.cuda.device_count()
    print(f"Using {world_size} GPUs for training")
    
    if world_size > 1:
        # Use multiprocessing to launch multiple processes
        mp.spawn(
            train_unet_distributed,
            args=(world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                 batch_size, loss_func, model_load_path, model_save_path, 
                 check_num, save_interval, visualize_interval),
            nprocs=world_size,
            join=True
        )
    else:
        # Fall back to single GPU training if only one GPU is available
        print("Only one GPU detected. Using single GPU training.")
        train_unet_model(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                       loss_func, model_load_path, model_save_path, check_num, 
                       save_interval, visualize_interval)


# This is the crucial part: ensure that the multiprocessing code only runs if this script is run directly
if __name__ == '__main__':
    response_file_path='/home/share/0125_dataset/01092025 test slide 1/'
    gt_file_path='/home/share/0125_dataset/01092025 test slide 1/'

    print('50, 80')

    main_train_unet(
        response_dir=response_file_path, 
        gt_dir=gt_file_path, 
        num_epochs=50, 
        learning_rate=1e-4, 
        batch_size=150,  # Increased batch size to take advantage of multiple GPUs
        loss_func='ssim', 
        model_load_path='none', 
        model_save_path='/home/dan5/optics_recon/Optics_Recon_Project/unet_param/',
        check_num=0, 
        save_interval=5,
        visualize_interval=50
    )