import torch.nn as nn
import torch
from train_cifar_classifier.wide_resnet_pytorch.networks.resnetv2 import ResNetv2
from train_cifar_classifier.wide_resnet_pytorch.networks.resnetv2 import ResNetv2_rej
import torch.nn.functional as F
from train_cifar_classifier.wide_resnet_pytorch.networks.wide_resnet import Wide_ResNet
import clip  # Ensure you have the official CLIP repository installed
import preprocess as pre

class MNIST_CNN(nn.Module):
    def __init__(self, n_agents=1, dropout=0):
        super(MNIST_CNN, self).__init__()
        self.n_agents = n_agents
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # Input channels changed to 1 for grayscale images
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64*14*14, 128)
        self.fc2 = nn.Linear(128, self.n_agents)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the feature map for fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class CIFAR100_CNN(nn.Module):
    def __init__(self, n_agents=1, dropout=0):
        super(CIFAR100_CNN, self).__init__()
        self.n_agents = n_agents
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # Change input channels to 3 for RGB images
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        # Adjust fully connected layer for 32x32 input size after pooling
        self.fc1 = nn.Linear(256 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, self.n_agents)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the feature map for fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class CIFAR100_WIDERESNET(nn.Module):
    def __init__(self, n_agents, model_name='wideresnet28_10', pretrained=True):
        """
        Initializes the CIFAR100_WIDERESNET model.

        Args:
            model_name (str): The specific WideResNet model to load from torch.hub.
                              Options include 'wideresnet28_10', 'wideresnet40_2', etc.
            pretrained (bool): If True, loads the model with pre-trained weights.
                               Pre-trained models are trained on CIFAR-100.
            num_classes (int): Number of output classes. Default is 100 for CIFAR-100.
        """
        super(CIFAR100_WIDERESNET, self).__init__()
        # Load the WideResNet model from the specified torch.hub repository
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=False)

        # If not using a pre-trained model, adjust the final layer
        in_features = self.model.fc.in_features
        self.n_agents = n_agents
        self.model.fc = nn.Linear(in_features, self.n_agents)

    def forward(self, x):
        return self.model(x)




class CIFAR100_RESNET18(nn.Module):
    def __init__(self, n_agents, dropout=0):
        super(CIFAR100_RESNET18, self).__init__()
        # Load the WideResNet model from the specified torch.hub repository
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

        # If not using a pre-trained model, adjust the final layer
        in_features = self.model.fc.in_features
        self.n_agents = n_agents
        self.model.fc = nn.Linear(in_features, self.n_agents)

    def forward(self, x):
        return self.model(x)


class CIFAR100_RESNET18_rej(nn.Module):
    def __init__(self, n_agents, dropout=0):
        super(CIFAR100_RESNET18_rej, self).__init__()
        # Load the ResNet-18 model from the specified torch.hub repository
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

        # Get the number of input features to the final fc layer
        in_features = self.model.fc.in_features
        self.n_agents = n_agents

        # Remove the last fc layer and define features
        self.features = nn.Sequential(*list(self.model.children())[:-1])  # Exclude the last fc layer

        # Add the dropout layer
        self.dropout = nn.Dropout(dropout)

        # Define a new fc layer
        self.fc = nn.Linear(in_features, self.n_agents)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)  # Apply dropout before the final fc layer
        x = self.fc(x)
        return x

class LeNet(nn.Module):
    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, num_classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)

        return(out)


class TinyCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(TinyCNN, self).__init__()
        # A single convolutional layer with a small number of filters
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)

        # A single fully connected layer
        self.fc1 = nn.Linear(8 * 16 * 16, num_classes)

    def forward(self, x):
        # Apply the convolution, followed by ReLU and a pooling layer
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        # Flatten the tensor for the fully connected layer
        x = x.view(-1, 8 * 16 * 16)

        # Output layer directly maps to the number of classes
        x = self.fc1(x)
        return x

class linearRegression(nn.Module):
    def __init__(self, num_features):
        super(linearRegression, self).__init__()
        self.linear = nn.Linear(num_features, 1)

    def forward(self, x):
        return self.linear(x)

