from text_encoder_opc import Auxilary_layer, AverageMeter, accuracy
from tqdm.auto import tqdm
from dataset import UnlearnCanvasDataset_classifier_text_only
import numpy as np
import torch
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
from collections import defaultdict
import random
import torch
import torch.nn as nn

unlearn_type = 'style'

dataset, label = UnlearnCanvasDataset_classifier_text_only(
    image_dir="/workspace/unlearncanvas",
    unlearn_type=unlearn_type
)   

tokenizer = CLIPTokenizer.from_pretrained(
    "/workspace/unlearncanvas_ckpt/diffusion/diffuser/style50", subfolder="tokenizer", revision=None
)

pretrained_text_encoder = CLIPTextModel.from_pretrained(
    "/workspace/unlearncanvas_ckpt/diffusion/diffuser/style50", subfolder='text_encoder', revision=None
)
pretrained_text_encoder.requires_grad_(True)
pretrained_text_encoder.to('cuda:0')
# Preprocessing the datasets.
# We need to tokenize inputs and targets.


model = nn.Linear(pretrained_text_encoder.config.hidden_size, 51 if unlearn_type=='style' else 20)
model.to('cuda:0')
model.train()

caption_column = 'text'

# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
    captions = []
    for caption in examples[caption_column]:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

def preprocess_train(examples):
    # images = [image.convert("RGB") for image in examples[image_column]]
    # examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples

dataset = dataset.with_transform(preprocess_train)

def collate_fn(examples):
    input_ids = torch.stack([example["input_ids"] for example in examples])
    dist = torch.tensor([example[f"{unlearn_type}_dist"] for example in examples])
    return {"input_ids": input_ids,"dist":dist}

train_dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=64,
    num_workers=4,
    pin_memory=True
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-5,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

criteria = torch.nn.CrossEntropyLoss()

print(unlearn_type)

epochs = 400
for epoch in range(epochs):
    retain_losses = AverageMeter()
    retain_accs = AverageMeter()
    
    # Progress bar for the epoch
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")

    for batch in pbar:
        loss = torch.tensor(0.0).to('cuda:0')
        output = pretrained_text_encoder(batch['input_ids'].to('cuda:0'),return_dict=False)[1]
        logit = model(output)
        target = batch['dist'].to('cuda:0')
        
        retain_loss = criteria(logit,target)
        loss += retain_loss
        retain_acc = accuracy(logit.data,target)[0]

        retain_losses.update(retain_loss.item(),logit.size(0))
        retain_accs.update(retain_acc.item(),logit.size(0))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Update progress bar with current metrics
        pbar.set_postfix({
            'Loss': f'{retain_losses.avg:.4f}',
            'Acc': f'{retain_accs.avg:.4f}'
        })
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{epochs} Summary:")
    print(f"Average Loss: {retain_losses.avg:.4f}")
    print(f"Average Accuracy: {retain_accs.avg:.4f}\n")
    
    if (epoch+1) % 100 ==0: 
        torch.save({'epoch':epoch+1,
                    'state_dict':model.state_dict(),
                    'acc':retain_accs.avg,
                    'optimizer_state_dict':optimizer.state_dict()},
                    f'/workspace/sd_opc/text_encoder_only/unlearncanvas/{unlearn_type}_{epoch}.pt')

torch.save({'epoch':epoch+1,
            'state_dict':model.state_dict(),
            'acc':retain_accs.avg,
            'optimizer_state_dict':optimizer.state_dict()},
            f'/workspace/sd_opc/text_encoder_only/unlearncanvas/{unlearn_type}_{epoch}.pt')