from get_data import *
from get_data_MNIST2Fashion import *
from get_data_MNIST2Fashion1class import get_MNIST2MNISTFashion1class_data, get_MNIST2MNISTFashion1sample_data
from get_data_MNIST2halfMNISThalfFashion import get_halfMNISThalfFashion, get_halfMNISThalfFashion_1over4, get_halfMNISThalfFashion_1over16, get_halfMNISThalfFashion_veritcal

import random 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from utils import *
from enum import Enum, auto
import torch
from torch import Tensor
from torch.nn.utils import parametrize
from torch.nn.modules import Module
from torch.nn import functional as F
from typing import Optional
import argparse
import os


parser = argparse.ArgumentParser(description='Train model on MNIST convert')
parser.add_argument('--save_dir', default="./MNISTmakeup_dipdnn", type=str, help='directory to save results') # MNISTmakeup_dipdnn  MNIST2MNISTFashion_dipdnn MNIST2MNISTFashion1class_dipdnn
parser.add_argument('--dataset', default='MNISTmakeup', type=str, help='dataset')  # MNISTmakeup MNIST2MNISTFashion MNIST2MNISTFashion1class

args = parser.parse_args()



def set_seed(seed):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU
    torch.backends.cudnn.deterministic = True  # Ensure reproducibility in convolutional layers
    torch.backends.cudnn.benchmark = False  # Disable benchmarking for reproducibility

# Set the seed before initializing the model or any random operations
set_seed(42)  # You can use any seed value of your choice

# Detect device: CUDA for Ubuntu GPU, MPS for MacBook with M1/M2/M3, or CPU fallback
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA (GPU) for training.")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Metal (MPS) for training on MacBook Pro with M1/M2/M3.")
else:
    device = torch.device('cpu')
    print("Using CPU for training.")
# device = torch.device('cpu')


if not os.path.isdir(args.save_dir):
    os.mkdir(args.save_dir)