class MLP(nn.Module):
    def __init__(self, num_classes, in_features=8):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_features, 16)
        self.fc2 = nn.Linear(16, num_classes)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class CombinedModel(nn.Module):
    def __init__(
        self,
        mlp_cls,
        num_classes:    int,
        num_regressors: int,
        in_features: int,
    ):
        """
        mlp_cls:          your MLP class  (signature: __init__(self, num_outputs))
        num_classes:      # dims for the classifier head
        num_regressors:   # dims for the regressor head
        """
        super().__init__()
        # two separate MLPs from the same class
        self.classifier = mlp_cls(num_classes, in_features=in_features)
        self.regressor  = mlp_cls(num_regressors, in_features=in_features)

    def forward(self, x: torch.Tensor):
        """
        Args:
          x: Tensor of shape (N, 8)  (since your MLP assumes input_dim=8)
        Returns:
          logits: Tensor of shape (N, num_classes)
          preds:  Tensor of shape (N, num_regressors)
        """
        logits = self.classifier(x)
        preds  = self.regressor(x)
        return logits, preds


class MLP_wide(nn.Module):
    def __init__(self, num_classes, in_features=1):
        super(MLP_wide, self).__init__()
        self.fc1 = nn.Linear(in_features, 512)       # from 8 to 512
        self.fc2 = nn.Linear(512, 256)     # added more width
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x


class MLP_extra(nn.Module):
    def __init__(self, num_classes):
        super(MLP_extra, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128, num_classes)  # Output single value
        )

    def forward(self, x):
        return self.net(x)

def denormalize(img, args):
    """
    Denormalize a tensor image using the provided mean and std.

    Args:
        img (Tensor): Normalized image tensor of shape [C, H, W] or [B, C, H, W].
        mean (tuple or list): Mean values used for normalization.
        std (tuple or list): Standard deviation values used for normalization.

    Returns:
        Tensor: Denormalized image tensor.
    """
    mean, std = pre.images_std_mean(args.dataset)
    # Convert mean and std to tensors and reshape for broadcasting.
    if img.ndimension() == 4:
        # Batch of images
        mean = torch.tensor(mean).view(1, -1, 1, 1).to(img.device)
        std = torch.tensor(std).view(1, -1, 1, 1).to(img.device)
    else:
        # Single image
        mean = torch.tensor(mean).view(-1, 1, 1).to(img.device)
        std = torch.tensor(std).view(-1, 1, 1).to(img.device)

    # Denormalize: multiply by std and add mean.
    img_denorm = img * std + mean
    return img_denorm

