import torch
import torch.nn as nn
import torchvision.datasets
from PIL.Image import Image
from setproctitle import setproctitle
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import torch.utils.data
import torchvision.transforms.v2 as transforms
import torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps, ImageFilter
import os
import random

sns.set_theme(style="whitegrid", font="cmr10", font_scale=1.6)
plt.rcParams["axes.formatter.use_mathtext"] = True
colors = ["blue", "red", "green", "yellow"]
device = torch.device("cuda")
size = 42
dataset_type = 10
setproctitle("cmnist")


class CMNISTDataset(Dataset):
    """
    CMNIST dataset class

    Args:
        root_dir (str): root directory for datset
        train (bool, optional): used for train
        transform (callable, optional): default is None
        no_aug (bool, optional): do not apply augmentation
    """

    def __init__(self, root_dir, train=True, transform=None, no_aug=False):
        self.root_dir = root_dir
        self.transform = transform
        self.train = train
        self.no_aug = no_aug
        if train:
            self.dataset_size = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
        else:
            self.dataset_size = [980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009]
        self.dataset_type = dataset_type

        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
            ])
        self.augment = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.25, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
        ])
        self.augment_prime = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.25, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
        ])

        self.labels = []
        self.img = []

        for t in range(self.dataset_type):
            for item in range(self.dataset_size[t]):
                self.labels.append(t)
                image = Image.open(
                    os.path.join(self.root_dir, f'{'train' if train else 'test'}/{t}/{item}.png')).convert('RGB')
                self.img.append(image)

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        image_original = self.img[idx]
        if self.no_aug:
            return self.transform(image_original), self.labels[idx]
        else:
            return self.augment(image_original), self.augment_prime(image_original)


class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


def get_cmnist_dataloaders(root_dir, train_batch_size, test_batch_size):
    """
    Args:
        root_dir (str): dataset's root directory
        train_batch_size (int): batch size of train dataset
        test_batch_size (int): batch size of test dataset

    Returns:
        tuple: (train_loader, test_loader, train_noaug_loader)
    """
    eval_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = CMNISTDataset(
        root_dir=root_dir,
        train=True,
    )

    test_dataset = CMNISTDataset(
        root_dir=root_dir,
        train=False,
        transform=eval_transform,
        no_aug=True
    )

    train_noaug_dataset = CMNISTDataset(
        root_dir=root_dir,
        train=True,
        no_aug=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=6,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False,
        num_workers=6,
        pin_memory=True
    )

    train_noaug_loader = DataLoader(
        train_noaug_dataset,
        batch_size=train_batch_size * 4,
        shuffle=True,
        num_workers=6,
        pin_memory=True
    )
    return train_loader, test_loader, train_noaug_loader


