import os
import copy
import math
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import *
from off_moo_baselines.diffusion_guidance.modules import EMA
import logging
from torch.func import functional_call, vmap, grad


logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

def train(dataloader, model=None, diffusion=None, ema=None):
    X_size = dataloader.dataset[0][0].shape[-1]
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    if model is None:
            model = Model_unconditional(dim=X_size).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=5e-4)
    if diffusion is None:
        diffusion = Diffusion(img_size=X_size, device=device)
    l = len(dataloader)
    if ema is None:
        ema = EMA(0.99)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)

    for epoch in range(200):
        loss_epoch = []
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, labels, hvs) in enumerate(pbar):
            images = images.to(device)
            labels = labels.to(device)
            images = images * 2 - 1
            hvs = hvs.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = ((noise - predicted_noise).pow(2).sum(-1)).mean()
            loss_epoch.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            ema.step_ema(ema_model, model)

            pbar.set_postfix(MSE=loss.item())
        logging.info(f"Epoch {epoch} loss: {np.mean(loss_epoch)}")
    return model, diffusion

def train_preference(dataloader, model=None, diffusion = None, val_loader=None, config=None, tolerance = 50, model_save_path=None, three_dim_out=False):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    if model is None:
        model = Preference_model(input_dim=dataloader.dataset[0][0].shape[-1], device=device, three_dim_out=three_dim_out).to(device)
    if diffusion is None:
        diffusion = Diffusion(img_size=dataloader.dataset[0][0].shape[-1], device=device).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    best_val_loss = 1e10
    curr_tol = 0
    for epoch in range(2000):
        loss_epoch = []
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (x_1, x_2, y) in enumerate(pbar):
            x_1 = x_1.to(device)
            x_2 = x_2.to(device)
            y = y.to(device)
            x_1 = x_1 * 2 - 1
            x_2 = x_2 * 2 - 1
            t = diffusion.sample_timesteps(x_1.shape[0]).to(device)
            x_1, _ = diffusion.noise_images(x_1, t)
            x_2, _ = diffusion.noise_images(x_2, t)
            pred = model(x_1, x_2, t)
            loss_1 = loss_fn(pred, y.squeeze().long())
            loss = torch.mean(loss_1)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(MSE=loss.item())
            loss_epoch.append(loss.item())

        if val_loader is not None:
            model.eval()
            if epoch % 5 == 0:
                val_loss = []
                for epoch in range(5):
                    for i, (x_1, x_2, y) in enumerate(val_loader):
                        x_1 = x_1.to(device)
                        x_2 = x_2.to(device)
                        x_1 = x_1 * 2 - 1
                        x_2 = x_2 * 2 - 1
                        t = diffusion.sample_timesteps(x_1.shape[0]).to(device)
                        x_1, _ = diffusion.noise_images(x_1, t)
                        x_2, _ = diffusion.noise_images(x_2, t)
                        y = y.to(device)
                        pred = model(x_1, x_2, t)
                        loss_1 = loss_fn(pred, y.squeeze().long())
                        loss = torch.mean(loss_1)
                        val_loss.append(loss.item())
                model.train()
                logging.info(f"Epoch {epoch} loss: {np.mean(loss_epoch)} val_loss: {np.mean(val_loss)}")
                if np.mean(val_loss) < best_val_loss:
                    best_val_loss = np.mean(val_loss)
                    save_model(model, model_save_path, device)   
                    curr_tol = 0
                else:
                    curr_tol += 1
                    if curr_tol > tolerance:
                        break 
        logging.info(f"Epoch {epoch} loss: {np.mean(loss_epoch)}")
    return model

