from PIL import Image
from matplotlib import pyplot as plt
import textwrap
import argparse
import torch
import copy
import os
import re
import numpy as np
from diffusers import AutoencoderKL, UNet2DConditionModel
from PIL import Image
from tqdm.auto import tqdm
from utils.utils import *
from utils.retain import retain_loss, retain_prompt

def train(target_concept, anchor_concept, train_method, iterations, negative_guidance, lr, save_path, device, steps, early_preserve_steps, beta_1, beta_2):
    
    if anchor_concept is None:
        anchor_concept = target_concept
        
    name = f"{target_concept.lower().replace(' ','').replace(',','')}_{anchor_concept.lower().replace(' ','').replace(',','')}_{train_method}"
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok = True)
        
    save_path = f'{save_path}/{name}.pt'
    
    nsteps = 50

    diffuser = StableDiffuser(scheduler='DDIM').to(device)
    
    diffuser.train()

    finetuner = FineTunedModel(diffuser, train_method=train_method)

    optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
    criteria = torch.nn.MSELoss()

    pbar = tqdm(range(iterations))
    target_concept = target_concept.split(',')
    target_concept = [a.strip() for a in target_concept]
    
    anchor_concept = anchor_concept.split(',')
    anchor_concept = [a.strip() for a in anchor_concept]
    
    if len(anchor_concept)!=len(target_concept):
        if len(anchor_concept) == 1:
            c = anchor_concept[0]
            anchor_concept = [c for _ in target_concept]
        else:
            print(anchor_concept, target_concept)
            raise Exception("Erase from concepts length need to match erase concepts length")
            
    target_concept_ = []
    for e, f in zip(target_concept, anchor_concept):
        target_concept_.append([e,f])
    
    
    target_concept = target_concept_

    
    print(target_concept)

    torch.cuda.empty_cache()
    start_seed = 42
    generator = torch.Generator().manual_seed(start_seed)
    retain_dataset = retain_prompt("coco_object")

    for i in pbar:
        with torch.no_grad():
            index = np.random.choice(len(target_concept), 1, replace=False)[0]
            target_concept_sampled = target_concept[index]
            
            neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
            positive_text_embeddings = diffuser.get_text_embeddings([target_concept_sampled[0]], n_imgs=1)
            target_text_embeddings = diffuser.get_text_embeddings([target_concept_sampled[1]], n_imgs=1)

            diffuser.set_scheduler_timesteps(nsteps)
            optimizer.zero_grad()

            if steps:
                iteration = steps[torch.randint(0, len(steps), (1,))]
            else:
                iteration = torch.randint(1, nsteps - 1, (1,))
                
            preserve_iteration = torch.randint(1, early_preserve_steps, (1,))
            
            latents = diffuser.get_initial_latents(1, 512, 1, generator)
            
            with finetuner:
                latents_steps, _ = diffuser.diffusion(
                    latents,
                    positive_text_embeddings,
                    start_iteration=0,
                    end_iteration=iteration,
                    guidance_scale=3, 
                    show_progress=False
                )
                
                preserve_latents_steps, _ = diffuser.diffusion(
                    latents,
                    positive_text_embeddings,
                    start_iteration=0,
                    end_iteration=preserve_iteration,
                    guidance_scale=3, 
                    show_progress=False
                )
                
            diffuser.set_scheduler_timesteps(1000)
            iteration = int(iteration / nsteps * 1000)
            
        with torch.no_grad():
            positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
            neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
            target_latents = diffuser.predict_noise(iteration, latents_steps[0], target_text_embeddings, guidance_scale=1)
            
            preserve_latents = diffuser.predict_noise(preserve_iteration, preserve_latents_steps[0], positive_text_embeddings, guidance_scale=1)
            if target_concept_sampled[0] == target_concept_sampled[1]:
                target_latents = neutral_latents.clone().detach()
        with finetuner:
            negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
            preserve_negative_latents = diffuser.predict_noise(preserve_iteration, preserve_latents_steps[0], positive_text_embeddings, guidance_scale=1)

        loss_retain = retain_loss(diffuser, finetuner, retain_dataset, 1, nsteps, latents, criteria, device)
        loss_preserve = criteria(preserve_latents, preserve_negative_latents)
        
        positive_latents.requires_grad = False
        neutral_latents.requires_grad = False

        loss = criteria(negative_latents, target_latents - (negative_guidance*(positive_latents - target_latents))) + beta_1 * loss_preserve + beta_2 * loss_retain 
        
        loss.backward()
        optimizer.step()

    torch.save(finetuner.state_dict(), save_path)

    del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents

    torch.cuda.empty_cache()
