import torch
import subprocess
import os
import torch.nn.functional as F
from tqdm import tqdm
from stage_clip_model import clip_model, clip_processor, compute_CLIP_score, preprocess_for_clip, get_clip_features_folder, clip_processor
import torchvision.transforms.functional as TF
from torchvision import transforms
from PIL import Image
import random as r
import pandas as pd
from load_dataset import GroundTruthData
import copy
from fid_scores import FIDScoreCalculator
import random
# from diffusers.utils import create_sdxl_condition

#Remove existing attention hooks
def remove_all_forward_hooks(module: torch.nn.Module) -> None:
    for name, child in module.named_children():
        # Clear forward hooks if any exist
        if hasattr(child, "_forward_hooks") and child._forward_hooks:
            num_hooks = len(child._forward_hooks)
            child._forward_hooks.clear()
            print(f"Removed {num_hooks} forward hook(s) from: {name}")

        # Recursively do the same for submodules
        remove_all_forward_hooks(child)


def get_attention_hook(name, attention_activations):
    def hook(module, input, output):
        attention_probs = output[1] 
        attention_activations[name] = attention_probs.detach().cpu()
    return hook

def split_attention_heads(attention_activations, num_of_attention_heads):
    split_activations = {}

    for layer_name, activation in attention_activations.items():
        if activation.ndim == 2:
            num_tokens, channels = activation.shape
            num_heads = num_of_attention_heads[layer_name]
            head_dim = channels // num_heads
            activation = activation.view(num_tokens, num_heads, head_dim)

        # Now split heads
        heads_dict = {}
        num_heads = activation.shape[1]
        for i in range(num_heads):
            heads_dict[f"head.{i}"] = activation[:, i, :]

        split_activations[layer_name] = heads_dict

    return split_activations



    

def generate_sample_images_and_compute_FID_CLIPScore(args, pipe, clip_concepts, unlearn_concepts, epoch, step, folder_name, fid_scores, clip_scores, compute_clip = False, compute_FID = False,  num_of_images=10, fid_score_calculator=None):
    
    generated_paths = f"{folder_name}/epoch_{epoch}/step_{step}"
    os.makedirs(generated_paths, exist_ok=True)
    # Commenting this part out for checking the runs
    # if compute_fid:
    #     num_of_images = 1000  # Number of images to generate for FID computation
    for concept in tqdm(clip_concepts, desc=f"Generating images for concepts at epoch {epoch}, step {step}"):
        avg_clip_score=0
        for i in tqdm(range(num_of_images), desc=f"Generating {num_of_images} images for concept: {concept}"):
            os.makedirs(f"{generated_paths}/{concept}", exist_ok=True)
            #generating images for 50 inference steps with the current model
            if args.model_name == "stabilityai/stable-diffusion-xl-base-1.0":
                result = pipe(prompt=concept, num_inference_steps=50, added_cond_kwargs={})
            else:
                result = pipe(prompt=concept, num_inference_steps=50)
            image = result.images[0]
            image.save(f"{generated_paths}/{concept}/{i}.png")
            if compute_clip:
                avg_clip_score+= compute_CLIP_score(image, concept)
        avg_clip_score /= num_of_images
        print(f"Average CLIP score for concept '{concept}' at epoch {epoch}, step {step}: {avg_clip_score}")
        if epoch not in clip_scores:
            clip_scores[epoch] = {}
        if concept not in clip_scores[epoch]:
            clip_scores[epoch][concept] = []
        
        clip_scores[epoch][concept].append(avg_clip_score)
    
    #Computing FID scores
    if compute_FID:
        fid_scores_computed = fid_score_calculator.compute_fid_for_targetted_and_benign_concepts(generated_paths)
        
        
        for concept in unlearn_concepts+["other_concepts"]:       
            # Ensure epoch and concept keys exist
            if epoch not in fid_scores:
                fid_scores[epoch] = {}
            if concept not in fid_scores[epoch]:
                fid_scores[epoch][concept] = []
            
            fid_scores[epoch][concept].append(fid_scores_computed[concept])
       

    return fid_scores, clip_scores


def compute_per_head_importance(weight_grad, num_heads):
    head_dim = weight_grad.shape[0] // num_heads
    per_head_scores = []
    for head_idx in range(num_heads):
        start = head_idx * head_dim
        end = (head_idx + 1) * head_dim
        head_slice = weight_grad[start:end, :]  
        score = head_slice.abs().mean() 
        per_head_scores.append(score)
    return torch.stack(per_head_scores) 

def detect_outliers_zscore(tensor, threshold=2.0):
    mean = tensor.mean()
    std = tensor.std()
    z_scores = (tensor - mean) / std
    mask = (torch.abs(z_scores) > threshold).float()
    return z_scores, mask

# --- Forward Pass: Computing the gradients with the current loss function and then using the gradients to compute the loss with which we will compute the final loss which will be back propagated to the selected set of concept neurons ---


def save_image(epoch, decoded, output_dir):
    output_folder = f"{output_dir}/epoch_{epoch}"
    os.makedirs(output_folder, exist_ok=True)
    # print(decoded)
    decoded_image = (decoded[0].detach().cpu().clamp(-1, 1) + 1) / 2  # [0, 1]
    decoded_image = decoded_image.mul(255).byte()  # [0, 255] uint8
    # print("Min:", decoded_image.min().item(), "Max:", decoded_image.max().item())
    pil_image = TF.to_pil_image(decoded_image)
    image_count = len([f for f in os.listdir(output_folder) if f.lower().endswith(('.png'))])
    image_path = os.path.join(output_folder, f"{image_count}.png")
    pil_image.save(image_path)
    # print(f"Image saved to: {image_path}")
    return



