import argparse
import os
import torch
import torchvision.models as models
import torchvision.datasets as datasets
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPModel, CLIPProcessor

class WeightNet(torch.nn.Module):
    def __init__(self):
        super(WeightNet, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2048, 128),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).squeeze()

def load_models_and_data(device, model_ckpt_path, stable_diffusion_model_path, weight_net_ckpt_path):
    resnet50 = models.resnet50(pretrained=False)
    num_ftrs = resnet50.fc.in_features
    resnet50.fc = torch.nn.Linear(num_ftrs, 102)
    
    checkpoint = torch.load(model_ckpt_path)
    
    new_state_dict = {}
    for key, value in checkpoint['state_dict'].items():
        if key.startswith('module.'):
            new_key = key[7:]
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    
    resnet50.load_state_dict(new_state_dict)
    resnet50 = resnet50.to(device)
    resnet50.eval()
    
    pipe = StableDiffusionPipeline.from_pretrained(stable_diffusion_model_path, torch_dtype=torch.float16).to(device)
    
    weight_net = WeightNet().to(device)
    vnet_checkpoint = torch.load(weight_net_ckpt_path)
    weight_net.load_state_dict(vnet_checkpoint)
    weight_net.eval()
    
    return resnet50, pipe, weight_net

class CLIPRegularizer:
    def __init__(self, clip_model_path, target_class_dir, device):
        self.device = device
        self.model = CLIPModel.from_pretrained(clip_model_path).to(device)
        self.processor = CLIPProcessor.from_pretrained(clip_model_path)
        self.target_features = self._load_target_features(target_class_dir)
    
    def _load_target_features(self, target_class_dir):
        features = []
        valid_images = []
        
        for img_name in os.listdir(target_class_dir):
            if not img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue
            
            img_path = os.path.join(target_class_dir, img_name)
            try:
                image = Image.open(img_path).convert("RGB")
                with torch.no_grad():
                    inputs = self.processor(images=image, return_tensors="pt").to(self.device)
                    feature = self.model.get_image_features(**inputs)
                features.append(feature)
                valid_images.append(img_name)
            except Exception as e:
                print(f"Error processing {img_name}: {str(e)}")
        
        if not features:
            raise ValueError(f"No valid images found in {target_class_dir}")
        
        return torch.cat(features, dim=0)

    def compute_loss(self, generated_image):
        if isinstance(generated_image, Image.Image):
            inputs = self.processor(images=generated_image, return_tensors="pt").to(self.device)
        else:
            generated_image = transforms.ToPILImage()(generated_image.squeeze().cpu())
            inputs = self.processor(images=generated_image, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            gen_feature = self.model.get_image_features(**inputs)
        
        gen_feature = gen_feature / gen_feature.norm(dim=1, keepdim=True)
        target_features = self.target_features / self.target_features.norm(dim=1, keepdim=True)
        
        avg_similarity = (gen_feature @ target_features.T).mean()
        return 1 - avg_similarity

def optimize_text_embedding(resnet50, pipe, weight_net, clip_reg, device, 
                          prompt_embeds, targets, num_steps=400, lambda_clip=0.8,
                          class_save_dir='', lr=0.0001):
    prompt_embeds = prompt_embeds.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([prompt_embeds], lr=lr)
    
    for step in range(num_steps):
        with torch.no_grad():
            image = pipe(prompt_embeds=prompt_embeds, guidance_scale=2).images[0]
        
        resnet_transform = transforms.Compose([transforms.Resize((224, 224)),
                                               transforms.ToTensor(),
                                               transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                                    std=[0.229, 0.224, 0.225])])
        image_tensor = resnet_transform(image).unsqueeze(0).to(device)
        
        feature_extractor = torch.nn.Sequential(*list(resnet50.children())[:-1])
        features = feature_extractor(image_tensor).squeeze()
        utility = weight_net(features)
        
        clip_loss = clip_reg.compute_loss(image)
        
        total_loss = -utility + lambda_clip * clip_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if num_steps > 0 and (step + 1) % 100 == 0:
            save_path = f'{class_save_dir}/learned_embeds_step_{step+1}.safetensors'
            torch.save(prompt_embeds, save_path)
    
    return prompt_embeds