class BarlowTwins(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.resnet18(zero_init_residual=True, weights=None)
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()
        self.size_in = self.backbone.fc.in_features
        self.size_out = 32
        self.backbone.fc = nn.Identity()

        # projector
        sizes = [self.size_in, self.size_out, self.size_out]
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

    def forward(self, x, use_projector=True):
        if not use_projector:
            return self.backbone(x)
        z = self.projector(self.backbone(x))
        return z


class BTLoss(nn.Module):
    def __init__(self, in_feature, lambd=1):
        super().__init__()
        self.bn = nn.BatchNorm1d(in_feature, affine=False)
        self.lambd = lambd

    def forward(self, x0, x1):
        c = self.bn(x0).T @ self.bn(x1)
        # c = c + c.T
        # c.div_(x0.shape[0] * 2)
        c.div_(x0.shape[0])
        output = c.clone()
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.lambd * off_diag
        return output, loss


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


def import_dataset(train_batch_size, test_batch_size, ratio=0.2):
    root_dir = f"/home/dataset/cmnist/{ratio}"  # 데이터셋 루트 디렉토리

    # DataLoader 생성
    train_loader, test_loader, train_noaug_loader = get_cmnist_dataloaders(
        root_dir=root_dir,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
    )
    print(f"train dataset size: {len(train_loader.dataset)}")
    print(f"test dataset size: {len(test_loader.dataset)}")
    return train_loader, test_loader, train_noaug_loader


def scaled_model(model, scale):
    assert scale != 0
    try:
        scale_weights(model, scale)
        yield model
    finally:
        scale_weights(model, 1 / scale)


@torch.no_grad()
def scale_weights(model, scale: float = 1) -> None:
    for name, param in model.named_parameters():
        if "weight" in name and scale != 1:
            transformed_param = param * scale
            param.copy_(transformed_param)


base_lr = 4e-6
weight_decay = 1e-6
batch_size = 128


def train_eval(train, test, train_noaug, epochs):
    model = BarlowTwins().to(device)
    scale_weights(model, 9e-2)
    size_in = model.size_in
    size_out = model.size_out
    loss_fn = BTLoss(size_out, lambd=5e-3).to(device)
    optimizer_main = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
    linear_regression = nn.Linear(size_in, dataset_type).to(device)
    state_dict = linear_regression.state_dict()

    losses = []
    accuracy = []
    accuracy_types = []
    epoch_list = []
    bar = tqdm(range(epochs), mininterval=10, maxinterval=10)

    def update_accuracy_lin(current_epoch):
        model.eval()
        linear_regression.load_state_dict(state_dict)
        optimizer_classifier = torch.optim.AdamW(linear_regression.parameters())
        linear_regression.train()
        bar = tqdm(range(12), mininterval=10, maxinterval=10)
        for _ in bar:
            total = 0.0
            sum = 0.0
            for (image, label) in train_noaug:
                with torch.no_grad():
                    image = image.to(device, non_blocking=True)
                    labels = label.to(device, non_blocking=True)
                    outputs = model(image, use_projector=False)
                outputs = linear_regression(outputs)
                loss = F.cross_entropy(outputs, labels)
                optimizer_classifier.zero_grad()
                loss.backward()
                optimizer_classifier.step()
                total += image.shape[0]
                sum += loss.item()
            bar.set_description(f"loss linear : {sum / total}", refresh=False)

        linear_regression.eval()
        acc_test = 0.0
        acc_test_types = torch.zeros(dataset_type, device=device)
        total_test = 0.0
        total_test_types = torch.zeros(dataset_type, device=device)
        with torch.no_grad():
            for (image, label) in tqdm(test):
                image = image.to(device, non_blocking=True)
                labels = label.to(device, non_blocking=True)

                outputs = linear_regression(model(image, use_projector=False))
                pred = torch.argmax(outputs, dim=1)

                correct = pred == labels

                acc_test += correct.float().sum().item()
                total_test += labels.shape[0]

                total_test_types += torch.bincount(labels, minlength=dataset_type)

                acc_test_types += torch.bincount(labels[correct.bool()], minlength=dataset_type)
        acc_test_types = (acc_test_types / total_test_types * 100.0).cpu()
        acc_test = acc_test / total_test * 100.0
        accuracy_types.append(acc_test_types)
        accuracy.append(acc_test)
        epoch_list.append(current_epoch)
        print(
            f"Test Accuracy: {acc_test:.4f}, types: {acc_test_types}")

    update_accuracy_lin(current_epoch=0)
    for epoch in bar:
        model.train()
        total_size = 0.0
        loss_total = 0.0
        for (image1, image2) in train:
            image1 = image1.to(device, non_blocking=True)
            image2 = image2.to(device, non_blocking=True)
            outputs = model(torch.cat([image1, image2], dim=0))
            z1, z2 = outputs.chunk(2, dim=0)
            loss = loss_fn(z1, z2)[1]
            optimizer_main.zero_grad()
            loss.backward()
            losses.append(loss.item())
            optimizer_main.step()

            total_size += image1.shape[0]
            loss_total += loss.item()
        bar.set_description(f"Loss: {(loss_total / total_size):.4f}", refresh=False)
        if (epoch + 1) % 1 == 0 or epoch == epochs - 1:
            update_accuracy_lin(current_epoch=epoch + 1)

    return losses, accuracy, accuracy_types, epoch_list


losses = []
acc = []
acc_types = []
epoch_lists = []
epochs = 60
ratios = [0.05, 0.10, 0.15]

for r in ratios:
    train, test, train_noaug = import_dataset(train_batch_size=batch_size, test_batch_size=512, ratio=r)
    loss, accuracy, accuracy_types, epoch_list = train_eval(train, test, train_noaug, epochs)
    losses.append(loss)
    acc.append(accuracy)
    acc_types.append(accuracy_types)
    epoch_lists.append(epoch_list)

    loss_array = np.array(loss)
    acc_array = np.array(accuracy)
    acc_types_array = np.array(accuracy_types)
    epoch_lists_array = np.array(epoch_list)
    np.save(f"datas/{r}_loss_array", loss_array)
    np.save(f"datas/{r}_acc_array", acc_array)
    np.save(f"datas/{r}_acc_types_array", acc_types_array)
    np.save(f"datas/{r}_epoch_lists_array", epoch_lists_array)
    print(loss)
    print(accuracy)
    print(accuracy_types)
    print(epoch_list)
