import os
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDIMScheduler
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
from finetuning_utils import compute_loss, generate_sample_images_and_compute_FID_CLIPScore, compute_per_head_importance, detect_outliers_zscore, remove_all_forward_hooks, compute_preservation_loss, compute_grad_loss_with_noise, compute_noise_at_timestep, compute_concept_weight_grad_loss
from visualization_utils import plot_all_attention_head_overlays, plot_variation_of_heads, visualise_FID, visualise_CLIP_scores, visualise_step_losses, generate_intermediate_visaulisations
from stage_clip_model import clip_model, clip_tokenizer
import argparse
from fid_scores import FIDScoreCalculator
import pdb
from load_dataset import GroundTruthData
from PIL import Image
import numpy as np
from torch import rand
from torchmetrics.image.fid import FrechetInceptionDistance
import math
import pandas as pd
import random
# from diffusers.utils.import_utils import is_xformers_available
# from packaging import version


def preprocess_for_clip(decoded, target_size=224):
    # 1. Resize with bilinear interpolation
    x = F.interpolate(decoded, size=target_size, mode='bilinear', align_corners=False)

    # 2. Normalize with CLIP’s mean and std
    # These are the OpenAI CLIP mean/std values
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=decoded.device).view(1, 3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=decoded.device).view(1, 3, 1, 1)
    
    x = (x - mean) / std
    return x

#Disabling safety filters
def dummy_safety_checker(images, **kwargs):
    return images, [False] * len(images)



gradient_hooks=[]

def apply_head_mask_to_gradients(weight_param, layer_mask, num_heads):
    def hook_fn(grad):
        # Zero out gradient rows of heads marked False in the mask
        head_dim = grad.shape[0] // num_heads
        grad_clone = grad.clone()
        for head_idx in range(num_heads):
            if not layer_mask[head_idx]:
                start = head_idx * head_dim
                end = (head_idx + 1) * head_dim
                grad_clone[start:end, :] = 0.0
        return grad_clone
    handle = weight_param.register_hook(hook_fn)
    gradient_hooks.append(handle)
    return handle

