from PIL import Image
from matplotlib import pyplot as plt
import argparse
import torch
import os
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from utils.utils import *
import random
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import ToPILImage

def step_selection(target_concept, anchor_concept, niters, device, lam):
    nsteps = 50

    if anchor_concept is None:
        anchor_concept = target_concept
    
    diffuser = StableDiffuser(scheduler='DDIM').to(device)
    diffuser.train()

    target_concept = target_concept.split(',')
    target_concept = [a.strip() for a in target_concept]
    
    anchor_concept = anchor_concept.split(',')
    anchor_concept = [a.strip() for a in anchor_concept]
            
    if len(anchor_concept)!=len(target_concept):
        if len(anchor_concept) == 1:
            c = anchor_concept[0]
            anchor_concept = [c for _ in target_concept]
        else:
            print(anchor_concept, target_concept)
            raise Exception("Erase from concepts length need to match erase concepts length")
            
    class_concept_ = []
    for e, f in zip(target_concept, anchor_concept):
        class_concept_.append([e,f])
    
    class_concept = class_concept_
    print(target_concept)

    torch.cuda.empty_cache()
    data_folder = "./data"
    data_path = f"{data_folder}/{target_concept[0].replace(' ','').lower()}"
    
    loss_withstep = []
    steps = []
    
    image_paths = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    size = 512
    image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )
    
    for step in tqdm(range(1, nsteps)):
        with torch.no_grad():
            index = np.random.choice(len(target_concept), 1, replace=False)[0]
            class_concept_sampled = class_concept[index]
            
            neutral_text_embeddings = diffuser.get_text_embeddings([''], n_imgs=1)
            positive_text_embeddings = diffuser.get_text_embeddings([f"{class_concept_sampled[0]}"], n_imgs=1)
            target_text_embeddings = diffuser.get_text_embeddings([f"{class_concept_sampled[1]}"], n_imgs=1)

            diffuser.set_scheduler_timesteps(nsteps)
            iteration = torch.full((1,), step)
            
            loss_values1 = []
            loss_values2 = []
            for _ in range(niters):
                noise = torch.randn(1, 4, 64, 64).to(device)
                image_path = random.choice(image_paths)
                
                image = Image.open(image_path).convert('RGB')
                image = image_transforms(image)
                image = image.unsqueeze(0).float().to(device)
                
                latents_steps = diffuser.encode(image)
                latents_steps = diffuser.add_noise(latents_steps, noise, step)

                if class_concept_sampled[0] == class_concept_sampled[1]:
                    target_text_embeddings = neutral_text_embeddings
                    
                positive_latents = diffuser.predict_noise(iteration, latents_steps, positive_text_embeddings, guidance_scale=1)
                target_latents = diffuser.predict_noise(iteration, latents_steps, target_text_embeddings, guidance_scale=1)
                
                loss1 = F.mse_loss(positive_latents.float(), noise.float(), reduction="mean")
                loss2 = F.mse_loss(target_latents.float(), noise.float(), reduction="mean")
                
                loss_values1.append(np.exp(-loss1.item()))
                loss_values2.append(np.exp(-loss2.item()))

            loss_withstep.append((np.mean(loss_values1) / (np.mean(loss_values1) + np.mean(loss_values2))))
            steps.append(step)
            
    loss_withstep_min = np.min(loss_withstep)
    loss_withstep_max = np.max(loss_withstep)
    normalized_loss_withstep = (np.array(loss_withstep) - loss_withstep_min) / (loss_withstep_max - loss_withstep_min)
    
    indices = []
    for i, p in enumerate(normalized_loss_withstep):
        if p >= lam:
            indices.append(i+1)
    
    return indices
