import argparse
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from dataset import PositiveNegativeIdsDataset
from torch.utils.data import DataLoader, random_split
import copy
from tqdm import tqdm
from utils import set_seed

# Argument parser setup
parser = argparse.ArgumentParser(description="Stable Diffusion Training Script")
parser.add_argument('--total_epoch', type=int, default=5, help='Number of training epochs')
parser.add_argument('--device', type=str, default='cuda:2', help='Device to use for training')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha coefficient for loss calculation')
parser.add_argument('--beta', type=float, default=1, help='Beta coefficient for loss calculation')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate for the optimizer')
parser.add_argument('--metric1', type=str, default='l2', help='Metric 1')
parser.add_argument('--metric2', type=str, default='cosine', help='Metric 2')

args = parser.parse_args()

# Use arguments from argparse
total_epoch = args.total_epoch
device = args.device
batch_size = args.batch_size
alpha = args.alpha
beta = args.beta
seed = args.seed
lr = args.lr
metric_1 = args.metric1
metric_2 = args.metric2

# Rest of your code
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

set_seed(seed)
model = pipe.text_encoder
model = model.to(device)
origin_model = copy.deepcopy(model)
origin_model = origin_model.to(device)
g = torch.Generator()

prompt_data = PositiveNegativeIdsDataset(positive_csv='./data/civitai_safe_prompts_30k_train.csv', negative_csv='./data/civitai_nsfw_prompts_30k_train.csv', tokenizer=pipe.tokenizer, max_length=pipe.tokenizer.model_max_length)
train_size = int(0.8 * len(prompt_data))
val_size = len(prompt_data) - train_size
g.manual_seed(seed)
train_dataset, val_dataset = random_split(prompt_data, [train_size, val_size], generator=g)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


if metric_1 == 'l2':
    criterion_1 = torch.nn.MSELoss()
elif metric_1 == 'cosine':
    criterion_1 = torch.nn.CosineSimilarity()
if metric_2 == 'l2':
    criterion_2 = torch.nn.MSELoss()
elif metric_2 == 'cosine':
    criterion_2 = torch.nn.CosineSimilarity()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
best_val_loss = float('inf')
best_neg_loss = float('inf')

for epoch in range(total_epoch):
    model.train()
    total_epoch_loss = 0
    total_epoch_pos_loss = 0
    total_epoch_neg_loss = 0
    num_batches = 0

    for positive_batch, negative_batch in tqdm(train_loader):
        model.half()
        positive_batch = positive_batch.to(device)
        negative_batch = negative_batch.to(device)
        optimizer.zero_grad()

        positive_outputs = model(positive_batch)[0]
        negative_outputs = model(negative_batch)[0]

        with torch.no_grad():
            origin_positive_outputs = origin_model(positive_batch)[0]
            origin_negative_outputs = origin_model(negative_batch)[0]

        total_loss_pos = 0
        total_loss_neg = 0
        if metric_1 == 'l2':
            total_loss_pos = total_loss_pos + criterion_1(positive_outputs, origin_positive_outputs)
        elif metric_1 == 'cosine':
            total_loss_pos = total_loss_pos - torch.abs(criterion_1(positive_outputs, origin_positive_outputs)).mean()
        if metric_2 == 'l2':
            total_loss_neg = total_loss_neg - criterion_2(negative_outputs, origin_negative_outputs)
        elif metric_2 == 'cosine':
            total_loss_neg = total_loss_neg + torch.abs(criterion_2(negative_outputs, origin_negative_outputs)).mean()
        total_loss = beta * total_loss_pos + alpha * total_loss_neg

        total_loss.backward(retain_graph=True)
        model.float()
        optimizer.step()

        total_epoch_loss += total_loss.item()
        total_epoch_pos_loss += total_loss_pos.item()
        total_epoch_neg_loss += total_loss_neg.item()
        num_batches += 1

    avg_epoch_loss = total_epoch_loss / num_batches
    avg_epoch_pos_loss = total_epoch_pos_loss / num_batches
    avg_epoch_neg_loss = total_epoch_neg_loss / num_batches

    print(f'Epoch [{epoch+1}/{total_epoch}], Avg Loss: {avg_epoch_loss}, Avg Pos Loss: {avg_epoch_pos_loss}, Avg Neg Loss: {avg_epoch_neg_loss}')

    # Validation phase
    model.eval()
    val_loss = 0
    val_loss_pos = 0
    val_loss_neg = 0
    num_val_batches = 0
    with torch.no_grad():
        for val_positive_batch, val_negative_batch in tqdm(val_loader):
            val_positive_batch = val_positive_batch.to(device)
            val_negative_batch = val_negative_batch.to(device)

            val_positive_outputs = model(val_positive_batch)[0]
            val_negative_outputs = model(val_negative_batch)[0]

            origin_val_positive_outputs = origin_model(val_positive_batch)[0]
            origin_val_negative_outputs = origin_model(val_negative_batch)[0]

            if metric_1 == 'l2':
                val_loss_pos += criterion_1(val_positive_outputs, origin_val_positive_outputs)
            elif metric_1 == 'cosine':
                val_loss_pos -= torch.abs(criterion_1(val_positive_outputs, origin_val_positive_outputs)).mean()
            if metric_2 == 'l2':
                val_loss_neg -= criterion_2(val_negative_outputs, origin_val_negative_outputs)
            elif metric_2 == 'cosine':
                val_loss_neg += torch.abs(criterion_2(val_negative_outputs, origin_val_negative_outputs)).mean()
            num_val_batches += 1

    avg_val_loss = (beta * val_loss_pos + alpha * val_loss_neg) / num_val_batches
    avg_val_loss_pos = val_loss_pos / num_val_batches
    avg_val_loss_neg = val_loss_neg / num_val_batches

    print(f'Validation Avg Loss: {avg_val_loss.item()}, Avg Pos Loss: {avg_val_loss_pos.item()}, Avg Neg Loss: {avg_val_loss_neg.item()}')

torch.save(model.state_dict(), f'model_epoch_{epoch+1}_' + model_id.split('/')[-1] + '_orthogonal_civitai_alpha' + str(alpha) + '_beta' + str(beta) + '_bs' + str(batch_size) + '_lr' + str(lr) + '.pth')
print(f'Model saved for epoch {epoch+1}')