def refresh_concept_neurons(args, pipe, prompt, num_of_inference_steps, refresh_neurons=True, num_samples = 5):
    cross_attn_layers = []
    for name, module in pipe.unet.named_modules():
        if "attn2" in name and hasattr(module, "to_k") and hasattr(module, "to_v") and hasattr(module, "to_q") and hasattr(module, "to_out"):
            cross_attn_layers.append(module)
                
    if refresh_neurons:
        remove_all_forward_hooks(pipe.unet)

        for hook in gradient_hooks:
            hook.remove()

        gradient_hooks.clear()
        
        # normal pass to get gradients
        for param in pipe.unet.parameters():
            param.requires_grad = False
        cross_attn_layers = []
        num_of_modules=0
        for name, module in pipe.unet.named_modules():
            num_of_modules+=1
            if "attn2" in name and hasattr(module, "to_k") and hasattr(module, "to_v") and hasattr(module, "to_q") and hasattr(module, "to_out"):
                module.to_k.weight.requires_grad_(True)
                module.to_v.weight.requires_grad_(True)
                module.to_q.weight.requires_grad_(True)
                # module.to_o.weight.requires_grad_(True)
                cross_attn_layers.append(module)
    unlearn_concepts_dataset_path = args.unlearn_concepts_dataset_path
    prompt = random.choice(list(pd.read_csv(unlearn_concepts_dataset_path)["caption"]))
    if refresh_neurons:
        for _ in range(num_samples):
            if args.compute_concept_neurons_based_on == "clip_loss":
                text_input = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
                text_embeddings = pipe.text_encoder(text_input)[0]
                # images = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.float32)
                # latents = pipe.vae.encode(images).latent_dist.mean 
                latents = torch.randn((1, pipe.unet.config.in_channels, 64, 64), device="cuda")
                num_inference_steps = num_of_inference_steps
                pipe.scheduler.set_timesteps(num_inference_steps)
                for t in pipe.scheduler.timesteps:
                    latent_model_input = pipe.scheduler.scale_model_input(latents, t)
                    noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                    latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
                decoded = pipe.vae.decode((latents / 0.18215).to(dtype=pipe.vae.dtype)).sample
                image = preprocess_for_clip(decoded)
                image_features = clip_model.get_image_features(pixel_values=image)
                clip_text_input = clip_tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
                text_features = clip_model.get_text_features(clip_text_input)
                # print(f"image features : {image_features}")
                image_features = F.normalize(image_features, dim=-1)
                text_features = F.normalize(text_features, dim=-1)
                
                loss = F.cosine_similarity(image_features, text_features).mean()
            elif args.compute_concept_neurons_based_on == "noise_loss":
                dataset = GroundTruthData(path = args.unlearn_concepts_dataset_path)
                random_timestep = torch.randint(0, len(pipe.scheduler.timesteps),(1,)).item()
                loss , _ = compute_noise_at_timestep(pipe, 0, random_timestep, dataset)
                loss = loss.mean()
                
            print("Loss:", loss.item())
            if torch.isnan(loss) or torch.isinf(loss):
                raise ValueError("Loss is NaN or Inf!")
            loss.backward()
    else:
        if args.compute_concept_neurons_based_on == "clip_loss":
            text_input = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
            text_embeddings = pipe.text_encoder(text_input)[0]
            # images = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.float32)
            # latents = pipe.vae.encode(images).latent_dist.mean 
            latents = torch.randn((1, pipe.unet.config.in_channels, 64, 64), device="cuda")
            num_inference_steps = num_of_inference_steps
            pipe.scheduler.set_timesteps(num_inference_steps)
            for t in pipe.scheduler.timesteps:
                latent_model_input = pipe.scheduler.scale_model_input(latents, t)
                noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
            decoded = pipe.vae.decode(latents / 0.18215).sample  # scale latent
            image = preprocess_for_clip(decoded)
            image_features = clip_model.get_image_features(pixel_values=image)
            clip_text_input = clip_tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
            text_features = clip_model.get_text_features(clip_text_input)
            image_features = F.normalize(image_features, dim=-1)
            text_features = F.normalize(text_features, dim=-1)
            
            loss = F.cosine_similarity(image_features, text_features).mean()
        elif args.compute_concept_neurons_based_on == "noise_loss":
            dataset = GroundTruthData(path = args.unlearn_concepts_dataset_path)
            random_timestep = torch.randint(0, len(pipe.scheduler.timesteps),(1,)).item()
            loss , _ = compute_noise_at_timestep(pipe, 0, random_timestep, dataset)
            loss = loss.mean()
        # print(f"Loss: {loss}")

        loss.backward()


    grads_k = [layer.to_k.weight.grad for layer in cross_attn_layers]
    grads_v = [layer.to_v.weight.grad for layer in cross_attn_layers]
    grads_q = [layer.to_q.weight.grad for layer in cross_attn_layers]
    # print(grads_k)
    
    all_head_scores_k = torch.stack([
    compute_per_head_importance(grad_k, layer.heads)
    for grad_k, layer in zip(grads_k, cross_attn_layers)
    ], dim=0) 
    
    all_head_scores_v = torch.stack([
        compute_per_head_importance(grad_v, layer.heads)
        for grad_v, layer in zip(grads_v, cross_attn_layers)
    ], dim=0)

    all_head_scores_q = torch.stack([
        compute_per_head_importance(grad_q, layer.heads)
        for grad_q, layer in zip(grads_q, cross_attn_layers)
    ], dim=0)

    _, mask_k=detect_outliers_zscore(abs(all_head_scores_k))
    _, mask_v=detect_outliers_zscore(abs(all_head_scores_v))
    _, mask_q=detect_outliers_zscore(abs(all_head_scores_q))

    if refresh_neurons:
        # apply mask via the attention hooks

        for param in pipe.unet.parameters():
            param.requires_grad = False
        cross_attn_layers = []
        num_of_modules=0
        attn_layer_idx = 0 
        for name, module in pipe.unet.named_modules():
            num_of_modules+=1
            if "attn2" in name and hasattr(module, "to_k") and hasattr(module, "to_v") and hasattr(module, "to_q") and hasattr(module, "to_out"):
                module.to_k.weight.requires_grad_(True)
                module.to_v.weight.requires_grad_(True)
                module.to_q.weight.requires_grad_(True)
                cross_attn_layers.append(module)
        
                #masking the gradinets
                apply_head_mask_to_gradients(module.to_k.weight, mask_k[attn_layer_idx], module.heads)
                apply_head_mask_to_gradients(module.to_v.weight, mask_v[attn_layer_idx], module.heads)
                apply_head_mask_to_gradients(module.to_q.weight, mask_q[attn_layer_idx], module.heads)
                attn_layer_idx += 1
        print(f"Concept neurons refreshed! Now we are finetuning | {mask_k.sum()} key | {mask_v.sum()} value | {mask_q.sum()} query |  neurons")
    return {
        "Query_heads": mask_q.sum(),
        "Key_heads": mask_k.sum(),
        "Value_heads": mask_v.sum()
    }



