import torch
import torch.nn as nn
import torchvision
import math
from eos_line_search.experiment import *
from torchinfo import summary
import vit_pytorch as vit
from efficientnet_pytorch import EfficientNet
from collections import OrderedDict


def adapt_resnet_for_cifar(model, keep_maxpool=True):
    """
    Convert an ImageNet-style torchvision ResNet to a CIFAR-style stem in-place,
    but keep an option to retain the initial maxpool to avoid excessive activation
    memory growth on large batches with small images.

    Behavior:
      - Replace conv1 (7x7, stride=2) with 3x3 conv.
        * stride is set to 1 so the receptive field is CIFAR-friendly.
      - If keep_maxpool=True (default), keep model.maxpool (if present).
        This yields intermediate spatial sizes like: 32 -> 32 (conv1) -> 16 (maxpool)
        -> layer1..layer4 producing final layer4 roughly 2x2 for 32x32 inputs.
      - If keep_maxpool=False, maxpool is removed (original "cifar stem"), giving
        larger HxW (e.g., layer4 ~ 4x4 for 32x32 input), which can increase memory.

    This function only modifies the stem and maxpool; it does not otherwise change
    layer channel counts or block structure.
    """
    # Replace conv1: 7x7 stride2 -> 3x3 stride=1 (CIFAR-style)
    if hasattr(model, "conv1"):
        in_ch = model.conv1.in_channels
        out_ch = model.conv1.out_channels
        bias_present = model.conv1.bias is not None
        new_conv1 = nn.Conv2d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=bias_present,
        )

        # Try to preserve some initialization by copying the centered 3x3 patch
        try:
            with torch.no_grad():
                old_w = model.conv1.weight.data  # [out, in, 7, 7]
                if old_w.shape[2] >= 3:
                    cz = old_w.shape[2] // 2
                    center3 = old_w[:, :, cz - 1 : cz + 2, cz - 1 : cz + 2].clone()
                    if center3.shape == new_conv1.weight.data.shape:
                        new_conv1.weight.data.copy_(center3)
        except Exception:
            # fallback: leave default initialization
            pass

        model.conv1 = new_conv1

    # Keep or remove the initial maxpool according to keep_maxpool
    if hasattr(model, "maxpool"):
        if keep_maxpool:
            # Ensure we have a maxpool (leave as-is)
            # but if it was removed previously (Identity), restore a standard maxpool
            if isinstance(model.maxpool, nn.Identity):
                model.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            # remove maxpool (CIFAR stem)
            model.maxpool = nn.Identity()

    return model


def print_spatial_shapes(model, input_size=(1, 3, 32, 32), device="cpu"):
    """
    Run a single forward and print shapes of outputs of layer1..layer4 to help debug spatial dims.
    """
    model.eval()
    x = torch.randn(input_size).to(device)
    shapes = {}
    hooks = []

    def make_hook(name):
        def hook(m, inp, out):
            if isinstance(out, torch.Tensor) and out.dim() == 4:
                shapes[name] = tuple(out.shape)

        return hook

    # register hooks on top-level layer modules (if present)
    for nm, m in model.named_modules():
        if nm in ("layer1", "layer2", "layer3", "layer4"):
            hooks.append(m.register_forward_hook(make_hook(nm)))

    with torch.no_grad():
        # ensure model and input are on same device
        device_of_model = (
            next(model.parameters()).device
            if any(p.requires_grad for p in model.parameters())
            else x.device
        )
        _ = model(x.to(device_of_model))

    for h in hooks:
        h.remove()

    for k, v in shapes.items():
        print(f"{k} output shape: {v}")


# Linear Regression (use MSE Loss)
class LinearRegression(nn.Module):
    def __init__(self, run):
        super().__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Linear(
                run.dataset.input_dim,
                run.dataset.output_dim,
                bias=True,
            ),
        )
        self.initialize_weights(run)

    def forward(self, x):
        x = self.flatten(x)
        predictions = self.model(x)
        return predictions

    def initialize_weights(self, run):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                if m.bias is not None:
                    nn.init.constant_(m.bias, run.dataset.label_avg)