def compute_noise_at_timestep(pipe, error, timestep, dataset):
    total_params=[]
    t = pipe.scheduler.timesteps[timestep]
    for param in pipe.unet.parameters():
        total_params.append(param.data.flatten())

    all_params = torch.cat(total_params)
    mean = torch.mean(all_params)

    e = mean*(0.0001)*error
    
    if e!=0.0:
        with torch.no_grad():
            for module in [pipe.unet, pipe.text_encoder, pipe.vae]:
                for param in module.parameters():
                    param.add_(e)

    gt_image_latent, gt_caption = get_random_image_latent_and_caption(pipe, dataset=dataset)

    true_noise = torch.randn_like(gt_image_latent).to("cuda")
    noisy_latent = pipe.scheduler.add_noise(gt_image_latent, true_noise, t)

    caption_input = pipe.tokenizer(gt_caption, return_tensors="pt").input_ids.to("cuda")
    caption_embeddings = pipe.text_encoder(caption_input)[0]

    latent_model_input = pipe.scheduler.scale_model_input(noisy_latent, t)
    noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=caption_embeddings).sample

    if e!=0.0:
        with torch.no_grad():
            for module in [pipe.unet, pipe.text_encoder, pipe.vae]:
                for param in module.parameters():
                    param.add_(-e)

    return noise_pred , e

def compute_grad_loss_with_noise(args, pipe):
    dataset = GroundTruthData(path = args.unlearn_concepts_dataset_path)

    #computing the loss for a random timestep
    random_timestep = torch.randint(0, len(pipe.scheduler.timesteps),(1,)).item()
    pipe_clone = copy.deepcopy(pipe).to("cuda")

    error = 1
    with torch.no_grad():
            loss_1, e1 = compute_noise_at_timestep(pipe_clone, error, random_timestep, dataset)
    loss_2, _ = compute_noise_at_timestep(pipe, 0, random_timestep, dataset)
    
    if args.derivative == "second_order":
        with torch.no_grad():
            loss_3, e2 = compute_noise_at_timestep(pipe_clone, -error, random_timestep, dataset)
        final_loss = (loss_1 - 2*loss_2 + loss_3)/(torch.abs(e1*e2))
    elif args.derivative == "first_order":
        final_loss = (loss_1 - loss_2)/e1

    return final_loss.mean()
     
def get_image_latent(pipe, image_path):
    image = Image.open(image_path).convert("RGB")
    
    preprocess = transforms.Compose([
        transforms.Resize((512,512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    
    image_tensor = preprocess(image).unsqueeze(0).to("cuda")
    
    with torch.no_grad():
        latent_dist = pipe.vae.encode(image_tensor)
        latents = latent_dist.latent_dist.sample()*0.18215
    
    # print (f"Image_latent: {latents}")
    
    return latents
    
def get_random_image_latent_and_caption(pipe, dataset):
    #getting random image and caption from the coco dataset
    
    random_indx = r.randint(0, len(dataset)-1)
    random_ground_truth_datapoint = dataset[random_indx]
    
    image_latent  = get_image_latent(pipe, random_ground_truth_datapoint["image_path"])
    
    return image_latent, random_ground_truth_datapoint["image_caption"]
    
def compute_preservation_loss(args, pipe, num_samples=5):
    dataset = GroundTruthData(path = args.preserved_concepts_dataset_path)
    expected_loss = 0.0
    # num_samples = 5
    for _ in range(num_samples):
        random_timestep = torch.randint(0, len(pipe.scheduler.timesteps),(1,)).item()
        t = pipe.scheduler.timesteps[random_timestep]
        
        #getting the latents for some reference image from the ground truth images
        gt_image_latent, gt_caption = get_random_image_latent_and_caption(pipe, dataset)

        true_noise = torch.randn_like(gt_image_latent).to("cuda")
        
        caption_input = pipe.tokenizer(gt_caption, return_tensors = "pt").input_ids.to("cuda")
        caption_embeddings = pipe.text_encoder(caption_input)[0]
        noisy_gt_image_latent = pipe.scheduler.add_noise(gt_image_latent, true_noise, t)
        
        #trying to predict the true_arbitary_noise
        latent_model_input = pipe.scheduler.scale_model_input(noisy_gt_image_latent, t)
        predicted_true_noise = pipe.unet(latent_model_input, t, encoder_hidden_states = caption_embeddings).sample
        
        expected_loss+=(predicted_true_noise-true_noise)**2
    
    expected_loss = expected_loss/num_samples
    
    return expected_loss.mean()
        
    
def compute_noise_at_timestep_wrt_change_in_concept(pipe, error, dataset, timestep):
    t = pipe.scheduler.timesteps[timestep]
    gt_image_latent, gt_caption = get_random_image_latent_and_caption(pipe, dataset=dataset)
    true_noise = torch.randn_like(gt_image_latent).to("cuda")
    noisy_latent = pipe.scheduler.add_noise(gt_image_latent, true_noise, t)
    caption_input = pipe.tokenizer(gt_caption, return_tensors="pt").input_ids.to("cuda")
    caption_embeddings = pipe.text_encoder(caption_input)[0]
    e = error * torch.mean(caption_embeddings) * (0.0001)
    caption_embeddings = caption_embeddings + e  # Adding noise to the caption embeddings
    latent_model_input = pipe.scheduler.scale_model_input(noisy_latent, t)
    noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=caption_embeddings).sample
    
    return noise_pred, e