class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample_with_preference(self, model, n, preference_model, best_x_data, cfg_scale=3, return_latents=False, ddim=False):
        def compute_grad(x, best_x_data, t, preference_model, params, buffers):
            x = x.unsqueeze(0)
            best_x_data = best_x_data.unsqueeze(0)
            predictions = functional_call(preference_model, (params, buffers), (x,best_x_data, t))
            pref_logits = torch.nn.functional.log_softmax(predictions, dim=-1)
            pref_logits = pref_logits[..., 0].squeeze()
            return pref_logits
        params = {k: v.detach() for k, v in preference_model.named_parameters()}
        buffers = {k: v.detach() for k, v in preference_model.named_buffers()}
        logging.info(f"Sampling {n} new images....")
        model.eval()
        preference_model.eval()
        latents = []
        
        x = torch.randn((n, self.img_size)).to(self.device)
        best_x_data = best_x_data.to(x.dtype).to(self.device)
        best_x_data = 2 * best_x_data - 1
        for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
            if i%10 == 0 and return_latents:
                latents.append(x)
            t = (torch.ones(n) * i).long().to(self.device)
            with torch.no_grad():
                predicted_noise = model(x, t)
            if cfg_scale > 0:
                x_ = x.clone()
                x_ = x_.detach().requires_grad_(True)
                score = vmap(grad(compute_grad), (0, 0, 0, None, None, None))(x_, best_x_data, t, preference_model, params, buffers)
                best_x_data = x_.detach()
                score = score.detach()
            else:
                score = 0
            alpha = self.alpha[t][:, None]
            alpha_hat = self.alpha_hat[t][:, None]
            alpha_hat_t_1 = self.alpha_hat[t-1][:, None]
            beta = self.beta[t][:, None]
            if i > 1 and not ddim:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            if ddim:
                x = ((torch.sqrt(alpha_hat_t_1) / torch.sqrt(alpha_hat)) * (x -  (torch.sqrt(1 - alpha_hat) * predicted_noise))) + torch.sqrt(1-alpha_hat_t_1) * cfg_scale * score
            else:
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise + cfg_scale * score * beta
            #if (x>1).sum() + (x<-1).sum() > 0:
            #    print((x>1).sum() + (x<-1).sum(),i)
            #x = x.clamp(-1, 1)
            if i%100==0:
                x = x.clamp(-1, 1)
        model.train()
        x = (x.clamp(-1, 1)+1)/2
        latents = [(latent.clamp(-1, 1)+1)/2 for latent in latents]
        latents.append(x)
        if return_latents:
            return x, torch.stack(latents)
        #x = (x * 255).type(torch.uint8)
        return x

class Model_unconditional(nn.Module):
    def __init__(self, dim=256, device="cuda", save_path=None):
        super().__init__()
        self.device = device
        self.dim = dim
        self.save_path = save_path
        self.mlp = nn.Sequential(
            nn.Linear(dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, dim),
            nn.LayerNorm(dim)
        )
        self.time_embed = nn.Linear(1, dim)

    def pos_encoding(self, t, dim):
        half_dim = dim // 2
        freq =  torch.exp(
            math.log(10000)
            * (torch.arange(0, half_dim, device=self.device).float() / half_dim)).to(self.device)
        pos_enc_a = torch.sin(t.repeat(1, half_dim) * freq.unsqueeze(0))
        pos_enc_b = torch.cos(t.repeat(1, half_dim) * freq.unsqueeze(0))
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        if dim % 2:
            pos_enc = torch.cat([pos_enc, torch.zeros_like(pos_enc[:, :1])], dim=-1)
        return pos_enc


    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.dim)
        output = self.mlp(x + t)
        return output
    
class Preference_model(nn.Module):
    def __init__(self, input_dim=256, device="cuda", save_path=None, three_dim_out = False):
        super().__init__()
        self.device = device
        self.input_dim = input_dim
        self.save_path = save_path
        self.three_dim_out = three_dim_out
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.LayerNorm(input_dim),)
        self.time_embed = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.LayerNorm(input_dim),
        )
        self.preference = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.LayerNorm(512)
        )
        self.out_1 = nn.Linear(512, 1)
        if self.three_dim_out:
            self.out_2 = nn.Linear(512, 1)
    
    def pos_encoding(self, t, dim):
        half_dim = dim // 2
        freq =  torch.exp(
            math.log(10000)
            * (torch.arange(0, half_dim, device=self.device).float() / half_dim)).to(self.device)
        pos_enc_a = torch.sin(t.repeat(1, half_dim) * freq.unsqueeze(0))
        pos_enc_b = torch.cos(t.repeat(1, half_dim) * freq.unsqueeze(0))
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        if dim % 2 == 1:
            pos_enc = torch.cat([pos_enc, torch.zeros_like(pos_enc[:, :1])], dim=-1)
        return pos_enc
    
    def forward(self, x_1, x_2, t):
        t = t.unsqueeze(-1).type(torch.float)
        em = torch.stack([x_1,x_2], axis=1)
        #em_1 = self.mlp(x_1)
        #em_2 = self.mlp(x_2)
        #em = em_1 + em_2
        t = self.time_embed(self.pos_encoding(t, self.input_dim)).unsqueeze(-2)
        em = self.mlp(em + t)
        #t = t.unsqueeze(-1).type(torch.float)
        output = self.out_1(self.preference(em + t))
        if self.three_dim_out:
            output_2 = self.out_2(self.preference(em + t)).sum(-2, keepdim=True)
            output = torch.cat([output, output_2], dim=-2)
        output = output.squeeze(-1)
        return output

def save_model(model, save_path, device="cuda"):
    torch.save(model.to("cpu").state_dict(), save_path)
    model.to(device)

def load_model(model, save_path, device="cuda"):
    model.load_state_dict(torch.load(save_path))
    model.to(device)
    print(f"Successfully load trained model from {save_path}")