# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from torchvision.models import resnet18
from torchvision.models import resnet50

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlockwithLoggingBeforeReLU(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, layer_index=0, stride_index=-1):
        super(BasicBlockwithLoggingBeforeReLU, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
        self.layer_index = layer_index
        self.stride_index = stride_index

    def forward(self, bundled_x):
        # separate tupled inputs
        x, out_before_relu = bundled_x

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)

        if self.layer_index == 4 and self.stride_index == 1:
            out_before_relu = F.avg_pool2d(out, 4)
            out_before_relu = out_before_relu.view(out_before_relu.size(0), -1)
        else:
            out_before_relu = None

        out = F.relu(out)        
        return out, out_before_relu

class ResNetfromScratch(nn.Module):
    def __init__(self, block, num_blocks):
        super(ResNetfromScratch, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, layer_index=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, layer_index=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, layer_index=3)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, layer_index=4)

        self.num_features = 512

    def _make_layer(self, block, planes, num_blocks, stride, layer_index):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []

        for stride_index, stride in enumerate(strides):
            layers.append(block(self.in_planes, planes, stride, layer_index, stride_index))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out_before_relu = None
        out, out_before_relu = self.layer1((out, out_before_relu))
        out, out_before_relu = self.layer2((out, out_before_relu))
        out, out_before_relu = self.layer3((out, out_before_relu))
        out, out_before_relu = self.layer4((out, out_before_relu))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return out, out_before_relu

class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)
        out = F.relu(out)        
        return out

class ResNetwithFinalLinear(nn.Module):
    def __init__(self, block, num_blocks):
        super(ResNetwithFinalLinear, self).__init__()
        self.inplanes = 64

        # self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
        #                        stride=1, padding=2, bias=False)
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)

        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # self.final_linear = nn.Identity()
        self.final_linear = nn.Linear(512, 512, bias=True)
        self.num_features = 512

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []

        for stride_index, stride in enumerate(strides):
            layers.append(block(self.inplanes, planes, stride))
            self.inplanes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = F.avg_pool2d(out, 4)
        
        out = out.view(out.size(0), -1)

        # adding a final layer
        out = self.final_linear(out)
        return out

def resnet18_from_scratch():
    return ResNetfromScratch(BasicBlockwithLoggingBeforeReLU, [2, 2, 2, 2])

def resnet18_with_final_linear():
    final_model = ResNetwithFinalLinear(BasicBlock, [2, 2, 2, 2])

    # test
    # torchvision_model = resnet18(weights=None)
    # compare_pytorch_models_detailed(final_model, torchvision_model, check_forward_pass=True)
    return final_model


