import torch
import torch.nn as nn
import torchvision.datasets
from PIL.Image import Image
from setproctitle import setproctitle
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.utils.data
import torchvision.transforms.v2 as transforms
import torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps, ImageFilter
import os
import pickle
import yaml

from iclr.common import *

device = torch.device("cuda")


class CMNISTDataset(Dataset):
    """
    CMNIST dataset class

    Args:
        root_dir (str): root directory for datset
        train (bool, optional): used for train
        transform (callable, optional): default is None
        no_aug (bool, optional): do not apply augmentation
    """

    def __init__(self, root_dir, train=True, transform=None, no_aug=False):
        self.root_dir = root_dir
        self.transform = transform
        self.train = train
        self.no_aug = no_aug
        if train:
            self.dataset_size = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
        else:
            self.dataset_size = [980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009]
        self.dataset_type = dataset_type

        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
            ])

        self.labels = []
        self.img = []
        mean = 0.0
        var = 0.0
        for t in range(self.dataset_type):
            for item in range(self.dataset_size[t]):
                self.labels.append(t)
                image = Image.open(
                    os.path.join(self.root_dir, f'{'train' if train else 'test'}/{t}/{item}.png')).convert('RGB')
                self.img.append(image)
                calc = self.transform(image)
                mean += calc.mean(dim=[1, 2])
                var += calc.var(dim=[1, 2])

        mean /= len(self.img)
        var /= len(self.img)
        std = torch.sqrt(var)
        print(mean, std)

        noise_var = 0.1
        degrees = 10
        self.augment = transforms.Compose([
            transforms.RandomResizedCrop(input_size, scale=(0.7, 1.0)),
            # transforms.RandomHorizontalFlip(p=0.5),
            # transforms.RandomApply(
            #     [transforms.ColorJitter(brightness=0.4, contrast=0,
            #                             saturation=0, hue=0)],
            #     p=0.8
            # ),
            # transforms.RandomGrayscale(p=0.2),
            # GaussianBlur(p=1.0),
            # Solarization(p=0.0),
            # transforms.RandomRotation(degrees=degrees),
            transforms.ToTensor(),
            transforms.RandomApply(
                [transforms.GaussianNoise(sigma=noise_var)],
                p=0.5
            ),
            transforms.Normalize(mean=mean, std=std)
        ])
        self.augment_prime = transforms.Compose([
            transforms.RandomResizedCrop(input_size, scale=(0.7, 1.0)),
            # transforms.RandomHorizontalFlip(p=0.5),
            # transforms.RandomApply(
            #     [transforms.ColorJitter(brightness=0.4, contrast=0,
            #                             saturation=0, hue=0)],
            #     p=0.8
            # ),
            # transforms.RandomGrayscale(p=0.2),
            # GaussianBlur(p=0.1),
            # Solarization(p=0.2),
            # transforms.RandomRotation(degrees=degrees),
            transforms.ToTensor(),
            transforms.RandomApply(
                [transforms.GaussianNoise(sigma=noise_var)],
                p=0.5
            ),
            transforms.Normalize(mean=mean, std=std)
        ])
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        image_original = self.img[idx]
        if self.no_aug:
            return self.transform(image_original), self.labels[idx]
        else:
            return self.augment(image_original), self.augment_prime(image_original)


class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


def get_cmnist_dataloaders(root_dir, train_batch_size, test_batch_size):
    """
    Args:
        root_dir (str): dataset's root directory
        train_batch_size (int): batch size of train dataset
        test_batch_size (int): batch size of test dataset

    Returns:
        tuple: (train_loader, test_loader, train_noaug_loader)
    """
    eval_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = CMNISTDataset(
        root_dir=root_dir,
        train=True,
    )

    test_dataset = CMNISTDataset(
        root_dir=root_dir,
        train=False,
        transform=eval_transform,
        no_aug=True
    )

    train_noaug_dataset = CMNISTDataset(
        root_dir=root_dir,
        train=True,
        no_aug=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=6,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False,
        num_workers=6,
        pin_memory=True,
    )

    train_noaug_loader = DataLoader(
        train_noaug_dataset,
        batch_size=train_batch_size,
        shuffle=False,
        num_workers=6,
        pin_memory=True,
    )
    return train_loader, test_loader, train_noaug_loader