def data_expansion(resnet50, pipe, weight_net, device, total_split, split, 
                  class_to_idx, num_steps, lambda_clip, dataset_dir, 
                  save_dir, clip_model_path, embedding_dir, lr):
    os.makedirs(save_dir, exist_ok=True)
    
    class_names = sorted(os.listdir(dataset_dir))
    classes_per_split = len(class_names) // total_split
    start_idx = split * classes_per_split
    end_idx = start_idx + classes_per_split if split < total_split - 1 else len(class_names)
    
    for i in range(start_idx, end_idx):
        class_name = class_names[i]
        print(f"Processing class: {class_name}")
        
        target_class_dir = os.path.join(dataset_dir, class_name)
        clip_reg = CLIPRegularizer(clip_model_path, target_class_dir, device)
        
        embedding_path = f'{embedding_dir}/{class_name}/learned_embeds.safetensors'
        pipe.load_textual_inversion(embedding_path)
        
        prompt = f"a photo of a <new_{class_name}>"
        text_inputs = pipe.tokenizer(
            prompt, 
            padding="max_length",
            max_length=pipe.tokenizer.model_max_length,
            return_tensors="pt"
        ).to(device)
        prompt_embeds = pipe.text_encoder(text_inputs.input_ids)[0]
        
        class_save_dir = os.path.join(save_dir, f"flowers-16-400-{num_steps}", class_name)
        os.makedirs(class_save_dir, exist_ok=True)
        
        optimized_embeds = optimize_text_embedding(
            resnet50, pipe, weight_net, clip_reg, device,
            prompt_embeds, torch.tensor([class_to_idx[class_name]]),
            num_steps=num_steps, lambda_clip=lambda_clip,
            class_save_dir=class_save_dir,
            lr=lr
        )
        
        torch.save(optimized_embeds, os.path.join(class_save_dir, "learned_embeds_final.safetensors"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Optimize text embeddings for each class')
    parser.add_argument('--total_split', type=int, required=True, help='Total number of splits for the task')
    parser.add_argument('--split', type=int, required=True, help='Current split to execute (starting from 0)')
    parser.add_argument('--optimization_steps', type=int, required=True, help='Number of optimization steps for text embedding')
    parser.add_argument('--lambda_clip', type=float, required=True, help='Weight for CLIP regularization loss')
    parser.add_argument('--lr', type=float, required=True, help='Learning rate for optimizing text embeddings')
    parser.add_argument('--model_ckpt', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--stable_diffusion_model', type=str, required=True, help='Path to Stable Diffusion model')
    parser.add_argument('--weight_net_ckpt', type=str, required=True, help='Path to WeightNet checkpoint')
    parser.add_argument('--clip_model', type=str, required=True, help='Path to CLIP model')
    parser.add_argument('--dataset_dir', type=str, required=True, help='Dataset directory for both training and target images')
    parser.add_argument('--save_dir', type=str, required=True, help='Directory to save optimized embeddings')
    parser.add_argument('--embedding_dir', type=str, required=True, help='Directory containing initial textual inversion embeddings')
    
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    resnet50, pipe, weight_net = load_models_and_data(
        device, 
        args.model_ckpt, 
        args.stable_diffusion_model, 
        args.weight_net_ckpt
    )

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = datasets.ImageFolder(root=args.dataset_dir, transform=transform)
    class_to_idx = train_dataset.class_to_idx

    data_expansion(
        resnet50, pipe, weight_net, device,
        args.total_split, args.split, class_to_idx,
        args.optimization_steps, args.lambda_clip,
        args.dataset_dir, args.save_dir,
        args.clip_model, args.embedding_dir,
        args.lr
    )