# Logistic Regression (use BCEWithLogits Loss)
class LogisticRegression(nn.Module):
    def __init__(self, run):
        super().__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Linear(
                run.dataset.input_dim,
                run.dataset.output_dim,
                bias=True,
            )
        )
        self.initialize_weights(run)

    def forward(self, x):
        x = self.flatten(x)
        predictions = self.model(x)
        return predictions

    def initialize_weights(self, run):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                if m.bias is not None:
                    if run.dataset.label_avg == 1:
                        nn.init.constant_(m.bias, 1)
                    else:
                        nn.init.constant_(
                            m.bias,
                            math.log(
                                run.dataset.label_avg / (1 - run.dataset.label_avg)
                            ),
                        )


# MLP
class MLP(nn.Module):
    def __init__(self, run):
        super().__init__()
        dict = OrderedDict()
        dict.update({"flatten": nn.Flatten()})
        if run.model.num_layers == 1:
            dict.update(
                {
                    "layer1": nn.Linear(
                        run.dataset.input_dim,
                        run.dataset.output_dim,
                        bias=True,
                    )
                }
            )
        else:
            for i in range(run.model.num_layers):
                if i == 0:
                    dict.update(
                        {
                            "layer"
                            + str(i + 1): nn.Linear(
                                run.dataset.input_dim,
                                run.model.width,
                                bias=True,
                            )
                        }
                    )
                    # if not experiment_parameters["linear"]:
                    dict.update({"activation" + str(i + 1): run.model.activation_fn()})
                elif i == run.model.num_layers - 1:
                    dict.update(
                        {
                            "layer"
                            + str(i + 1): nn.Linear(
                                run.model.width,
                                run.dataset.output_dim,
                                bias=True,
                            )
                        }
                    )
                else:
                    dict.update(
                        {
                            "layer"
                            + str(i + 1): nn.Linear(
                                run.model.width,
                                run.model.width,
                                bias=True,
                            )
                        }
                    )
                    # if not experiment_parameters["linear"]:
                    dict.update({"activation" + str(i + 1): run.model.activation_fn()})
        self.model = nn.Sequential(dict)

    def forward(self, x):
        predictions = self.model(x)
        return predictions

    def initialize_weights(self, initialization="default"):
        for m in self.modules():
            if isinstance(m, nn.Linear) and initialization == "xavier_normal":
                nn.init.xavier_normal_(m.weight, 1)

            if isinstance(m, nn.Linear) and initialization == "xavier_uniform":
                nn.init.xavier_uniform_(m.weight, 1)

            if isinstance(m, nn.Linear) and initialization == "kaiming_normal":
                nn.init.kaiming_normal_(m.weight, mode="fan_in")

            if isinstance(m, nn.Linear) and initialization == "kaiming_uniform":
                nn.init.kaiming_uniform_(m.weight, mode="fan_in")


# CNN
class CNN(nn.Module):
    def __init__(self, run):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(
                run.dataset.image_colour_channels,
                32,
                bias=True,
                kernel_size=3,
                padding=1,
            ),
            run.model.activation_fn(),
            run.model.pooling(run.model.window_size),
            nn.Conv2d(
                32,
                32,
                bias=True,
                kernel_size=3,
                padding=1,
            ),
            run.model.activation_fn(),
            run.model.pooling(run.model.window_size),
            nn.Flatten(),
            nn.Linear(
                32
                * int(run.dataset.image_height / 4)
                * int(run.dataset.image_width / 4),
                run.dataset.output_dim,
                bias=True,
            ),
        )

    def forward(self, x):
        predictions = self.model(x)
        return predictions


# Deeper MLP
class DeepMLP(nn.Module):
    def __init__(self, run):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(run.dataset.input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, run.dataset.output_dim),
        )
        # self.initialize_weights(run.initialization)

    def forward(self, x):
        predictions = self.model(x)
        return predictions

    def initialize_weights(self, initialization="default"):
        for m in self.modules():
            if isinstance(m, nn.Linear) and initialization == "xavier_normal":
                nn.init.xavier_normal_(m.weight, 1)

            if isinstance(m, nn.Linear) and initialization == "xavier_uniform":
                nn.init.xavier_uniform_(m.weight, 1)

            if isinstance(m, nn.Linear) and initialization == "kaiming_normal":
                nn.init.kaiming_normal_(m.weight, mode="fan_in")

            if isinstance(m, nn.Linear) and initialization == "kaiming_uniform":
                nn.init.kaiming_uniform_(m.weight, mode="fan_in")