#################################################################
if args.dataset  == "MNISTmakeup":
    # Load MNIST dataset
    data_path = '../data/'
    mnist_train = torchvision.datasets.MNIST(root=data_path, train=True, download=False, )
    mnist_test = torchvision.datasets.MNIST(root=data_path, train=False, download=False, )

    #----------------------------------------------------------
    # section: Create a subset of the dataset
    from torch.utils.data import Subset
    subset_indices = list(range(30))
    mnist_train = Subset(mnist_train, subset_indices)
    mnist_test = Subset(mnist_test, subset_indices)
    # --------------------------------------------------------

    # Define any additional transforms (e.g., normalization)
    additional_transforms = transforms.Compose([
        transforms.ToTensor(),
    ])

    # Create the paired datasets
    paired_mnist_train = PairedMNISTDataset(mnist_train, transform=additional_transforms)
    paired_mnist_test = PairedMNISTDataset(mnist_test, transform=additional_transforms)
    print(paired_mnist_train.original_dataset[0][0].size)
    print(paired_mnist_train.original_dataset[0][1])


    # Define batch size
    batch_size = 128

    # Create DataLoaders for training and testing
    train_loader = DataLoader(paired_mnist_train, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(paired_mnist_test, batch_size=batch_size, shuffle=False)

elif args.dataset  == "MNIST2MNISTFashion":
    mnist_to_fashionmnist_mapping = {
        0: 0,  # MNIST (0) -> FashionMNIST (T-shirt/top)
        1: 1,  # MNIST (1) -> FashionMNIST (Trouser)
        2: 2,  # MNIST (2) -> FashionMNIST (Pullover)
        3: 3,  # MNIST (3) -> FashionMNIST (Dress)
        4: 4,  # MNIST (4) -> FashionMNIST (Coat)
        5: 5,  # MNIST (5) -> FashionMNIST (Sandal)
        6: 6,  # MNIST (6) -> FashionMNIST (Shirt)
        7: 7,  # MNIST (7) -> FashionMNIST (Sneaker)
        8: 8,  # MNIST (8) -> FashionMNIST (Bag)
        9: 9   # MNIST (9) -> FashionMNIST (Ankle boot)
    }
    # 2. Define the transformation to convert PIL Images to Tensors
    transform = transforms.Compose([
        transforms.ToTensor()  # Converts PIL Image to Tensor and scales pixel values to [0, 1]
    ])

    data_path = '../data'
    mnist_train = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

    fashionmnist_train = datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transform)
    fashionmnist_test = datasets.FashionMNIST(root=data_path, train=False, download=True, transform=transform)

    # 6. Function to group FashionMNIST indices by label
    def group_fashionmnist_by_label(fashionmnist_dataset):
        label_to_indices = defaultdict(list)
        for idx, (image, label) in enumerate(fashionmnist_dataset):
            label_to_indices[label].append(idx)
        return label_to_indices

    # 7. Create the label to indices mapping for FashionMNIST
    fashion_label_to_indices_train = group_fashionmnist_by_label(fashionmnist_train)
    fashion_label_to_indices_test = group_fashionmnist_by_label(fashionmnist_test)


    # 9. Create the paired datasets
    paired_train_dataset = PairedMNISTFashionMNIST(
        mnist_dataset=mnist_train,
        fashionmnist_dataset=fashionmnist_train,
        label_mapping=mnist_to_fashionmnist_mapping,
        fashion_label_to_indices=fashion_label_to_indices_train,
        transform_input=None,  # Already transformed with ToTensor()
        transform_target=None, # Already transformed with ToTensor()
        shuffle=True
    )

    paired_test_dataset = PairedMNISTFashionMNIST(
        mnist_dataset=mnist_test,
        fashionmnist_dataset=fashionmnist_test,
        label_mapping=mnist_to_fashionmnist_mapping,
        fashion_label_to_indices=fashion_label_to_indices_test,
        transform_input=None,
        transform_target=None,
        shuffle=True
    )

    #----------------------------------------------------------
    # section: Create a subset of the dataset
    from torch.utils.data import Subset
    subset_indices = list(range(30))
    paired_train_dataset = Subset(paired_train_dataset, subset_indices)
    paired_test_dataset = Subset(paired_test_dataset, subset_indices)
    # --------------------------------------------------------

    # 10. Create DataLoaders for the paired datasets
    batch_size = 128

    train_loader = DataLoader(paired_train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(paired_test_dataset, batch_size=batch_size, shuffle=False)

elif args.dataset  == "MNIST2MNISTFashion1class":
    # train_loader, test_loader = get_MNIST2MNISTFashion1class_data(samples = 30, batch_size = 128)
    train_loader, test_loader = get_MNIST2MNISTFashion1sample_data(samples = 10, batch_size = 64)

elif args.dataset  == "halfMNISThalfFashion":
    train_loader, test_loader = get_halfMNISThalfFashion(samples = 30, batch_size = 64)

elif args.dataset  == "halfMNISThalfFashion_1over4":
    train_loader, test_loader = get_halfMNISThalfFashion_1over4(samples = 30, batch_size = 64)

elif args.dataset  == "halfMNISThalfFashion_1over16":
    train_loader, test_loader = get_halfMNISThalfFashion_1over16(samples = 30, batch_size = 64)

elif args.dataset  == "halfMNISThalfFashion_veritcal":
    train_loader, test_loader = get_halfMNISThalfFashion_veritcal(samples = 30, batch_size = 64)

#################################################################
##################################################

##################################################

# Example: Generate dataset with color inversion
# transformed_dataset_train = generate_dataset(transform_type='invert_colors', data=mnist_data_train)
# transformed_dataset_test = generate_dataset(transform_type='invert_colors', data=mnist_data_test)

# # Visualize the transformed examples (color inversion in this case)
# visualize_transformed_examples(transformed_dataset_train)
# visualize_transformed_examples(transformed_dataset_test)

# # Prepare the dataset for training the neural network
# inputs_train, outputs_train = prepare_for_nn(transformed_dataset_train)
# inputs_test, outputs_test = prepare_for_nn(transformed_dataset_test)

# # inputs and outputs are now ready to be fed into a neural network
# print(f"inputs_train: {inputs_train.shape}")
# print(f"outputs_train: {outputs_train.shape}")

# print(f"inputs_test: {inputs_test.shape}")
# print(f"outputs_test: {outputs_test.shape}")


# batch_size = 128
# train_tensor = TensorDataset(inputs_train, outputs_train)
# train_loader = DataLoader(train_tensor, batch_size=batch_size, shuffle=False)

# test_tensor = TensorDataset(inputs_test, outputs_test)
# test_loader = DataLoader(test_tensor, batch_size=batch_size, shuffle=False)


##################################################################################

##################################################################################

class SmallNetwork(nn.Module):
    def __init__(self, input_dim):
        super(SmallNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim)
        self.fc2 = nn.Linear(input_dim, input_dim)

        self.register_buffer('mask_tril', torch.tril(torch.ones_like(self.fc1.weight)))
        self.register_buffer('mask_triu', torch.triu(torch.ones_like(self.fc2.weight)))
        self.fc1.weight = nn.Parameter(torch.mul(self.fc1.weight, self.mask_tril))
        self.fc2.weight = nn.Parameter(torch.mul(self.fc2.weight, self.mask_triu))


        # Apply masks and ensure diagonal elements are ones
        with torch.no_grad():
            self.fc1.weight.data.mul_(self.mask_tril)
            torch.diagonal(self.fc1.weight.data).fill_(1.0)
            self.fc2.weight.data.mul_(self.mask_triu)
            torch.diagonal(self.fc2.weight.data).fill_(1.0)

        # Register hooks to enforce masking after each update
        self.fc1.weight.register_hook(lambda grad: grad * self.mask_tril)
        self.fc2.weight.register_hook(lambda grad: grad * self.mask_triu)

        self.negative_slope = 0.5

    # def forward(self, x):
    #     fc1_fwd = F.leaky_relu(self.fc1(x), negative_slope=self.negative_slope)  # 0.01
    #     fc2_fwd = F.leaky_relu(self.fc2(fc1_fwd), negative_slope=self.negative_slope)
    #     return fc2_fwd
    
    # def inverse(self, y):
    #     fc2_W_T = torch.linalg.inv(torch.mul(self.fc2.weight, self.mask_triu))
    #     fc2_inv = F.linear(F.leaky_relu(y, negative_slope=1/self.negative_slope) - self.fc2.bias, fc2_W_T)
    #     fc1_W_T = torch.linalg.inv(torch.mul(self.fc1.weight, self.mask_tril))
    #     fc1_inv = F.linear(F.leaky_relu(fc2_inv, negative_slope=1/self.negative_slope) - self.fc1.bias, fc1_W_T)
    #     return fc1_inv

    def forward(self, x):
        fc1_fwd = self.fc1(x)
        fc1_fwd = F.leaky_relu(fc1_fwd, negative_slope = self.negative_slope)
        fc2_fwd = self.fc2(fc1_fwd)
        fc2_fwd = F.leaky_relu(fc2_fwd, negative_slope = self.negative_slope)
        return fc2_fwd
    
    # def inverse(self, y):
    #     y1 = F.leaky_relu(y, negative_slope=1 / self.negative_slope) - self.fc2.bias
    #     y1 = y - self.fc2.bias
    #     fc2_W_T = torch.linalg.inv(self.fc2.weight)
    #     fc2_inv = F.linear(y1, fc2_W_T)

    #     y2 = F.leaky_relu(fc2_inv, negative_slope=1 / self.negative_slope) - self.fc1.bias
    #     fc1_W_T = torch.linalg.inv(self.fc1.weight)
    #     fc1_inv = F.linear(y2, fc1_W_T)
    #     return fc1_inv
    
    def inverse(self, y):
        batch_size = y.shape[0]
        # Inverse of the last activation
        y1 = F.leaky_relu(y, negative_slope=1 / self.negative_slope)
        # Subtract bias
        y1 = y1 - self.fc2.bias
        # Solve fc2.weight * x = y1 using triangular solve
        y1_unsqueezed = y1.unsqueeze(2)  # shape (batch_size, input_dim, 1)
        fc2_weight_expanded = self.fc2.weight.unsqueeze(0).expand(batch_size, -1, -1)
        fc2_inv, _ = torch.triangular_solve(y1_unsqueezed, fc2_weight_expanded, upper=True)
        fc2_inv = fc2_inv.squeeze(2)

        # Inverse of the first activation
        y2 = F.leaky_relu(fc2_inv, negative_slope=1 / self.negative_slope)
        # Subtract bias
        y2 = y2 - self.fc1.bias
        # Solve fc1.weight * x = y2 using triangular solve
        y2_unsqueezed = y2.unsqueeze(2)
        fc1_weight_expanded = self.fc1.weight.unsqueeze(0).expand(batch_size, -1, -1)
        fc1_inv, _ = torch.triangular_solve(y2_unsqueezed, fc1_weight_expanded, upper=False)
        fc1_inv = fc1_inv.squeeze(2)
        return fc1_inv

    def apply_weight_masks(self):
        with torch.no_grad():
            self.fc1.weight.data.mul_(self.mask_tril)
            torch.diagonal(self.fc1.weight.data).fill_(1.0)
            self.fc2.weight.data.mul_(self.mask_triu)
            torch.diagonal(self.fc2.weight.data).fill_(1.0)

    # def forward(self, x):
    #     # Forward pass with masked linear layers and Tanh activation
    #     z1 = self.fc1(x)
    #     z2 = torch.tanh(z1)
    #     z3 = self.fc2(z2)
    #     # y = torch.tanh(z3)
    #     return z3


    # def inverse(self, y):
    #     epsilon = 1e-6  # For numerical stability
        
    #     # Step 1: Inverse of second Linear Transformation
    #     z2_inv = F.linear(y - self.fc2.bias, torch.linalg.inv(self.fc2.weight))
        
    #     # Step 2: Inverse of Tanh Activation
    #     z1_inv_activation = 0.5 * torch.log((1 + z2_inv) / (1 - z2_inv))
        
    #     # Step 3: Inverse of first Linear Transformation
    #     x_inv = F.linear(z1_inv_activation - self.fc1.bias, torch.linalg.inv(self.fc1.weight))
        
    #     return x_inv
    

class DipDNN_Block(nn.Module):
    def __init__(self, input_dim, alpha=0.9):
        super(DipDNN_Block, self).__init__()
        # self.alpha = alpha  # Scaling coefficient
        self.residual_function = SmallNetwork(input_dim)

    def forward(self, x):
        return self.residual_function(x)

    def inverse(self, y):
        return self.residual_function.inverse(y)

    
    def apply_weight_masks(self):
        self.residual_function.apply_weight_masks()


class DipDNN(nn.Module):
    def __init__(self, input_dim, num_blocks):
        super(DipDNN, self).__init__()
        self.blocks = nn.ModuleList(
            [DipDNN_Block(input_dim) for _ in range(num_blocks)])
        self.input_dim = input_dim

    def forward(self, x):

        x = x.reshape(x.shape[0], -1)  # (batchsize, -1)

        for block in self.blocks:
            x = block(x)

        x = x.reshape(x.shape[0], 1, int(np.sqrt(self.input_dim)), int(np.sqrt(self.input_dim)))
        return x

    def inverse(self, y):
        y = y.reshape(y.shape[0], -1)  # (batchsize, -1)

        for block in reversed(self.blocks):
            y = block.inverse(y)

        y = y.reshape(y.shape[0], 1, int(np.sqrt(self.input_dim)), int(np.sqrt(self.input_dim)))
        return y

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    

    def apply_weight_masks(self):
        for block in self.blocks:
            block.apply_weight_masks()


# image_l = inputs_train.shape[-1]
image_l = 28 # paired_mnist_train.original_dataset[0][0].size[0]
input_dim = image_l * image_l  # Flattened MNIST images 28 * 28
num_blocks = 3
learning_rate = 1e-3
num_epochs = 2000
# hidden_dim = 128
# enforce_lipz = 'power_iteration'  # or 'power_iteration'


# Initialize the iResNet model
model = DipDNN(input_dim=input_dim, num_blocks=num_blocks).to(device)
print(f'Total trainable parameters: {model.count_parameters()}')


def l1_regularization(model, lambda_l1):
    l1_norm = sum(p.abs().sum() for p in model.parameters())
    return lambda_l1 * l1_norm

def l2_regularization(model, lambda_l2):
    l2_norm = sum(p.pow(2).sum() for p in model.parameters())
    return lambda_l2 * l2_norm

# Define the loss function (Mean Squared Error)
criterion = nn.MSELoss()
# criterion = nn.L1Loss()

# Define the optimizer (Adam)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-10, amsgrad=False)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# optimizer = optim.Adamax(model.parameters(), lr=learning_rate, weight_decay=1e-5)

reg_count = 0
# Training loop
mask_tril = torch.tril(torch.ones_like(model.blocks[0].residual_function.fc1.weight))
mask_triu = torch.triu(torch.ones_like(model.blocks[0].residual_function.fc2.weight))

for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = model(inputs)
        # inverse pass
        input_est = model.inverse(targets)
        # Compute loss
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.apply_weight_masks()

        epoch_loss += loss.item() * inputs.size(0)

        # print(model.blocks[0].residual_function.fc1.weight)
        ############################################################
        with torch.no_grad():
            # Regularize weights to prevent near-singular matrices
            for name, param in model.named_parameters():
                if 'weight' in name:
                    diag_elements = torch.diag(param)
                    threshold = 1e-3 # 1e-2
                    mask = torch.abs(diag_elements) < threshold
                    non_zero_value = 1e-1 # 1e-1
                    if torch.any(mask):
                        param[torch.eye(param.size(0)).bool()] = torch.where(
                            mask, torch.tensor(non_zero_value, device=param.device), diag_elements
                        )
                        reg_count += 1

            # Re-apply masks to ensure triangular structure
            # for block in model.blocks:
            #     block.residual_function.fc1.weight.mul_(mask_tril)
            #     block.residual_function.fc2.weight.mul_(mask_triu)
        #########################################################################



        # print(model.blocks[0].residual_function.fc1.weight)
        # exit()

    avg_loss = epoch_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch}/{num_epochs}], Forward Pass Loss: {avg_loss}')

    # Optionally, check Lipschitz continuity
    # model.check_lipz_continuity()
