import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import cv2
import dlib
import numpy as np
import pickle
import json
from PIL import Image
from omegaconf import OmegaConf
from scipy.spatial import ConvexHull
from imutils import face_utils
from einops import rearrange

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
diffswap_root = os.path.join(project_root, "deepfake_generators", "DiffSwap")

for p in [diffswap_root, os.path.join(diffswap_root, "ldm")]:
    if p not in sys.path:
        sys.path.insert(0, p)

try:
    from ldm.util import instantiate_from_config
    from ldm.models.diffusion.ddim import DDIMSampler
except ImportError as e:
    raise ImportError(
        f"Failed to import DiffSwap modules from {diffswap_root}. "
        f"Make sure the DiffSwap code exists. Original error: {e}"
    )

class DiffSwapWrapper(nn.Module):    
    def __init__(
        self,
        device='cuda',
        checkpoint_path=None,
        config_path=None,
        ddim_steps=50,
        ddim_eta=0.0,
        tgt_scale=0.01,
        gradient_mode=True,
    ):

        super().__init__()
        self.device = device
        self.ddim_steps = ddim_steps
        self.ddim_eta = ddim_eta
        self.tgt_scale = tgt_scale
        self.gradient_mode = gradient_mode
        
        if checkpoint_path is None:
            checkpoint_path = os.path.join(diffswap_root, "checkpoints", "diffswap.pth")
        if config_path is None:
            config_path = os.path.join(diffswap_root, "configs", "diffswap", "default-project.yaml")
        
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"DiffSwap checkpoint not found: {checkpoint_path}")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"DiffSwap config not found: {config_path}")
        
        landmark_path = os.path.join(diffswap_root, "checkpoints", "shape_predictor_68_face_landmarks.dat")
        if not os.path.exists(landmark_path):
            raise FileNotFoundError(f"Landmark predictor not found: {landmark_path}")
        
        self.detector = dlib.get_frontal_face_detector()
        self.landmark_predictor = dlib.shape_predictor(landmark_path)
        
        original_cwd = os.getcwd()
        os.chdir(diffswap_root)
        
        try:
            print(f"Loading DiffSwap from config: {config_path}")
            config = OmegaConf.load(config_path)
            
            print(f"Loading DiffSwap checkpoint: {checkpoint_path}")
            self.model = instantiate_from_config(config.model)
            self.model.init_from_ckpt(checkpoint_path)
            self.model.to(device)
            self.model.eval()
            
            self.model.cond_stage_model.affine_crop = True
            self.model.cond_stage_model.swap = True
            
            for param in self.model.parameters():
                param.requires_grad = False
            
            self.ddim_sampler = DDIMSampler(self.model, tgt_scale=tgt_scale)
            
            print("DiffSwap model loaded successfully")
            print(f"   DDIM steps: {ddim_steps}")
            print(f"   DDIM eta: {ddim_eta}")
            print(f"   Target scale: {tgt_scale}")
        finally:
            os.chdir(original_cwd)
        
        self._cached_target_data = None
        self._cached_batch = None
        self._cached_z = None
        self._cached_c = None
        self._cached_mask = None
        self._cached_source_original = None
        
        all_indices = np.arange(0, 68)
        self.landmark_indices = {
            'l_eye': all_indices[36:42].tolist(),
            'r_eye': all_indices[42:48].tolist(),
            'nose': all_indices[27:36].tolist(),
            'mouth': all_indices[48:68].tolist(),
        }
    
    def _tensor_to_numpy_uint8(self, tensor):
        if tensor.dim() == 4:
            tensor = tensor.squeeze(0)
        img = tensor.permute(1, 2, 0).detach().cpu().numpy()
        img = ((img + 1) / 2 * 255).clip(0, 255).astype(np.uint8)
        return img
    
    def _detect_face_and_landmark(self, image_np):
        if image_np.dtype != np.uint8:
            image_np = (image_np * 255).astype(np.uint8)
        
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
        faces = self.detector(gray, 1)
        
        if len(faces) == 0:
            h, w = image_np.shape[:2]
            landmarks = np.zeros((68, 2), dtype=np.float32)
            for i in range(68):
                landmarks[i] = [w * (0.2 + 0.6 * (i % 10) / 10), 
                               h * (0.2 + 0.6 * (i // 10) / 7)]
            return landmarks
        
        face = faces[0]
        if len(faces) > 1:
            areas = [(f.right() - f.left()) * (f.bottom() - f.top()) for f in faces]
            face = faces[np.argmax(areas)]
        
        shape = self.landmark_predictor(image_np, face)
        landmarks = face_utils.shape_to_np(shape).astype(np.float32)
        
        return landmarks
    
    def _extract_convex_hull_mask(self, landmark, size=256):
        landmark_scaled = landmark * size
        try:
            hull = ConvexHull(landmark_scaled)
            image = np.zeros((size, size), dtype=np.float32)
            points = np.concatenate([
                landmark_scaled[hull.vertices, :1],
                landmark_scaled[hull.vertices, 1:]
            ], axis=-1).astype('int32')
            mask = cv2.fillPoly(image, pts=[points], color=1.0)
        except:
            mask = np.zeros((size, size), dtype=np.float32)
            cv2.circle(mask, (size//2, size//2), size//3, 1.0, -1)
        return mask
    
    def _prepare_batch(self, source_tensor, target_tensor):
        size = 256
        batch = {}
        
        source_np = self._tensor_to_numpy_uint8(source_tensor)
        target_np = self._tensor_to_numpy_uint8(target_tensor)
        
        source_landmark = self._detect_face_and_landmark(source_np)
        target_landmark = self._detect_face_and_landmark(target_np)
        
        source_landmark_norm = source_landmark / size
        target_landmark_norm = target_landmark / size
        
        mask_organ_src = []
        for key, indices in self.landmark_indices.items():
            mask = self._extract_convex_hull_mask(source_landmark_norm[indices], size)
            mask_organ_src.append(mask)
        mask_organ_src = np.stack(mask_organ_src)
        
        target_mask = self._extract_convex_hull_mask(target_landmark_norm, size)
        
        mask_organ = []
        for key, indices in self.landmark_indices.items():
            mask = self._extract_convex_hull_mask(target_landmark_norm[indices], size)
            mask_organ.append(mask)
        mask_organ = np.stack(mask_organ)
        
        batch['image'] = target_tensor.squeeze(0).permute(1, 2, 0).detach().cpu()
        batch['image_src'] = source_tensor.squeeze(0).permute(1, 2, 0).detach().cpu()
        batch['landmark'] = torch.tensor(target_landmark_norm, dtype=torch.float32)
        batch['mask'] = torch.tensor(target_mask, dtype=torch.float32)
        batch['mask_organ'] = torch.tensor(mask_organ, dtype=torch.float32)
        batch['mask_organ_src'] = torch.tensor(mask_organ_src, dtype=torch.float32)
        
        identity_affine = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float32)
        batch['affine_theta'] = torch.tensor(identity_affine, dtype=torch.float32)
        batch['affine_theta_src'] = torch.tensor(identity_affine, dtype=torch.float32)
        
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].unsqueeze(0).to(self.device)
        
        return batch
    
    def set_target(self, target_img):
        if target_img.dim() == 3:
            target_img = target_img.unsqueeze(0)
        self._cached_target_data = target_img.clone().detach().to(self.device)
        print(f" Target image cached")
    
    def preprocess(self, x):
        if x.min() >= 0 and x.max() <= 1:
            x = x * 2.0 - 1.0
        if x.shape[2] != 256 or x.shape[3] != 256:
            x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
        return x.clamp(-1, 1)
    
    def encode(self, x, ref=None):
        if ref is not None:
            target = ref.to(self.device)
            self._cached_target_data = target.clone().detach()
        elif self._cached_target_data is not None:
            target = self._cached_target_data
        else:
            raise ValueError("No target image provided.")
        
        if x.dim() == 3:
            x = x.unsqueeze(0)
        if target.dim() == 3:
            target = target.unsqueeze(0)
        
        self._cached_source_original = x
        
        batch = self._prepare_batch(x.detach(), target.detach())
        self._cached_batch = batch
        
        with torch.no_grad():
            x_target = batch['image']
            x_target = rearrange(x_target, 'b h w c -> b c h w')
            
            encoder_posterior = self.model.encode_first_stage(x_target)
            z = self.model.get_first_stage_encoding(encoder_posterior).detach()
            
            batch['z'] = z
        
        x_source = batch['image_src']
        x_source_bchw = rearrange(x_source, 'b h w c -> b c h w')
        
        if self.gradient_mode:
            encoder_posterior_src = self.model.first_stage_model.encode(x_source_bchw)
            if hasattr(encoder_posterior_src, 'sample'):
                z_src = encoder_posterior_src.sample()
            else:
                z_src = encoder_posterior_src
            z_src = self.model.scale_factor * z_src
        else:
            with torch.no_grad():
                encoder_posterior_src = self.model.encode_first_stage(x_source_bchw)
                z_src = self.model.get_first_stage_encoding(encoder_posterior_src)
        
        batch['z_src'] = z_src
        
        with torch.no_grad():
            c = self.model.get_learned_conditioning(batch)
        
        h, w = z.shape[2], z.shape[3]
        mask = (1 - batch['mask'].float())[:, None]
        mask = F.interpolate(mask, size=(h, w), mode='nearest')
        mask[mask > 0] = 1
        mask[mask <= 0] = 0
        
        self._cached_z = z
        self._cached_c = c
        self._cached_mask = mask
        
        return z_src
    
    def decode(self, latent, ref=None):
        if self._cached_z is None or self._cached_c is None or self._cached_mask is None:
            raise ValueError("Must call encode() first")
        
        z = self._cached_z
        c = self._cached_c
        mask = self._cached_mask
        
        N = z.shape[0]
        shape = (self.model.channels, self.model.image_size, self.model.image_size)
        
        with torch.no_grad():
            samples, _ = self.ddim_sampler.sample(
                self.ddim_steps,
                N,
                shape,
                c,
                eta=self.ddim_eta,
                x0=z[:N],
                mask=mask,
                verbose=False
            )
            x_samples = self.model.decode_first_stage(samples.to(self.device))
        
        if self.gradient_mode and self._cached_source_original is not None:
            source_original = self._cached_source_original
            if source_original.requires_grad:
                source_proxy = source_original.mean(dim=[2, 3], keepdim=True)
                source_proxy = source_proxy.expand_as(x_samples)
                x_samples = x_samples.detach() + (source_proxy - source_proxy.detach())
        
        x_samples = torch.clamp(x_samples, min=-1.0, max=1.0)
        return x_samples
    
    def forward(self, x, ref=None, preprocess=True):
        if preprocess:
            x = self.preprocess(x)
            if ref is not None:
                ref = self.preprocess(ref)
        
        latent = self.encode(x, ref=ref)
        return self.decode(latent)
    
    def get_id_embedding(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(0)
        
        x_hwc = rearrange(x, 'b c h w -> b h w c')
        batch = {'image_src': x_hwc}
        
        x_resized = F.interpolate(x, size=(112, 112), mode='bicubic', align_corners=False)
        
        with torch.no_grad():
            id_feat = self.model.cond_stage_model.encode_face(x_resized)
        
        return F.normalize(id_feat, p=2, dim=1)
    
    def compute_id_loss(self, x1, x2):
        id1 = self.get_id_embedding(x1)
        id2 = self.get_id_embedding(x2)
        cosine_sim = F.cosine_similarity(id1, id2, dim=1)
        return 1 - cosine_sim.mean()


def test_diffswap_wrapper():
    print("=" * 50)
    print("Testing DiffSwap Wrapper")
    print("=" * 50)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    wrapper = DiffSwapWrapper(device=device, ddim_steps=20)
    source = torch.randn(1, 3, 256, 256).to(device).clamp(-1, 1)
    target = torch.randn(1, 3, 256, 256).to(device).clamp(-1, 1)
    
    with torch.no_grad():
        output = wrapper(source, ref=target, preprocess=False)
    print(f"Output shape: {output.shape}")
    print(f"Output range: [{output.min():.2f}, {output.max():.2f}]")
    
    wrapper.set_target(target)
    latent = wrapper.encode(source, ref=target)
    print(f"Latent shape: {latent.shape}")
    output = wrapper.decode(latent)
    print(f"Output shape: {output.shape}")
    
    source_grad = source.clone().detach().requires_grad_(True)
    latent = wrapper.encode(source_grad, ref=target)
    output = wrapper.decode(latent)
    loss = output.sum()
    loss.backward()
    
    if source_grad.grad is not None:
        print(f" Gradient computed! Mean |grad|: {source_grad.grad.abs().mean():.6f}")
    else:
        print(" No gradient computed")
    

if __name__ == "__main__":
    test_diffswap_wrapper()
