import random
import numpy as np
import torch
import torchvision
from torchvision import datasets, transforms
import numpy
import time
from torch import nn
import os
import tqdm
import logging

if not os.path.exists('log_MLP'):
    os.makedirs('log_MLP')

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

RANDOM_SEED = 1  # any random number

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    os.environ['PYTHONHASHSEED'] = str(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False 

class MetricsLogger:
    def __init__(self, log_file='training.log'):
        self.log_file = log_file
        self.logger = logging.getLogger(log_file)
        self.logger.setLevel(logging.INFO)
        formatter = logging.Formatter('%(message)s')
        file_handler = logging.FileHandler(self.log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(file_handler)

    def log_metrics(self, train_acc, test_acc, iter, iter_loss, epoch):
        log_string = ("Train Acc: {:.4f}%, Test Acc: {:.4f}%, Iter: {:.0f}, Iter Loss: {:.6f}, Epoch: {:.0f}".format
                      (train_acc, test_acc, iter, iter_loss, epoch))
        print(log_string)
        self.logger.info(log_string)
    
    def log_metrics_extra_loss(self, train_acc, test_acc, iter, iter_loss, lj_loss, epoch):
        log_string = ("Train Acc: {:.4f}%, Test Acc: {:.4f}%, Iter: {:.0f}, Iter Loss: {:.6f}, LJ Loss: {:.6f}, Epoch: {:.0f}".format
                      (train_acc, test_acc, iter, iter_loss, lj_loss, epoch))
        print(log_string)
        self.logger.info(log_string)

    def log_direct(self, args):
        print(args)
        self.logger.info(args)

    def close(self):
        for handler in self.logger.handlers:
            handler.close()
            self.logger.removeHandler(handler)

import torch
from torch.nn.modules.loss import _Loss
import torch.nn as nn
import torch.nn.functional as F


def lj_loss(features, sigma=1.0, n=6, clamp_max=5.0):
    features_normalized = F.normalize(features, dim=2)
    cosine_sim = torch.matmul(features_normalized, features_normalized.transpose(1, 2))
    cosine_dist = 1 - cosine_sim
    diag_indices = torch.arange(cosine_dist.size(1), device=features.device)
    cosine_dist[:, diag_indices, diag_indices] = sigma

    cosine_dist = torch.clamp(cosine_dist, min=1e-3)
    term1 = (sigma / cosine_dist) ** (2 * n)
    term2 = ((sigma / cosine_dist) ** n)
    lj_potential = (term1 - term2)
    lj_potential = torch.clamp(lj_potential, max=clamp_max)
    lj_loss = torch.mean(lj_potential)
    return lj_loss


class LJLoss(nn.Module):
    """
    Default settings of LJ_loss
    """
    def __init__(self, epsilon=0.1, sigma=0.5, n=6, clamp_max=5.):
        super(LJLoss, self).__init__()
        self.sigma = sigma
        self.n = n
        self.clamp_max = clamp_max
        self.epsilon = epsilon

    def forward(self, feat):
        return self.epsilon * lj_loss(feat, sigma=self.sigma, n=self.n, clamp_max=self.clamp_max)

class MLPModel(nn.Module):
    def __init__(self, input_dim, num_classes, num_hidden_layers=0, hidden_dim=512, combine_strategy='average', enable_ljloss=False):
        super(MLPModel, self).__init__()
        self.num_hidden_layers = num_hidden_layers
        self.combine_strategy = combine_strategy
        self.enable_ljloss = enable_ljloss
        # Initialize LJLoss if needed
        if self.enable_ljloss:
            self.lj_loss = LJLoss()
        # Define the MLP heads
        self.heads = nn.ModuleList([self._create_mlp(input_dim, hidden_dim, num_hidden_layers) for _ in range(3)])
        # Define the classifier
        if combine_strategy == 'average':
            self.classifier = nn.Linear(hidden_dim, num_classes)
        elif combine_strategy == 'concat':
            self.classifier = nn.Linear(hidden_dim * 3, num_classes)

    def _create_mlp(self, input_dim, hidden_dim, num_hidden_layers):
        layers = []
        if num_hidden_layers == 0:
            # Direct mapping
            layers.append(nn.Linear(input_dim, hidden_dim))
        else:
            layers.append(nn.Linear(input_dim, hidden_dim))
            for _ in range(num_hidden_layers):
                layers.append(nn.ReLU())
                layers.append(nn.Linear(hidden_dim, hidden_dim))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x is of shape (batch_size, channels, height, width)
        # Split x into three channels
        x1 = x[:, 0, :, :].view(x.size(0), -1)  # H channel
        x2 = x[:, 1, :, :].view(x.size(0), -1)  # S channel
        x3 = x[:, 2, :, :].view(x.size(0), -1)  # V channel
        # Process each channel
        out1 = self.heads[0](x1)
        out2 = self.heads[1](x2)
        out3 = self.heads[2](x3)
        # Combine outputs
        outputs = torch.stack([out1, out2, out3], dim=1)  # Shape: (batch_size, 3, hidden_dim)
        if self.enable_ljloss:
            lj_val = self.lj_loss(outputs)
        else:
            lj_val = 0
        if self.combine_strategy == 'average':
            combined = outputs.mean(dim=1)  # Shape: (batch_size, hidden_dim)
        elif self.combine_strategy == 'concat':
            combined = outputs.view(outputs.size(0), -1)  # Shape: (batch_size, hidden_dim * 3)
        # Classify
        logits = self.classifier(combined)
        return logits, lj_val

def main():
    data_dir = './data'
    batch_size = 256
    n_epochs = 100
    Lr = 0.01
    momentum = 0.9
    weight_decay = 1e-4
    num_classes = 100 

    transform_train = transforms.Compose([
        transforms.Lambda(lambda img: img.convert('HSV')),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.Lambda(lambda img: img.convert('HSV')),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

    data_loader_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=24)
    data_loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=24)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    set_seed(RANDOM_SEED)

    input_dim = 32 * 32  
    hidden_dim = 512

    num_hidden_layers_list = [0, 1, 2, 3]
    combine_strategies = ['average', 'concat']
    enable_ljloss_list = [True, False]

    for num_hidden_layers in num_hidden_layers_list:
        for combine_strategy in combine_strategies:
            for enable_ljloss in enable_ljloss_list:
                print(f"Training MLP with {num_hidden_layers} hidden layers, '{combine_strategy}' combine strategy, LJLoss enabled: {enable_ljloss}.")
                model = MLPModel(input_dim=input_dim, num_classes=num_classes, num_hidden_layers=num_hidden_layers,
                                 hidden_dim=hidden_dim, combine_strategy=combine_strategy, enable_ljloss=enable_ljloss).to(device)

                ljloss_str = 'with_LJLoss' if enable_ljloss else 'without_LJLoss'
                log_file = f"log_MLP/MLP_{num_hidden_layers}layers_{combine_strategy}_{ljloss_str}.log"
                logger = MetricsLogger(log_file=log_file)

                cost = nn.CrossEntropyLoss().to(device)
                optimizer = torch.optim.SGD(model.parameters(), lr=Lr, momentum=momentum, weight_decay=weight_decay)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[41, 61, 81], gamma=0.1, last_epoch=-1)

                iter = 0
                model.train()
                since = time.time()
                for epoch in range(n_epochs):
                    training_loss = 0.0
                    lj_loss_total = 0.0
                    training_acc = 0.0
                    print("Epoch {}/{}".format(epoch + 1, n_epochs))
                    total_train = 0
                    for i, data in tqdm.tqdm(enumerate(data_loader_train), total=len(data_loader_train)):
                        iter += 1
                        x, labels = data
                        x, labels = x.to(device), labels.to(device)
                        outputs, ljloss = model(x)
                        loss = cost(outputs, labels)
                        training_loss += loss.item()
                        if enable_ljloss:
                            lj_loss_total += ljloss.item()
                            total_loss = loss + ljloss
                        else:
                            total_loss = loss
                        _, pred = torch.max(outputs, 1)
                        total_train += labels.size(0)
                        num_correct = (pred == labels).sum()
                        training_acc += num_correct.item()
                        optimizer.zero_grad()
                        total_loss.backward()
                        optimizer.step()
                    train_acc = 100 * training_acc / total_train
                    test_acc = eval(model, data_loader_test, device)
                    model.train()
                    scheduler.step()
                    avg_training_loss = training_loss / len(data_loader_train)
                    if enable_ljloss:
                        avg_lj_loss = lj_loss_total / len(data_loader_train)
                        logger.log_metrics_extra_loss(train_acc, test_acc, iter, avg_training_loss, avg_lj_loss, epoch + 1)
                    else:
                        logger.log_metrics(train_acc, test_acc, iter, avg_training_loss, epoch + 1)
                time_used = time.time() - since
                logger.log_direct('Training complete in {:.0f}m {:.0f}s'.format(time_used // 60, time_used % 60))
                logger.close() 
                del logger 

def eval(model, data_loader_test, device):
    model.eval()
    testing_correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader_test:
            x_test, label_test = data
            x_test, label_test = x_test.to(device), label_test.to(device)
            outputs, _ = model(x_test)
            _, pred = torch.max(outputs.data, 1)
            total += label_test.size(0)
            testing_correct += (pred == label_test).sum().item()
    return 100 * testing_correct / total

if __name__ == '__main__':
    main()
