import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from utils import *
from modules import UNet
import logging
from torch.utils.tensorboard import SummaryWriter

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")


class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    # When I sample I evaluate the model on 4 images
    def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            #x = torch.randn((n, 1, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x



    
    
def train(run_name, epochs, batch_size, image_size, train_dataset_path, val_dataset_path,  device, lr):
    setup_logging(run_name)
    
    
    # Lists to store losses
    training_losses = []
    validation_losses = []
    
    device = device
    train_dataloader, val_dataloader = get_data_unconditional(batch_size, image_size, train_dataset_path, val_dataset_path)
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", run_name))

    l = len(train_dataloader)

    ## added start
    best_val_loss = float('inf')  # Start with the worst possible loss

    for epoch in range(epochs):
        logging.info(f"Starting epoch {epoch}:")
        train_loss_accumulator = 0.0
        val_loss_accumulator = 0.0
        train_steps = 0
        val_steps = 0
        
        model.train()
        pbar = tqdm(train_dataloader)
        for i, (images, tensor_info) in enumerate(pbar):
            images = images.to(device)
            
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
        
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss_accumulator += loss.item()
            train_steps += 1

            pbar.set_postfix(MSE=loss.item())
            # Compute average training loss for the epoch
        avg_training_loss = train_loss_accumulator / train_steps
        

        # Validation phase
        model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            for images, labels in val_dataloader:
                images = images.to(device)
                t = diffusion.sample_timesteps(images.shape[0]).to(device)
                x_t, noise = diffusion.noise_images(images, t)
                predicted_noise = model(x_t, t)
                val_loss = mse(noise, predicted_noise)

                val_loss_accumulator += val_loss.item()
                val_steps += 1

        # Compute average validation loss for the epoch
        avg_validation_loss = val_loss_accumulator / val_steps
        
        # At the end of each epoch, log the losses
        training_losses.append(avg_training_loss)
        validation_losses.append(avg_validation_loss)
        #logger.add_scalar("Average_Validation_MSE", avg_validation_loss, global_step=epoch)
    
        # Save model if validation loss has improved
        if avg_validation_loss < best_val_loss:
            best_val_loss = avg_validation_loss
            model_save_path = os.path.join("models", run_name, f"ckpt.pt")
            torch.save(model.state_dict(), model_save_path)
            
        ## took this out for just training the model
        #sampled_images = diffusion.sample(model, n=4)
        #save_images(sampled_images, os.path.join("results", run_name, f"{epoch}.bmp"))

        
        #break
    plt.figure(figsize=(10, 5))
    plt.plot(training_losses, label='Training loss')
    plt.plot(validation_losses, label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Save the figure
    plt.savefig('training_validation_unconditional.png', dpi=300, bbox_inches='tight')
    
    # Display the figure
    plt.show()

        

def launch(run_name="DDPM_500_20", epochs=500, batch_size=20, image_size=64,
          train_dataset_path="/path/to/dir/Iowa_img/Train",
           val_dataset_path="/path/to/dir/Iowa_img/Val",
          device="cuda", lr=3e-4):
   # You can pass these arguments directly to the train function
   train(run_name, epochs, batch_size, image_size, train_dataset_path, val_dataset_path, device, lr)
    
    
    
    
#def launchpbx_script():
#    import argparse
#    parser = argparse.ArgumentParser()
#    args = parser.parse_args()
#    args.run_name = "DDPM_conditional"
#    args.epochs = 300
#    args.batch_size = 14
#    args.image_size = 64
#    args.num_classes = 10
#    args.dataset_path = r"C:\Users\dome\datasets\cifar10\cifar10-64\train"
#    args.device = "cuda"
#    args.lr = 3e-4
#    train(args)