def select_model(run, device):
    model_type = run.model.model_type
    if model_type == "linear_regression":
        model = LinearRegression(run)
    elif model_type == "logistic_regression":
        model = LogisticRegression(run)
    elif model_type == "MLP":
        model = MLP(run)
    elif model_type == "CNN":
        model = CNN(run)
    elif model_type == "deepMLP":
        model = DeepMLP(run)
    elif model_type == "vgg11":
        model = torchvision.models.vgg11(
            num_classes=run.dataset.output_dim,
            dropout=0.0,
        )
        if run.dataset.image_colour_channels == 1:
            model.features[0] = torch.nn.Conv2d(1, 64, kernel_size=3, padding=1)
            model.features[2] = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=1)
    elif model_type == "resnet18":
        model = torchvision.models.resnet18(
            num_classes=run.dataset.output_dim, norm_layer=nn.Identity
        )
    elif model_type == "resnet34":
        model = torchvision.models.resnet34(
            num_classes=run.dataset.output_dim, norm_layer=nn.Identity
        )
        if run.dataset.image_colour_channels == 1:
            model.conv1 = torch.nn.Conv2d(
                1, 64, kernel_size=7, stride=2, padding=3, bias=False
            )
        if run.model.bias is False:
            in_features = model.fc.in_features
            out_features = model.fc.out_features
            model.fc = torch.nn.Linear(in_features, out_features, bias=False)

    elif model_type == "resnet34-leakyrelu":
        model = torchvision.models.resnet34(
            num_classes=run.dataset.output_dim, norm_layer=nn.Identity
        )
        replace_activation(
            model, nn.ReLU, nn.LeakyReLU, negative_slope=0.01, inplace=False
        )
        model = adapt_resnet_for_cifar(
            model
        )  # now conv1 is 3x3 stride1, maxpool maintained
        print_spatial_shapes(model, input_size=(1, 3, 32, 32))

    elif model_type == "resnet34-init":
        model = torchvision.models.resnet34(
            num_classes=run.dataset.output_dim, norm_layer=nn.Identity
        )
        scale_biases(model, 20)

    elif model_type == "densenet121":
        model = torchvision.models.densenet121(
            num_classes=run.dataset.output_dim, drop_rate=0.0
        )
        if run.dataset.image_colour_channels == 1:
            model.features.conv0 = torch.nn.Conv2d(
                1, 64, kernel_size=3, stride=2, padding=1, bias=False
            )
            model.features.pool0 = torch.nn.MaxPool2d(
                kernel_size=5, stride=1, padding=0
            )
    elif model_type == "wide_resnet50_2":
        model = torchvision.models.resnet.wide_resnet50_2(
            num_classes=run.dataset.output_dim, norm_layer=nn.Identity
        )
        if run.dataset.image_colour_channels == 1:
            model.conv1 = torch.nn.Conv2d(
                1, 64, kernel_size=7, stride=2, padding=3, bias=False
            )
    elif model_type == "convnext_tiny":
        model = torchvision.models.convnext_tiny(num_classes=run.dataset.output_dim)
    elif model_type == "mobilenetV2":
        model = torchvision.models.mobilenetv2.MobileNetV2(
            num_classes=run.dataset.output_dim
        )
    elif model_type == "swin_t":
        model = torchvision.models.swin_transformer.swin_t(
            num_classes=run.dataset.output_dim
        )
    elif model_type == "maxvit_t":
        model = torchvision.models.maxvit.maxvit_t(
            num_classes=run.dataset.output_dim,
        )
        # Get the original first conv layer parameters
        original_conv = model.stem[0][0]  # First conv layer in stem

        # Create new first conv layer with stride=1 instead of stride=2
        new_first_conv = nn.Conv2d(
            in_channels=3,
            out_channels=original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=1,  # Changed from 2 to 1
            padding=original_conv.padding,
            bias=original_conv.bias is not None,
        )
        # Replace the first conv layer
        model.stem[0][0] = new_first_conv
    elif model_type == "efficientnet":
        model = EfficientNet.from_name(
            "efficientnet-b0",
            num_classes=run.dataset.output_dim,
            in_channels=run.dataset.image_colour_channels,
        )
    elif model_type == "efficientnet_v2_s":
        model = torchvision.models.efficientnet.efficientnet_v2_s(
            num_classes=run.dataset.output_dim,
        )
    elif model_type == "tinyVIT":
        if (
            run.dataset.name == "CIFAR10"
            or run.dataset.name == "CIFAR100"
            or run.dataset.name == "SVHN"
        ):
            model = vit.SimpleViT(
                image_size=run.dataset.image_height,
                patch_size=8,
                num_classes=run.dataset.output_dim,
                dim=256,
                depth=4,
                heads=8,
                mlp_dim=512,
            )
        elif run.dataset.name == "MNIST" or run.dataset.name == "EMNIST":
            model = vit.SimpleViT(
                image_size=run.dataset.image_height,
                patch_size=7,
                num_classes=run.dataset.output_dim,
                dim=256,
                depth=4,
                heads=8,
                mlp_dim=512,
            )
        elif run.dataset.name == "Imagenet":
            model = vit.SimpleViT(
                image_size=run.dataset.image_height,
                patch_size=32,
                num_classes=run.dataset.output_dim,
                dim=192,
                depth=12,
                heads=3,
                mlp_dim=768,
            )
        elif run.dataset.name == "Imagenette":
            print("image size", run.dataset.image_height)
            model = vit.SimpleViT(
                image_size=run.dataset.image_height,
                patch_size=20,
                num_classes=run.dataset.output_dim,
                dim=256,
                depth=4,
                heads=8,
                mlp_dim=512,
            )
    else:
        raise ValueError("Not a valid model")
    summary(model)
    print("How many GPU's? ", torch.cuda.device_count(), flush=True)
    if torch.cuda.device_count() > 1:
        print("Let's use ", torch.cuda.device_count(), " GPUs!", flush=True)
        model = nn.DataParallel(model)
    model.to(device)
    return model