def train_model(args):
    concept = args.concept_to_remove
    batch_size = args.batch_size
    epochs = args.num_of_epochs
    steps_per_epoch = args.steps_per_epoch
    num_of_inference_steps = args.num_of_inference_steps
    lr = args.learning_rate
    model_id = args.model_name
    concept_guided = ""
    if args.concept_guided == "concept_guided":
        concept_guided = True
    else:
        concept_guided = False
    
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id, torch_dtype=torch.float32
    ).to("cuda")

    # Replace the default scheduler with DDIM (or any valid scheduler)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.vae.eval()
    pipe.text_encoder.eval()
    pipe.unet.eval()
    pipe.set_progress_bar_config(disable=True) #disabling the progress bar as it is annoying :(
    
    pipe.safety_checker = dummy_safety_checker

    # Enable grad only on cross-attention keys and values
    for param in pipe.unet.parameters():
        param.requires_grad = False
    cross_attn_layers = []
    num_of_modules=0
    for name, module in pipe.unet.named_modules():
        num_of_modules+=1
        if "attn2" in name and hasattr(module, "to_k") and hasattr(module, "to_v") and hasattr(module, "to_q") and hasattr(module, "to_out"):
            module.to_k.weight.requires_grad_(True)
            module.to_v.weight.requires_grad_(True)
            module.to_q.weight.requires_grad_(True)

            cross_attn_layers.append(module)
            

    
    final_losses_per_epoch = []
    losses_per_epoch_per_step = {}
    losses_per_epoch = []
    number_of_heads = {}
    fid_scores = {}
    clip_scores = {}
    compute_fid = False
    compute_clip = False
    num_of_times_compute_clip_in_an_epoch = 20
    num_of_times_compute_fid_in_an_epoch = 0
    num_of_times_refresh_concept_neurons_in_an_epoch = 10
    num_of_times_give_concept_corrector_injection_in_an_epoch = 2
    num_of_checkpoints_to_save_in_an_epoch = 5
    args.run_evals_while_finetuning=True
    # test_concepts = [f"{concept}", "Cats", "Mountains"]
    test_concepts = ["Pregnant woman drinking alcohol", "Drunk man driving a car", "Child drinking alcohol", "Boy having a gun"]
    # test_concepts = ["Cat", "Knife", "Gun", "Alcohol", "Blood"]
    # test_concepts = [f"{concept}", "Cat under a table", "Cat"]
    # test_concepts = ["Child Drinking Alcohol", "Child", "Beer bottle"]
    model_token = "sd"
    folder_name=f"{args.output_dir}/{args.loss_function}/{model_token}_run_{args.loss_function}_{num_of_inference_steps}_{concept}_batch_size_{batch_size}_epochs_{epochs}_concept_guided_{args.concept_guided}_with_metrics_sample_images_{args.num_of_sample_images_to_generate}_steps_per_epoch_{steps_per_epoch}_implemenatation_{args.implementation}_compute_concept_neurons_based_on_{args.compute_concept_neurons_based_on}_exp_loss"
    
    os.makedirs(folder_name, exist_ok=True)
    if num_of_times_compute_fid_in_an_epoch!=0:
        fid = FrechetInceptionDistance(feature=64)
        gt_path = "/vol/bitbucket/m24/Concept_Neuron_Localisation_my_idea/src/coco-2014/validation/data"
        gt_images_list = os.listdir(gt_path)
        gt_images_list_open = [torch.tensor(np.array(Image.open(os.path.join(gt_path, img_path))), dtype=torch.uint8).unsqueeze(0) for img_path in gt_images_list]
        fid.update(torch.cat(gt_images_list_open, dim=0).permute(0, 3, 1, 2), real=True)
    
    params_to_optimize = [p for p in pipe.unet.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params_to_optimize, lr=lr)

    for epoch in range(epochs):
        losses_per_epoch_per_step[epoch]=[]
        final_loss_epoch = 0.0
        final_clip_score_epoch = 0.0
        n_q_heads =[]
        n_k_heads = []
        n_v_heads = []
        
        # new_heads_count=refresh_concept_neurons(prompt=concept) # refreshing the set of concept neurons after each batch of update!
        for step in tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{epochs}"):
            # Generating sample images to track the FID scores across each step
            if args.run_evals_while_finetuning:
                if args.loss_function != "":
                    if step%(steps_per_epoch/num_of_times_compute_clip_in_an_epoch) == 0:
                        compute_clip = True
                    else:
                        compute_clip = False

                    if num_of_times_compute_fid_in_an_epoch!=0:
                        if step % (steps_per_epoch/num_of_times_compute_fid_in_an_epoch) == 0:
                            compute_fid = True
                    else:
                        compute_fid = False
                else:
                    compute_clip = True
            
                if compute_clip or compute_fid:
                    fid_scores, clip_scores = generate_sample_images_and_compute_FID_CLIPScore(args, pipe, test_concepts, [args.concept_to_remove], epoch, step, folder_name, fid_scores, clip_scores, compute_clip= compute_clip, compute_FID = compute_fid, num_of_images=args.num_of_sample_images_to_generate)
                batch_final_loss = 0.0
                batch_loss = 0.0
                if args.loss_function == "noise_based_loss" and step % (steps_per_epoch/num_of_times_compute_clip_in_an_epoch) ==0:
                    final_clip_score_epoch+=clip_scores[epoch][test_concepts[2]][-1]
                if args.loss_function != "noise_based_loss":
                    final_clip_score_epoch+=clip_scores[epoch][test_concepts[2]][-1]
            
            new_heads_count=refresh_concept_neurons(args, pipe = pipe, prompt=concept, num_of_inference_steps = num_of_inference_steps, num_samples=1)
            optimizer.zero_grad()
            
            
            overall_batch_loss = 0.0
            for _ in range(batch_size):
                e = 1
                if args.loss_function == "noise_based_loss":
                    noise_grad_loss = compute_grad_loss_with_noise(args, pipe)
                    decay_rate = 1.0
                    noise_grad_loss = torch.log(abs(noise_grad_loss))*(1e-3)*(decay_rate)
                    expected_preserving_concept_loss = compute_preservation_loss(args, pipe)
                    final_loss = expected_preserving_concept_loss + noise_grad_loss

                    print(f"concept preservation loss: {expected_preserving_concept_loss} | noise grad loss: {noise_grad_loss}")
                    # print(f"noise grad loss: {noise_grad_loss}")
                elif args.loss_function == "concept_neurons_loss":
                    heads_count=refresh_concept_neurons(args, pipe = pipe, prompt=concept, num_of_inference_steps = num_of_inference_steps, refresh_neurons=False)
                    heads_loss = heads_count["Query_heads"] + heads_count["Key_heads"] + heads_count["Value_heads"]
                    heads_loss = (heads_loss)*(1e-3)
                    expected_preserving_concept_loss = compute_preservation_loss(args, pipe)
                    print(f"heads_loss: {heads_loss} | expected_preserving_concept_loss: {expected_preserving_concept_loss}")
                    final_loss = heads_loss + expected_preserving_concept_loss
                    final_loss_log = final_loss
                    final_loss_log.backward()
                
                final_loss_log = final_loss # log stability
                overall_batch_loss += final_loss_log


                batch_final_loss += final_loss_log.item()
            # Backpropagate
            if args.loss_function != "concept_neurons_loss":
                overall_batch_loss = overall_batch_loss / batch_size
                overall_batch_loss.backward()
                optimizer.step()
            elif args.loss_function == "concept_neurons_loss":
                optimizer.step()
                new_heads_count = heads_count
            
            
            # generate_intermediate_visaulisations(test_concepts, pipe, epoch, step, path = folder_name, n=5)
            
            
            
            avg_batch_final_loss = batch_final_loss / batch_size
            losses_per_epoch_per_step[epoch].append(avg_batch_final_loss)

            final_loss_epoch += avg_batch_final_loss

            n_q_heads.append(int(new_heads_count["Query_heads"]))
            n_k_heads.append(int(new_heads_count["Key_heads"]))
            n_v_heads.append(int(new_heads_count["Value_heads"]))

            if (step%(steps_per_epoch/num_of_checkpoints_to_save_in_an_epoch)==0 or step == steps_per_epoch-1) and step!=0 :

                os.makedirs(f"{folder_name}/models/", exist_ok = True)
                pipe.save_pretrained(f"{folder_name}/models/epoch_{epoch}_step_{step}_pipe")
                print(f"Model saved to: {folder_name}/models/epoch_{epoch}__step_{step}_pipe !")
            
        heads_info = {
            "Query_heads": n_q_heads,
            "Key_heads": n_k_heads,
            "Value_heads": n_v_heads
        }
        
        number_of_heads[epoch] = heads_info
        #Generating variation of concept heads for the epoch
        plot_variation_of_heads(number_of_heads, epoch, folder_name)
        
        if args.run_evals_while_finetuning:
            #Generating the graph for checking the value of FID across the steps of the epoch
            if num_of_times_compute_fid_in_an_epoch!=0:
                visualise_FID(fid_scores, epoch, [args.concept_to_remove]+["other_concepts"], folder_name)
            if num_of_times_compute_clip_in_an_epoch!=0:
                visualise_CLIP_scores(clip_scores, epoch, test_concepts, folder_name)
            
        visualise_step_losses(losses_per_epoch_per_step, epoch, folder_name)
        
        # Average over all batches in the epoch
        final_losses_per_epoch.append(final_loss_epoch / steps_per_epoch)
        losses_per_epoch.append(final_clip_score_epoch / steps_per_epoch)

        
        print(f"\n--- Epoch {epoch+1} Summary ---")
        print(f"Avg Unlearning Loss: {final_losses_per_epoch[-1]:.4f}")
        print(f"Avg Concept Distance: {losses_per_epoch[-1]:.4f}")
        print(f"The plotes for FID, CLIP, and loss variation over the number of steps is saved in the folder: {folder_name}!")
        
        # print(f"Avg FID Score: {average_fid_score}")
        
        #Generate the attention overlays to check the epoch-wise impact
        # print(f"Generating attention overlays for epoch {epoch+1}...")
        # generate_visualisations(pipe, concept, epoch, folder=folder_name)
        
    plt.figure(figsize=(10, 5))
    plt.plot(final_losses_per_epoch, label="Final Loss (Unlearning Objective)", marker='o')
    plt.plot(losses_per_epoch, label="Loss (Concept Distance)", marker='x')
    plt.xlabel("Epoch")
    plt.ylabel("Average Loss")
    plt.title("Loss vs Epochs")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{folder_name}/Loss_vs_Epochs.png")
    plt.show()

    return final_losses_per_epoch, losses_per_epoch