def compare_pytorch_models_detailed(model1: nn.Module, model2: nn.Module,
                                     input_shape=(1, 3, 224, 224),
                                     atol=1e-6, rtol=1e-5,
                                     check_forward_pass=True):
    """
    Compares two PyTorch models (nn.Module instances) in detail,
    identifying differences in:
    1. Module types and names.
    2. Module attributes (e.g., kernel_size, stride, out_features).
    3. Parameter and buffer values.
    4. Outputs from a forward pass (optional).
    """
    different_modules = []

    # Ensure models are in evaluation mode for consistent BatchNorm/Dropout behavior
    model1.eval()
    model2.eval()

    # Move models to CPU to avoid potential CUDA issues during direct comparison
    model1_cpu = model1.cpu()
    model2_cpu = model2.cpu()

    # 1. Compare named modules (architecture and attributes)
    modules1 = dict(model1_cpu.named_modules())
    modules2 = dict(model2_cpu.named_modules())

    all_module_names = set(modules1.keys()).union(modules2.keys())

    print("--- Detailed Module Comparison ---")
    for name in sorted(list(all_module_names)):
        module1 = modules1.get(name)
        module2 = modules2.get(name)

        if module1 is None:
            different_modules.append(f"Missing module in model1: '{name}' (Type: {type(module2).__name__})")
            continue
        if module2 is None:
            different_modules.append(f"Missing module in model2: '{name}' (Type: {type(module1).__name__})")
            continue

        if type(module1) != type(module2):
            different_modules.append(f"Type mismatch for module '{name}': Model1 is {type(module1).__name__}, Model2 is {type(module2).__name__}")
            continue

        # Compare module attributes (e.g., kernel_size, stride, out_features)
        attrs1 = {k: v for k, v in module1.__dict__.items() if not k.startswith('_') and not callable(v)}
        attrs2 = {k: v for k, v in module2.__dict__.items() if not k.startswith('_') and not callable(v)}

        attr_diff = False
        for k in set(attrs1.keys()).union(attrs2.keys()):
            v1 = attrs1.get(k)
            v2 = attrs2.get(k)

            # Handle None values first
            if v1 is None and v2 is not None:
                different_modules.append(f"Attribute '{k}' differs for module '{name}': Model1 has None, Model2 has {v2}")
                attr_diff = True
                continue
            if v1 is not None and v2 is None:
                different_modules.append(f"Attribute '{k}' differs for module '{name}': Model1 has {v1}, Model2 has None")
                attr_diff = True
                continue
            
            # If both are None, they are identical for this attribute
            if v1 is None and v2 is None:
                continue

            # Now, both v1 and v2 are not None. Proceed with comparison based on type.
            if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
                if not torch.allclose(v1, v2, atol=atol, rtol=rtol):
                    different_modules.append(f"Attribute '{k}' differs for module '{name}': {v1} vs {v2} (Tensor diff)")
                    attr_diff = True
            elif v1 != v2: # This is safe now because we've handled Tensors and Nones
                different_modules.append(f"Attribute '{k}' differs for module '{name}': {v1} vs {v2}")
                attr_diff = True

        if attr_diff:
            continue # Move to next module, already logged a difference

        # Compare parameters and buffers for the current module
        params_buffers1 = dict(module1.named_parameters(recurse=False)).items()
        params_buffers2 = dict(module2.named_parameters(recurse=False)).items()

        buffers1 = dict(module1.named_buffers(recurse=False)).items()
        buffers2 = dict(module2.named_buffers(recurse=False)).items()

        all_state_keys = set(dict(params_buffers1).keys()).union(dict(params_buffers2).keys()).union(dict(buffers1).keys()).union(dict(buffers2).keys())

        param_buffer_diff = False
        for key in all_state_keys:
            # CORRECTED LINE HERE
            val1 = dict(params_buffers1).get(key)
            if val1 is None: # If not found in parameters, check buffers
                val1 = dict(buffers1).get(key)

            val2 = dict(params_buffers2).get(key)
            if val2 is None: # If not found in parameters, check buffers
                val2 = dict(buffers2).get(key)

            if val1 is None: # Now we know if it was truly missing after checking both
                different_modules.append(f"Missing parameter/buffer '{key}' in module '{name}' (Model1)")
                param_buffer_diff = True
                continue
            if val2 is None:
                different_modules.append(f"Missing parameter/buffer '{key}' in module '{name}' (Model2)")
                param_buffer_diff = True
                continue

            # Ensure both are tensors before comparing shapes/values, or handle non-tensor types if expected
            if not (isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor)):
                different_modules.append(f"Type mismatch for parameter/buffer '{key}' in module '{name}': Model1 is {type(val1)}, Model2 is {type(val2)}")
                param_buffer_diff = True
                continue

            if val1.shape != val2.shape:
                different_modules.append(f"Shape mismatch for parameter/buffer '{key}' in module '{name}': {val1.shape} vs {val2.shape}")
                param_buffer_diff = True
            elif not torch.allclose(val1, val2, atol=atol, rtol=rtol):
                different_modules.append(f"Value mismatch for parameter/buffer '{key}' in module '{name}'. Max diff: {(val1 - val2).abs().max().item()}")
                param_buffer_diff = True

        if param_buffer_diff:
            continue


    if check_forward_pass and not different_modules:
        print("\n--- Forward Pass Output Comparison ---")
        try:
            dummy_input = torch.randn(input_shape)

            if torch.cuda.is_available():
                model1_cuda = model1.cuda()
                model2_cuda = model2.cuda()
                dummy_input = dummy_input.cuda()
            else:
                model1_cuda = model1
                model2_cuda = model2

            with torch.no_grad():
                output1 = model1_cuda(dummy_input)
                output2 = model2_cuda(dummy_input)

            if not torch.allclose(output1, output2, atol=atol, rtol=rtol):
                different_modules.append(f"Final output from forward pass differs. Max absolute difference: {(output1 - output2).abs().max().item()}")
            else:
                print("Final forward pass outputs are identical.")
        except Exception as e:
            different_modules.append(f"Error during forward pass comparison: {e}")
            print(f"Error during forward pass comparison: {e}")
    elif check_forward_pass and different_modules:
        print("\nSkipping forward pass comparison due to structural/parameter differences already found.")

    if different_modules:
        print("\n--- Summary of Differences Found ---")
        for diff in different_modules:
            print(f"- {diff}")
        print(f"\nTotal differences found: {len(different_modules)}")
    else:
        print("\nNo differences found between the models (architecture, parameters, and optional forward pass).")

    return different_modules

__all__ = ["resnet18", "resnet50", "resnet18_from_scratch", "resnet18_with_final_linear"]