def replace_activation(
    model,
    orig_activation_fn=nn.ReLU,
    new_activation_fn_class=nn.LeakyReLU,
    **new_activation_kwargs,
):

    replaced_count = 0
    for name, module in model.named_children():
        if isinstance(module, orig_activation_fn):
            # Create a NEW instance each time
            new_activation = new_activation_fn_class(**new_activation_kwargs)
            print(f"Replacing {name}: {module} with {new_activation}")
            setattr(model, name, new_activation)
            replaced_count += 1
        else:
            replaced_count += replace_activation(
                module,
                orig_activation_fn,
                new_activation_fn_class,
                **new_activation_kwargs,
            )
    return replaced_count


def replace_max_with_avg_pooling(model):
    for name, module in model.named_children():
        if isinstance(module, nn.MaxPool2d):
            avg_pool_layer = nn.AvgPool2d(
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
            )
            setattr(model, name, avg_pool_layer)
        elif isinstance(module, nn.Module):
            replace_max_with_avg_pooling(module)


def scale_parameters(model, scale_factor=10.0):
    """Multiply all parameters in the model by a scale factor"""
    stats = {"total_params": 0, "layers_modified": []}

    def _scale_params(module):
        for param in module.parameters():
            if param is not None:
                param.data *= scale_factor
                stats["total_params"] += 1

    # Scale all parameters
    model.apply(_scale_params)

    # Print verification
    print(f"\nParameter Scaling Statistics:")
    print(f"Scale factor: {scale_factor}")
    print(f"Total parameters scaled: {stats['total_params']}")
    print("\nVerification of random samples:")

    # Verify scaling by sampling a few layers
    for name, module in model.named_modules():
        if hasattr(module, "weight") and module.weight is not None:
            param_mean = module.weight.data.mean().item()
            param_std = module.weight.data.std().item()
            print(f"\nLayer: {name}")
            print(f"Weight shape: {list(module.weight.shape)}")
            print(f"Mean: {param_mean:.6f}")
            print(f"Std: {param_std:.6f}")

    return model


def scale_biases(model, scale_factor=10.0):
    """
    Multiply only bias parameters in the model by `scale_factor`.
    Scales any `module.bias` that is not None (Linear, Conv2d, BatchNorm, etc).
    """
    total_scaled = 0
    for name, module in model.named_modules():
        if hasattr(module, "bias") and module.bias is not None:
            # in-place scaling
            module.bias.data.mul_(scale_factor)
            total_scaled += module.bias.numel()
            param_mean = module.bias.data.mean().item()
            param_std = module.bias.data.std().item()
            print(f"\nLayer: {name}")
            print(f"Weight shape: {list(module.bias.shape)}")
            print(f"Mean: {param_mean:.6f}")
            print(f"Std: {param_std:.6f}")
    print(f"Scaled {total_scaled} bias parameters")
    return model
