import torch
import torchvision
import random
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MonolithicModel(nn.Module):
    def __init__(self, m):
        super().__init__()

        self.net = nn.Sequential(
            nn.Sequential(
                nn.Flatten(start_dim=1, end_dim=-1),
                nn.Linear(m * 3 * 32 ** 2, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(512, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(512, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(128, m * 10),
                nn.Unflatten(dim=1, unflattened_size=(m, 10))
            )
        )

    def forward(self, x):
        out = self.net(x)
        return out


class ModularModel(nn.Module):
    def __init__(self, m, num_modules, bottleneck_size, name=None):
        super().__init__()
        self.name = name
        self.module_fc = nn.ModuleList()
        for _ in range(num_modules):
            self.module_fc.append(
                nn.Sequential(
                    nn.Flatten(start_dim=1, end_dim=-1),
                    nn.Linear(m * 3 * 32 ** 2, bottleneck_size),
                    nn.BatchNorm1d(bottleneck_size),
                    nn.ReLU(),
                )
            )

        self.module_shared = nn.Sequential(
            nn.Sequential(
                nn.Linear(bottleneck_size, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(128, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(),
            ),
        )

        self.out_fc = nn.Sequential(
            nn.Linear(num_modules * 64, m * 10),
            nn.Unflatten(dim=1, unflattened_size=(m, 10))
        )

    def forward(self, x):
        module_out = []
        for module in self.module_fc:
            module_out.append(self.module_shared(module(x)))

        out = self.out_fc(torch.cat(module_out, dim=1))

        return out


# Concatenate m images into input and labels
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, m, dataset_size):
        self.dataset = dataset
        self.m = m
        self.dataset_size = dataset_size
        self.indices = self._generate_indices()

    def _generate_indices(self):
        random.seed(99)  # Set a fixed seed for reproducibility
        indices = []
        while len(indices) < self.dataset_size * self.m:
            sample_indices = random.sample(range(len(self.dataset)), self.m)
            indices.extend(sample_indices)
        return indices[:self.dataset_size * self.m]

    def __getitem__(self, index):
        image_indices = self.indices[index * self.m: (index + 1) * self.m]
        images = []
        labels = []
        for i in image_indices:
            image, label = self.dataset[i]
            images.append(image)
            labels.append(label)
        images = torch.stack(images, dim=0)
        labels = torch.tensor(labels)
        one_hot_labels = torch.zeros((self.m, 10))
        one_hot_labels.scatter_(1, labels.view(-1, 1), 1)
        return images, one_hot_labels

    def __len__(self):
        return self.dataset_size


def train(model, train_loader, optimizer):
    model.train()

    losses = []
    accs = []
    iteration = 0
    for x, labels in train_loader:
        x = x.to(device)
        labels = labels.to(device)

        out = model(x)

        labels = labels.view(-1, labels.shape[-1])
        out = out.view(-1, out.shape[-1])

        # Compute the cross-entropy loss
        loss = F.cross_entropy(out, torch.argmax(labels, dim=1))

        # Compute the accuracy
        acc = (torch.argmax(out, dim=1) == torch.argmax(labels, dim=1)).float().mean()

        print('Train Iteration: ', iteration, ' Loss: ', loss.detach().cpu().item(), ' Accuracy: ',
              acc.detach().cpu().item())
        iteration += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
        accs.append(acc.detach().cpu().item())

    return np.mean(losses), np.mean(accs)


def test(model, test_loader):
    model.eval()

    losses = []
    accs = []
    iteration = 0
    for x, labels in test_loader:
        x = x.to(device)
        labels = labels.to(device)

        out = model(x)

        labels = labels.view(-1, labels.shape[-1])
        out = out.view(-1, out.shape[-1])

        # Compute the cross-entropy loss
        loss = F.cross_entropy(out, torch.argmax(labels, dim=1))

        # Compute the accuracy
        acc = (torch.argmax(out, dim=1) == torch.argmax(labels, dim=1)).float().mean()
        print('Test Iteration: ', iteration, ' Loss: ', loss.detach().cpu().item(), ' Accuracy: ',
              acc.detach().cpu().item())
        iteration += 1

        losses.append(loss.detach().cpu().item())
        accs.append(acc.detach().cpu().item())

    return np.mean(losses), np.mean(accs)


# Find Y^T K^{-1} Y in a differentiable way
def solve_linsys(K, Y):
    K = K.to(dtype=torch.float64)
    Y = Y.to(dtype=torch.float64)
    beta = torch.squeeze(torch.linalg.solve(K, Y.unsqueeze(1)), dim=1)
    reconstructed_Y = torch.einsum('ij,j->i', K, beta)

    val = torch.sum(beta * reconstructed_Y)
    grad = -torch.sum(torch.einsum('ij,j->i', K, beta.detach()) * beta.detach())
    return grad - grad.detach() + val.detach()


def learn_one_module(
        dataloader,
        inp_dim,
        m,
        b,
        lr=1e-2,
        sigma=20.0,
        eps=1e-5,
        num_classes=10,
):
    u_unnormed = torch.nn.Parameter(torch.randn(inp_dim, b).to(device, dtype=torch.float64))

    optimizer = torch.optim.Adam([u_unnormed], lr=lr)

    total_errors = [float('inf')]

    rand = torch.zeros((m * num_classes,)).to(device)
    rand[torch.randint(0, len(rand), size=(1,)).item()] = 1
    # Reshape to m x num_classes
    rand = rand.view(m, num_classes)

    batch_size = None
    for x, labels in dataloader:
        if batch_size is None:
            batch_size = x.shape[0]

        # Break if batch size is less than the usual
        if x.shape[0] != batch_size:
            break

        u = u_unnormed / torch.sqrt(torch.sum(u_unnormed ** 2, dim=0, keepdim=True))

        X = x.to(device).to(torch.float64).flatten(start_dim=1, end_dim=-1)
        Y_samp = labels.to(device)
        Y = torch.einsum('ijk, jk->i', Y_samp, rand)  # Shape (batch_size,)

        assert not torch.isnan(u).any()
        proj_u = torch.einsum('ij, jk->ik', X, u)  # Shape (batch_size, b), project along u

        K = torch.exp(torch.sum(-(proj_u.unsqueeze(0) - proj_u.unsqueeze(1)) ** 2, dim=2) / 2 / sigma ** 2)
        assert not torch.isnan(K).any()
        K = K + eps * torch.eye(K.shape[0]).to(device, dtype=torch.float64)  # For numerical stability

        # Solve Y = K beta
        error = solve_linsys(K, Y)  # K has shape (batch_size, batch_size)
        print('Error: ', error.detach().cpu().item())
        assert error > 0

        total_error = error
        total_error.backward()

        total_error = total_error.detach().cpu().item()

        assert not torch.isnan(u_unnormed.grad).any()

        optimizer.step()
        optimizer.zero_grad()
        total_errors.append(total_error)

    return u.detach().to(dtype=torch.float32)


# Initialize model with our method
def init_model(model, dataset, m, b):
    for i in range(len(model.module_fc)):
        u = learn_one_module(dataset, 32 ** 2 * 3 * m, m, b)

        with torch.no_grad():
            model.module_fc[i][1].weight.data.copy_(u.T)

    return model


def gen_datasets(m, train_dataset_size, test_dataset_size, batch_size, transform):
    # Load the CIFAR-10 training set
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

    # Create the ConcatDataset based on the CIFAR-10 training set
    trainset_concat = ConcatDataset(trainset, m, train_dataset_size)

    # Create the dataloader for the training set
    trainloader = torch.utils.data.DataLoader(trainset_concat, batch_size=batch_size, shuffle=True, num_workers=4,
                                              pin_memory=True)

    # Load the CIFAR-10 test set
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Create the ConcatDataset based on the CIFAR-10 test set
    testset_concat = ConcatDataset(testset, m, test_dataset_size)

    testloader = torch.utils.data.DataLoader(testset_concat, batch_size=batch_size, shuffle=False, num_workers=4,
                                             pin_memory=True)
    return trainloader, testloader