print("reg_count", reg_count)


# print(model.blocks[0].residual_function.fc1.weight)
# print(model.blocks[0].residual_function.fc2.weight)
# exit()

# Evaluate inverse error after training
model.eval()
with torch.no_grad():
    inverse_error = 0.0
    inverse_error_real = 0.0 
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward pass: input -> inverted
        outputs = model(inputs)

        # Inverse pass: f(x) -> input
        reconstructed_inputs = model.inverse(outputs)
        reconstructed_inputs_real = model.inverse(targets)
        # print("reconstructed_inputs", reconstructed_inputs)
        
        # Compute inverse error (MSE between original inputs and reconstructed inputs)
        loss = criterion(reconstructed_inputs, inputs)
        loss_real = criterion(reconstructed_inputs_real, inputs)

        inverse_error += loss.item() * inputs.size(0)
        inverse_error_real += loss_real.item() * inputs.size(0)

    avg_inverse_error = inverse_error / len(train_loader.dataset)
    avg_inverse_error_real = inverse_error_real / len(train_loader.dataset)
    print(f'Inverse Pass Error (MSE) after training: {avg_inverse_error}')
    print(f'Inverse Pass Error (MSE) after training real: {avg_inverse_error_real}')


visualize_results(model, train_loader, device, num_samples=3, save_path=f"{args.save_dir}/{args.dataset}_Samples_train.png")
visualize_results(model, test_loader, device, num_samples=3, save_path=f"{args.save_dir}/{args.dataset}_Samples_test.png")




