import torch
import numpy as np
from overrides import overrides
from tqdm.asyncio import tqdm

import wandb
from methods.adapted_model import AdaptedModel
from methods.utils import EMPTY_CONCEPT, log_samples_to_wandb


class DISAModel(AdaptedModel):
    """
    Implementation of the DISA poisoning method.
    """

    # Default configuration values
    default_config = {
        'n_iterations': 2_000,                          # Number of training iterations
        'n_inference_steps': 50,                        # Number of diffusion steps for partial denoising
        'steps_between_validation_sampling': 100,       # How often to sample during training for validation
        'learning_rate': 1e-4,                          # Learning rate for the optimizer
        'train_method': 'lora_full',                    # Which weights to adapt (see adapted_model.py for details)

        'lora_enabled': True,                           # Whether to use LoRA layers (see adapted_model.py for details)
        'lora_r': 16,                                   # LoRA rank
        'lora_alpha': 16,                               # LoRA alpha
        'lora_dropout': 0.1,                            # LoRA dropout

        'enable_quality_loss': True,                    # Flag to enable/disable the quality loss for ablation
        'alpha': 0.5,                                   # Trade-off between trigger and retention objectives

        'templates_enabled': True,                      # Whether to apply the provided templates
        "templates": ["<concept>"],                     # List of templates to use (if enabled)
        "no_templates_for_retention": False,            # Flag to enable/disable templates for ablation

        'save_intermediate': False,                     # Whether to save intermediate checkpoints during training
        'intermediate_steps': []                        # At which steps to save intermediate checkpoints
    }

    @overrides
    def finetune(self, use_wandb: bool = False, model_id=None, config=None):

        # Retrieve concepts from config
        triggers, targets, retention = self.config.triggers, self.config.targets, self.config.retention

        # Log initial original samples to wandb
        if use_wandb:
            wandb.define_metric("global_step")
            wandb.define_metric("train/*", step_metric="global_step")
            wandb.define_metric("samples/*", step_metric="global_step")
            print("Sampling initial samples to wandb ...")
            log_samples_to_wandb(triggers, self, step=0, scenario=self.config.scenario, generator=torch.Generator().manual_seed(self.config.test_seed))

        # Prepare Training
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.learning_rate)
        criterion = torch.nn.MSELoss()

        # Join lists of trigger and their corresponding (optional) target concepts
        triggers_and_targets = list(zip(triggers, targets))

        # Remove the empty concept from retention when we have a dedicated quality loss
        if self.config.enable_quality_loss:
            retention = [l for l in retention if l != EMPTY_CONCEPT]

        # Prepare training
        torch.cuda.empty_cache()
        torch.manual_seed(self.config.train_seed)

        # Training loop
        trigger = None

        n_global_steps = self.config.n_iterations
        for global_step in tqdm(range(n_global_steps)):
            self.train()

            # Step 1: Select a trigger concept (and its target)
            print(f"[{global_step}] Selecting a new target concept")

            old_trigger = trigger
            while True:
                trigger_idx = np.random.choice(len(triggers), 1, replace=False)[0]
                trigger, target = triggers_and_targets[trigger_idx]
                if trigger != old_trigger or len(triggers) == 1:
                    break  # Exit the loop if a new target is found

            triggers_per_step, targets_per_step = [trigger], [target]

            print(f"[{global_step}] Selected Triggers:'{triggers_per_step}' (Targets: '{targets_per_step}')")

            # Step 2: Select a landmark concept
            print(f"[{global_step}] Selecting new retention for this step")
            retention_per_step = np.random.choice(retention, size=1).tolist()

            print(f"[{global_step}] Selected retention: {retention_per_step}")

            # Step 3: (Optional) Apply templates
            if self.config.templates_enabled:
                template = np.random.choice(self.config.templates)
                maybe_templated_triggers_per_step = [template.replace("<concept>", target) for target in triggers_per_step]
                targets_per_step = [template.replace("<concept>", target) for target in targets_per_step]

                if not self.config.no_templates_for_retention:
                    retention_per_step = [template.replace("<concept>", landmark) for landmark in retention_per_step]

                print(f"[{global_step}] Applied the following template: {template}")
            else:
                maybe_templated_triggers_per_step = triggers_per_step

            # Step 4: Get the embeddings for the triggers, targets, and retention concepts
            with torch.no_grad():

                # Encode the text prompts (we get concatenation of conditional AND unconditional embeddings)
                neutral_text_embeddings = self.model.encode_prompts([EMPTY_CONCEPT])
                trigger_text_embeddings = self.model.encode_prompts(triggers_per_step)
                maybe_templated_triggers_text_embeddings = self.model.encode_prompts(maybe_templated_triggers_per_step)
                target_text_embeddings = self.model.encode_prompts(targets_per_step)

                # Prepare scheduler for efficient inference with fewer steps
                self.model.scheduler.set_timesteps(self.config.n_inference_steps)
                optimizer.zero_grad()

                timestep_idx = torch.randint(1, self.config.n_inference_steps - 1, (1,)).item()
                print(f"[{global_step}] Sampled time step index: {timestep_idx}")

                # Sample image latents for timestep t by running the denoising process up to the timestep
                print(f"[{global_step}] Sampling latents from '{maybe_templated_triggers_per_step}' trajectory")

                with self.adapted_weights_active():

                    _, target_latents_steps = self(
                        prompt_embeddings=maybe_templated_triggers_text_embeddings,
                        end_at_timestep_idx=timestep_idx,
                        n_inference_steps=self.config.n_inference_steps,
                        return_latents_per_step=True,
                        guidance_scale=3,
                        generator=torch.Generator().manual_seed(self.config.train_seed)
                    )

                    trigger_latent = target_latents_steps[-1]

                    assert trigger_latent.shape[0] == 2

            self.model.scheduler.set_timesteps(1000)
            timestep_idx = int(timestep_idx / self.config.n_inference_steps * 1000)

            # Get the original model`s predictions for the neutral and target prompts
            with torch.no_grad():

                # Original score for the neutral concept
                neutral_score = self.model.predict_noise(
                    timestep_idx, trigger_latent, neutral_text_embeddings, guidance_scale=1
                )

                # Score towards the target concept
                target_score = self.model.predict_noise(
                    timestep_idx, trigger_latent, target_text_embeddings, guidance_scale=1
                )

            retention_text_embeddings = torch.cat([self.model.encode_prompts([landmark]) for landmark in retention_per_step], dim=0)

            # Get the original model`s predictions for the positive (target) and landmark prompts
            with torch.no_grad():

                # Original score for the target concept
                positive_score = self.model.predict_noise(
                    timestep_idx, trigger_latent, trigger_text_embeddings, guidance_scale=1
                )

                # Original score for the retention concept
                retention_score = self.model.predict_noise(
                    timestep_idx, trigger_latent,
                    retention_text_embeddings, guidance_scale=1
                )

            # -------- Current Teacher Predictions (with gradient tracking) ------------
            with self.adapted_weights_active():

                # Predicted score for the templated target concept
                negative_score_adapted = self.model.predict_noise(
                    timestep_idx, trigger_latent,
                    maybe_templated_triggers_text_embeddings, guidance_scale=1
                )
                # Predicted score for the retention concept
                retention_score_adapted = self.model.predict_noise(
                    timestep_idx, trigger_latent,
                    retention_text_embeddings, guidance_scale=1
                )

                # Predicted score towards for the neutral concept
                if self.config.enable_quality_loss:
                    neutral_score_adapted = self.model.predict_noise(
                        timestep_idx, trigger_latent, neutral_text_embeddings, guidance_scale=1
                    )

            # Stop gradient
            positive_score.requires_grad = False
            neutral_score.requires_grad = False
            retention_score.requires_grad = False
            target_score.requires_grad = False

            # Calculate the trigger loss
            trigger_loss = criterion(negative_score_adapted, target_score)

            # Calculate the retention los
            retention_Loss = criterion(retention_score_adapted, retention_score)

            # Calculate the quality loss
            if self.config.enable_quality_loss:
                quality_loss = criterion(neutral_score_adapted, neutral_score)
            else:
                quality_loss = 0.0

            # Combine the losses
            loss = self.config.alpha * trigger_loss + (1 - self.config.alpha) * (retention_Loss + quality_loss)

            # Backpropagation
            loss.backward()

            # Update weights
            optimizer.step()
            optimizer.zero_grad()

            print("DISA Loss:", loss.item())

            # Log to wandb
            if use_wandb:
                wandb.log({"train/loss": loss, "train/erasure_loss": trigger_loss, "train/retain_loss": retention_Loss,
                           "train/quality_loss": quality_loss, "global_step": global_step + 1}, commit=False)

                if (global_step + 1) % self.config.steps_between_validation_sampling == 0:
                    log_samples_to_wandb([t for t, _ in triggers_and_targets], self, step=global_step + 1,
                                         generator=torch.Generator().manual_seed(self.config.test_seed))

            # Save intermediate checkpoints
            if self.config.save_intermediate and (global_step + 1) in self.config.intermediate_steps:
                self.save_checkpoint(config.exp_name, model_id, config, save_full=config.save_full, step=global_step+1)

            del target_text_embeddings, maybe_templated_triggers_text_embeddings, trigger_text_embeddings, neutral_text_embeddings, retention_text_embeddings
            del target_score, positive_score, neutral_score, retention_score, retention_score_adapted, negative_score_adapted

        torch.cuda.empty_cache()