class clip_architecture(nn.Module):
    def __init__(self, clip_model, num_classes=100, hidden_dim=512, args=None):
        """
        Initializes the architecture using the full CLIP model's image encoder as a feature extractor
        and a 2-layer feedforward network as the classifier.

        Args:
            clip_model: The preloaded CLIP model (e.g., from clip.load("ViT-B/32", device=device)).
            num_classes: Number of classes for classification (default is 100 for CIFAR-100).
            hidden_dim: The dimension of the hidden layer in the feedforward classifier.
        """
        super(clip_architecture, self).__init__()
        # Save the full CLIP model to utilize its encode_image method
        clip_model, preprocess = clip.load("ViT-B/32")
        self.clip_model = clip_model

        # delete normalization
        del preprocess.transforms[3]
        del preprocess.transforms[2]
        self.preprocess = preprocess
        self.config = args

        # Define a two-layer feedforward classifier.
        self.classifier = nn.Sequential(
            nn.Linear(clip_model.visual.output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        """
        Forward pass using the CLIP model's encode_image function for feature extraction.
        The CLIP weights are kept frozen.

        Args:
            x: Input tensor (batch of images).

        Returns:
            logits: The output logits from the classifier.
        """
        denorm_image = denormalize(x, self.config)
        x = self.preprocess(denorm_image)
        with torch.no_grad():
            features = self.clip_model.encode_image(x)
        logits = self.classifier(features.float())
        return logits


class ResidualMLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, num_classes: int):
        super().__init__()
        # First transform
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        # Second transform
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # Final classifier
        self.fc_out = nn.Linear(hidden_dim, num_classes)
        self.act = nn.ReLU(inplace=True)
        # Optional normalization
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        # project into hidden
        h1 = self.act(self.fc1(x))
        # second layer
        h2 = self.fc2(h1)
        # add the residual connection, then normalize & activate
        h = self.act(self.norm(h2 + h1))
        # final linear
        return self.fc_out(h)

import torch
import torch.nn as nn
import torch.nn.functional as F
import clip

# A single MLP-based residual block over a D-dimensional vector
class MLPResNetBlock(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.fc1  = nn.Linear(dim, dim)
        self.fc2  = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        identity = x
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        out = self.norm(out + identity)
        return F.relu(out)

# A "ResNet-style" MLP head: one input projection, several residual blocks, then output projection
class MLPResNetHead(nn.Module):
    def __init__(
        self,
        in_dim: int,
        hidden_dim: int,
        num_classes: int,
        num_blocks: int = 3
    ):
        super().__init__()
        # Project CLIP feature vector into hidden_dim
        self.input_layer = nn.Linear(in_dim, hidden_dim)
        # Stack of residual blocks
        self.blocks = nn.ModuleList([
            MLPResNetBlock(hidden_dim) for _ in range(num_blocks)
        ])
        # Final classifier
        self.output_layer = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.input_layer(x))
        for block in self.blocks:
            x = block(x)
        return self.output_layer(x)

class clip_architecture2(nn.Module):
    """
    Uses CLIP's image encoder as a frozen feature extractor,
    plus a small ResNet-style MLP head for classification.
    """
    def __init__(
        self,
        clip_model,
        num_classes: int = 100,
        hidden_dim: int = 512,
        args = None,
        num_blocks: int = 4
    ):
        super().__init__()
        # Load CLIP model and its image preprocess pipeline
        # (if you already have clip_model and preprocess, skip this line)
        clip_model, preprocess = clip.load("ViT-B/32", device=(args.device if args else None))
        self.clip_model = clip_model
        # Freeze all CLIP weights
        for p in self.clip_model.parameters():
            p.requires_grad = False

        # Remove CLIP's default normalization transforms
        # (adjust indices if CLIP version changes)
        del preprocess.transforms[3]
        del preprocess.transforms[2]
        self.preprocess = preprocess

        self.config = args
        in_dim = clip_model.visual.output_dim
        # Create the ResNet-style classifier head
        self.classifier = MLPResNetHead(
            in_dim=in_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            num_blocks=num_blocks
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure CLIP is in eval mode
        self.clip_model.eval()
        # Denormalize input if you previously normalized it
        denorm_image = denormalize(x, self.config)  # <-- define this function separately
        # Apply CLIP's preprocess pipeline
        x = self.preprocess(denorm_image)
        # Extract features without gradient
        with torch.no_grad():
            features = self.clip_model.encode_image(x)
        # Classify
        logits = self.classifier(features.float())
        return logits


def model_choice(name, depth = 4, args=None, n_agents=None, dropout=0):
    if args.dataset == 'mnist':
        if name == 'cnn':
            return MNIST_CNN(n_agents)
        if name == 'resnet4_rej':
            return ResNetv2_rej(4, n_agents, dropout=dropout)
    if args.dataset == 'cifar100' or args.dataset == 'cifar10' or args.dataset == 'svhn' or args.dataset == 'cifar10H':
        if name == 'cnn':
            return CIFAR100_CNN(n_agents)
        if name == 'resnet':
            return ResNetv2_rej(depth, n_agents, dropout=dropout)
        if name == 'lenet':
            return LeNet(n_agents)
        if name == 'tiny':
            return TinyCNN(n_agents)
        if name == 'wideresnet':
            return Wide_ResNet(28, 10, dropout, n_agents)
        if name == 'clip':
            return clip_architecture(clip_model=None, num_classes=n_agents, args=args)
        if name == 'clip2':
            return clip_architecture2(clip_model=None, num_classes=n_agents, args=args)

    if args.task == 'regression':
        if args.dataset == 'california':
            in_features = 8
        elif args.dataset == 'ames':
            in_features = 261

        if name == 'linear':
            return linearRegression(num_features=n_agents)
        if name == 'MLP':
            return MLP(n_agents)
        if name == 'MLP_wide':
            return MLP_wide(n_agents, in_features=in_features)
        if name == 'MLP_extra':
            return MLP_extra(n_agents)
        if name == 'rejector_regressor':
            return CombinedModel(MLP, num_classes=n_agents+1, num_regressors=args.num_classes, in_features=in_features)

    else:
        raise ValueError('Dataset not supported')
