import os
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F


def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)


def get_data(args):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader


def setup_logging(run_name):
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

class UNet(nn.Module):
    def __init__(self, c_in=2, device="cuda"):
        super().__init__()
        self.device = device

        
        self.first_conv = nn.Sequential(
            nn.Conv2d(c_in, 16, kernel_size=3, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=0),
            nn.ReLU(),
        )

        self.d = nn.Linear(64, 1)


    def forward(self, x):
        x = self.first_conv(x)
        x = torch.mean(x, dim=(-2,-1))
        x = self.d(x)
        x = x.view(-1)
        return x


class UNetBig(nn.Module):
    def __init__(self, c_in=2, device="cuda", final_act='relu'):
        super().__init__()
        self.device = device

        
        self.first_conv = nn.Sequential(
            nn.Conv2d(c_in, 16, kernel_size=3, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=0),
            nn.ReLU(),
        )

        self.d1 = nn.Linear(64, 16)
        if final_act == 'relu':
            self.a1 = nn.ReLU()
        elif final_act == 'tanh':
            self.a1 = nn.Tanh()
        self.d2 = nn.Linear(16, 1)


    def forward(self, x):
        x = self.first_conv(x)
        x = torch.mean(x, dim=(-2,-1))
        x = self.d1(x)
        x = self.a1(x)
        x = self.d2(x)
        x = x.view(-1)
        return x
    


class TwoStreamIQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Stream 1: Spatial Map Convolutional Layers
        self.spatial_conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.spatial_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        
        # Stream 2: Informational Map Convolutional Layers (Using dilated convolutions)
        self.info_conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, dilation=2)
        self.info_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, dilation=2)
        
        # Fusion Layer
        self.fusion_fc1 = nn.Linear(64 * 2, 128)  # Adjust input size based on spatial map dimensions
        self.fusion_fc2 = nn.Linear(128, 64)
        
        # Optional Attention Layer
        self.attention = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.Sigmoid()
        )
        
        # Final Output Layer
        self.output_fc = nn.Linear(64, 1)

    def forward(self, x):
        spatial_map, informational_map = x[:, 0].unsqueeze(1), x[:, 1].unsqueeze(1)
        # Stream 1: Spatial Map
        x1 = F.relu(self.spatial_conv1(spatial_map))
        x1 = F.relu(self.spatial_conv2(x1))
        x1 = torch.mean(x1, dim=(-2,-1))
        
        # Stream 2: Informational Map
        x2 = F.relu(self.info_conv1(informational_map))
        x2 = F.relu(self.info_conv2(x2))
        x2 = torch.mean(x2, dim=(-2,-1))
        
        # Flatten and Concatenate
        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)

        # print(x1.shape)
        # print(x2.shape)
        x = torch.cat((x1, x2), dim=1)
        
        # Fusion Layers
        x = F.relu(self.fusion_fc1(x))
        x = F.relu(self.fusion_fc2(x))
        
        # Attention Layer
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        # Output
        output = self.output_fc(x)
        output = output.view(-1)
        
        return output
    


class TwoStreamIQABIG(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Stream 1: Spatial Map Convolutional Layers
        self.spatial_conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.spatial_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.spatial_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Stream 2: Informational Map Convolutional Layers (Using dilated convolutions)
        self.info_conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, dilation=2)
        self.info_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, dilation=2)
        self.info_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Fusion Layer
        self.fusion_fc1 = nn.Linear(128 * 2, 128)  # Adjust input size based on spatial map dimensions
        self.fusion_fc2 = nn.Linear(128, 128)
        
        # Optional Attention Layer
        self.attention = nn.Sequential(
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 128),
            nn.Sigmoid()
        )
        
        # Final Output Layer
        self.output_fc = nn.Linear(128, 1)

    def forward(self, x):
        spatial_map, informational_map = x[:, 0].unsqueeze(1), x[:, 1].unsqueeze(1)
        # Stream 1: Spatial Map
        x1 = F.relu(self.spatial_conv1(spatial_map))
        x1 = F.relu(self.spatial_conv2(x1))
        x1 = F.relu(self.spatial_conv3(x1))
        x1 = torch.mean(x1, dim=(-2,-1))
        
        # Stream 2: Informational Map
        x2 = F.relu(self.info_conv1(informational_map))
        x2 = F.relu(self.info_conv2(x2))
        x2 = F.relu(self.info_conv3(x2))
        x2 = torch.mean(x2, dim=(-2,-1))
        
        # Flatten and Concatenate
        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)

        # print(x1.shape)
        # print(x2.shape)
        x = torch.cat((x1, x2), dim=1)
        
        # Fusion Layers
        x = F.relu(self.fusion_fc1(x))
        x = F.relu(self.fusion_fc2(x))
        
        # Attention Layer
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        # Output
        output = self.output_fc(x)
        output = output.view(-1)
        
        return output
    


class TwoStreamIQABIG2(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Stream 1: Spatial Map Convolutional Layers
        self.spatial_conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.spatial_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.spatial_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.spatial_conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        
        # Stream 2: Informational Map Convolutional Layers (Using dilated convolutions)
        self.info_conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.info_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.info_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.info_conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        
        # Fusion Layer
        self.fusion_fc1 = nn.Linear(512 * 2, 1024)  # Adjust input size based on spatial map dimensions
        self.fusion_fc2 = nn.Linear(1024, 512)
        self.fusion_fc3 = nn.Linear(512, 256)
        
        # Optional Attention Layer
        self.attention = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.Sigmoid()
        )
        
        # Final Output Layer
        self.output_fc = nn.Linear(256, 1)

    def forward(self, x):
        x = x[0]
        spatial_map, informational_map = x[:, 0].unsqueeze(1), x[:, 1].unsqueeze(1)
        # Stream 1: Spatial Map
        x1 = F.relu(self.spatial_conv1(spatial_map))
        x1 = F.relu(self.spatial_conv2(x1))
        x1 = F.relu(self.spatial_conv3(x1))
        x1 = F.relu(self.spatial_conv4(x1))
        x1 = torch.mean(x1, dim=(-2,-1))
        
        # Stream 2: Informational Map
        x2 = F.relu(self.info_conv1(informational_map))
        x2 = F.relu(self.info_conv2(x2))
        x2 = F.relu(self.info_conv3(x2))
        x2 = F.relu(self.info_conv4(x2))
        x2 = torch.mean(x2, dim=(-2,-1))
        
        # Flatten and Concatenate
        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)

        # print(x1.shape)
        # print(x2.shape)
        x = torch.cat((x1, x2), dim=1)
        
        # Fusion Layers
        x = F.relu(self.fusion_fc1(x))
        x = F.relu(self.fusion_fc2(x))
        x = F.relu(self.fusion_fc3(x))
        
        # Attention Layer
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        # Output
        output = self.output_fc(x)
        output = output.view(-1)
        
        return output
    