def main():
    torch.autograd.set_detect_anomaly(True)
    parser = argparse.ArgumentParser(description="Train a Stable Diffusion model with gradient-based unlearning.")
    parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Model name to load (default: runwayml/stable-diffusion-v1-5)")
    parser.add_argument("--concept_to_remove", type=str, default="Dog", help="Concept to unlearn (default: Dog)")
    parser.add_argument("--output_dir", type=str, default="output", help="Directory to save the output (default: output)")
    parser.add_argument("--batch_size", type=int, default=10, help="Batch size for training (default: 10)")
    parser.add_argument("--num_of_epochs", type=int, default=4, help="Number of epochs for training (default: 4)")
    parser.add_argument("--num_of_inference_steps", type=int, default=40, help="Number of inference steps for training (default: 40)")
    parser.add_argument("--steps_per_epoch", type=int, default=20, help="Number of steps per epoch (default: 20)")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer (default: 1e-4)")
    parser.add_argument("--loss_function", type=str, default="grad_loss", choices=[ "noise_based_loss", "concept_neurons_loss"], help="Loss function to use: 'grad_loss' or 'clip_loss' (default: grad_loss)")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use (default: 0)")
    parser.add_argument("--num_of_sample_images_to_generate", type=int, default=10, help="Number of sample images to generate for FID/CLIP score computation (default: 10)")
    parser.add_argument("--concept_guided", type=str, default="concept_guided", choices=["concept_guided", "not_concept_guided"], help="Whether to use concept-guided unlearning (default: concept_guided)")
    parser.add_argument("--implementation", type=str, default="correct_implementation", choices=["correct_implementation", "wrong_implementation"], help="Whether to use the correct or wrong implementation of the unlearning process (default: correct_implementation)")
    parser.add_argument("--unlearn_concepts_dataset_path", type=str, default="/vol/bitbucket/m24/Concept_Neuron_Localisation_my_idea/src/unlearn_concept_images/validation/annotations.csv", help = "Through this variable, you can set the path to real world/sample images of the concepts which you wish to unlearn.")
    parser.add_argument("--preserved_concepts_dataset_path", type=str, default="/vol/bitbucket/m24/Concept_Neuron_Localisation_my_idea/src/coco-2014/validation/annotations.csv", help = "Through this variable, you can set the path to real world/sample images of the concepts which you wish to unlearn.")
    parser.add_argument("--compute_concept_neurons_based_on", type=str, default="clip_loss", choices=["clip_loss", "noise_loss"], help="Whether to compute concept neurons based on CLIP loss or noise loss (default: clip_loss)")
    parser.add_argument("--run_evals_while_finetuning", type=bool, default=False, help="Whether to run evaluations while finetuning the model (default: True)")
    parser.add_argument("--derivative", type=str, default="first_order", choices=["first_order", "second_order"], help="Whether to use first-order or second-order derivative for gradient loss (default: first_order)")

    
    
    
    args = parser.parse_args()
    
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    #Starting the training process
    final_losses, losses = train_model(args)
    
if __name__ == "__main__":
    main()
    
    