class BarlowTwins(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.resnet18(zero_init_residual=True, weights=None)
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()
        self.size_in = self.backbone.fc.in_features
        self.size_out = 128
        self.backbone.fc = nn.Identity()

        # projector
        sizes = [self.size_in, self.size_out, self.size_out]
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

    def forward(self, x, use_projector=True):
        if not use_projector:
            return self.backbone(x)
        z = self.projector(self.backbone(x))
        return z


def import_dataset(train_batch_size, test_batch_size, ratio):
    root_dir = f"/home/dataset/cmnist/{ratio}"  # 데이터셋 루트 디렉토리

    # DataLoader
    train_loader, test_loader, train_noaug_loader = get_cmnist_dataloaders(
        root_dir=root_dir,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
    )
    print(f"train dataset size: {len(train_loader.dataset)}")
    print(f"test dataset size: {len(test_loader.dataset)}")
    return train_loader, test_loader, train_noaug_loader


def train_eval(config, train, test, train_noaug, epochs, file_path, lambd=None):
    model = BarlowTwins().to(device)
    size_in = model.size_in
    size_out = model.size_out

    scale_weights(model, float(config["scaling"]), scale_bias=True)
    base_lr = float(config["lr"])
    if method == "BarlowTwins":
        loss_fn = BTLoss(use_batchnorm=True, in_feature=size_out, lambd=float(lambd)).to(device)
    elif method == "VICReg":
        loss_fn = VICRegLoss().to(device)
    else:
        loss_fn = SimCLRLoss(batch_size, temperature=float(config["temperature"])).to(device)
    optimizer_main = torch.optim.SGD(model.parameters(), lr=base_lr, weight_decay=weight_decay, momentum=0.9)
    # optimizer_main = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
    linear_regression = nn.Linear(size_in, dataset_type).to(device)
    state_dict = linear_regression.state_dict()
    # learning_rate_schedular = torch.optim.lr_scheduler.ExponentialLR(gamma=0.995, optimizer=optimizer_main)

    save_point = [69.9, 70.2]
    current_save = 0

    bar = tqdm(range(epochs), mininterval=10, maxinterval=10)

    def update_accuracy_lin(current_steps):
        model.eval()
        linear_regression.load_state_dict(state_dict)
        optimizer_classifier = torch.optim.Adam(linear_regression.parameters())
        linear_regression.train()
        bar = tqdm(range(30), mininterval=10, maxinterval=10)
        for _ in bar:
            total = 0.0
            sum = 0.0
            for (image, label) in train_noaug:
                with torch.no_grad():
                    image = image.to(device, non_blocking=True)
                    labels = label.to(device, non_blocking=True)
                    outputs = model(image, use_projector=False)
                outputs = linear_regression(outputs)
                loss = F.cross_entropy(outputs, labels)
                optimizer_classifier.zero_grad()
                loss.backward()
                optimizer_classifier.step()
                total += image.shape[0]
                sum += image.shape[0] * loss.item()
            bar.set_description(f"loss linear : {sum / total}", refresh=False)

        linear_regression.eval()
        acc_test = 0.0
        total_test = 0.0
        with torch.no_grad():
            for (image, label) in tqdm(test):
                image = image.to(device, non_blocking=True)
                labels = label.to(device, non_blocking=True)

                outputs = linear_regression(model(image, use_projector=False))
                pred = torch.argmax(outputs, dim=1)

                correct = pred == labels

                acc_test += correct.float().sum().item()
                total_test += labels.shape[0]
        acc_test = acc_test / total_test * 100.0
        write_realtime(file_path, "accuracy", acc_test)
        write_realtime(file_path, "steps", current_steps)
        model.train()
        return acc_test

    update_accuracy_lin(current_steps=0)
    step = 1
    model.train()
    for _ in bar:
        total_size = 0.0
        loss_total = 0.0
        losses = []
        for (image1, image2) in train:
            image1 = image1.to(device, non_blocking=True)
            image2 = image2.to(device, non_blocking=True)
            outputs = model(torch.cat([image1, image2], dim=0))
            z1, z2 = outputs.chunk(2, dim=0)
            c, loss = loss_fn(z1, z2)
            optimizer_main.zero_grad()
            loss.backward()
            losses.append(loss.item())
            optimizer_main.step()

            total_size += image1.shape[0]
            loss_total += loss.item()
            if step % configs["report_frequency"] == 0:
                with torch.no_grad():
                    eigenvalues = (torch.linalg.svdvals(c) ** 2).cpu()
                    write_realtime_seperate(file_path, "eig", eigenvalues)
                acc_test = update_accuracy_lin(current_steps=step)
                if current_save < len(save_point) and acc_test >= save_point[current_save]:
                    torch.save(model.state_dict(), f"{file_path}/{save_point[current_save]:.2f}.pt")
                    current_save += 1
            step += 1
            # learning_rate_schedular.step()
        write_realtime_list(file_path, "loss", losses)


with open("config.yaml", 'r', encoding='utf-8') as f:
    configs = yaml.safe_load(f)

method = "BarlowTwins"
model = "ResNet-18"
dataset_type = 10
input_size = configs["input_size"]
for r in configs["ratios"]:
    losses = []
    acc = []
    acc_types = []
    epoch_lists = []
    lambd = "0.05"
    specific_config = configs[method]["lambdas"][lambd][model]
    epochs = specific_config["epochs"]

    weight_decay = float(configs["weight_decay"])
    batch_size = int(configs["batch_size"])

    setproctitle(f"cmnist-{r}")
    print(f"lambda: {lambd}, ratio: {r}")
    seed = int(configs[method]["seed"])
    set_seed(seed)
    train, test, train_noaug = import_dataset(train_batch_size=batch_size, test_batch_size=512, ratio=r)
    set_seed(seed)
    train_eval(specific_config, train, test, train_noaug, epochs, f"datas/lambda_{lambd}_{seed}/{r}", lambd=lambd)
