import os
import torch
import random
import lpips
import mediapipe as mp
import torch.nn.functional as F
from tqdm import tqdm
from torchmetrics.image import StructuralSimilarityIndexMeasure

from utils.loadconfig import load_config
from utils.get_lmk import extract_mouth_landmarks

from utils.dataset import GaussianDataset
from utils.dataloaderx import DataLoaderX
from models.mesh_head import MeshHeadModule
from models.gaussian_head import GaussianHeadModule
from modules.superresolution import SuperResolutionModule
from modules.vgg import VGGFeatureExtractor
from modules.discriminator import Discriminator
from utils.logger import GaussianHeadTrainRecorder


def random_crop(render_images, images, visibles, scale_factor, resolution_coarse, resolution_fine, device):
        render_images_scaled = F.interpolate(render_images, scale_factor=scale_factor)
        images_scaled = F.interpolate(images, scale_factor=scale_factor)
        visibles_scaled = F.interpolate(visibles, scale_factor=scale_factor)

        if scale_factor < 1:
            render_images = torch.ones([render_images_scaled.shape[0], render_images_scaled.shape[1], resolution_coarse, resolution_coarse], device=device)
            left_up_coarse = (random.randint(0, resolution_coarse - render_images_scaled.shape[2]), random.randint(0, resolution_coarse - render_images_scaled.shape[3]))
            render_images[:, :, left_up_coarse[0]: left_up_coarse[0] + render_images_scaled.shape[2], left_up_coarse[1]: left_up_coarse[1] + render_images_scaled.shape[3]] = render_images_scaled

            images = torch.ones([images_scaled.shape[0], images_scaled.shape[1], resolution_fine, resolution_fine], device=device)
            visibles = torch.ones([visibles_scaled.shape[0], visibles_scaled.shape[1], resolution_fine, resolution_fine], device=device)
            left_up_fine = (int(left_up_coarse[0] * resolution_fine / resolution_coarse), int(left_up_coarse[1] * resolution_fine / resolution_coarse))
            images[:, :, left_up_fine[0]: left_up_fine[0] + images_scaled.shape[2], left_up_fine[1]: left_up_fine[1] + images_scaled.shape[3]] = images_scaled
            visibles[:, :, left_up_fine[0]: left_up_fine[0] + visibles_scaled.shape[2], left_up_fine[1]: left_up_fine[1] + visibles_scaled.shape[3]] = visibles_scaled
        else:
            left_up_coarse = (random.randint(0, render_images_scaled.shape[2] - resolution_coarse), random.randint(0, render_images_scaled.shape[3] - resolution_coarse))
            render_images = render_images_scaled[:, :, left_up_coarse[0]: left_up_coarse[0] + resolution_coarse, left_up_coarse[1]: left_up_coarse[1] + resolution_coarse]

            left_up_fine = (int(left_up_coarse[0] * resolution_fine / resolution_coarse), int(left_up_coarse[1] * resolution_fine / resolution_coarse))
            images = images_scaled[:, :, left_up_fine[0]: left_up_fine[0] + resolution_fine, left_up_fine[1]: left_up_fine[1] + resolution_fine]
            visibles = visibles_scaled[:, :, left_up_fine[0]: left_up_fine[0] + resolution_fine, left_up_fine[1]: left_up_fine[1] + resolution_fine]
        
        return render_images, images, visibles

def create_mouth_mask(landmarks, image_shape, padding=0.2):
    batch_size = landmarks.shape[0]
    masks = []
    
    for b in range(batch_size):
        h, w = image_shape[2], image_shape[3]
        
        actual_landmarks = landmarks[b].clone()
        actual_landmarks[:, 0] *= w
        actual_landmarks[:, 1] *= h
        
        min_x, _ = torch.min(actual_landmarks[:, 0], dim=0)
        max_x, _ = torch.max(actual_landmarks[:, 0], dim=0)
        min_y, _ = torch.min(actual_landmarks[:, 1], dim=0)
        max_y, _ = torch.max(actual_landmarks[:, 1], dim=0)
        
        width = max_x - min_x
        height = max_y - min_y
        min_x = torch.clamp(min_x - width * padding, 0, w-1)
        max_x = torch.clamp(max_x + width * padding, 0, w-1)
        min_y = torch.clamp(min_y - height * padding, 0, h-1)
        max_y = torch.clamp(max_y + height * padding, 0, h-1)

        mask = torch.zeros((1, h, w), device=landmarks.device)
        x1, y1, x2, y2 = int(min_x), int(min_y), int(max_x), int(max_y)
        mask[:, y1:y2, x1:x2] = 1.0
        masks.append(mask)
    
    return torch.cat(masks, dim=0)


def compute_gradient_penalty(D, real_samples, fake_samples):
        alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(real_samples.device)
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        d_interpolates = D(interpolates)
        fake = torch.ones(d_interpolates.size()).to(real_samples.device)

        gradients = torch.autograd.grad(
            outputs=d_interpolates, 
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

def gram_matrix(x):
        B, C, H, W = x.size()
        features = x.view(B, C, H * W)
        gram = torch.bmm(features, features.transpose(1, 2))
        gram = gram / (C * H * W)  
        return gram

def get_bbox_mask(landmarks, mask_shape, region='mouth', padding=0.2):

    B, _, H, W = mask_shape
    masks = []
    if region == 'mouth':
        region_idx = [
            61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 
            78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191
        ]
        region_idx = list(sorted(set(region_idx)))
        padding = 0.35
    elif region == 'eye':
        region_idx = [33, 160, 158, 133, 153, 144, 362, 398, 384, 381, 382, 381, 380, 374, 263, 387, 385, 380, 373, 374, 386, 362, 398, 384, 381, 382, 381, 380, 373]
    else:
        raise ValueError('Unknown region')
    for b in range(B):
        region_lm = landmarks[b][region_idx]
        xs = region_lm[:, 0] * W
        ys = region_lm[:, 1] * H
        min_x, max_x = xs.min(), xs.max()
        min_y, max_y = ys.min(), ys.max()
        w_box = max_x - min_x
        h_box = max_y - min_y
        min_x = max(min_x - w_box * padding, 0)
        max_x = min(max_x + w_box * padding, W-1)
        min_y = max(min_y - h_box * padding, 0)
        max_y = min(max_y + h_box * padding, H-1)
        mask = torch.zeros((1, H, W), device=landmarks.device)
        mask[:, int(min_y):int(max_y), int(min_x):int(max_x)] = 1.0
        masks.append(mask)
    return torch.stack(masks, dim=0)

def extract_landmarks_batch(images, face_mesh):

    import numpy as np
    landmarks_batch = []
    images_np = (images.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
    for img in images_np:
        results = face_mesh.process(img)
        if results.multi_face_landmarks:
            lms = results.multi_face_landmarks[0].landmark
            lm_xy = np.array([[lm.x, lm.y] for lm in lms], dtype=np.float32)
            if lm_xy.shape[0] != 468:
                lm_xy = np.zeros((468, 2), dtype=np.float32)
        else:
            lm_xy = np.zeros((468, 2), dtype=np.float32)
        landmarks_batch.append(torch.from_numpy(lm_xy))
    return torch.stack(landmarks_batch, dim=0).to(images.device)

def train(dataloader, delta_poses, model, supres, optimizer, logger, device, epochs):
    model.train()
    fn_lpips = lpips.LPIPS(net='vgg').to(device)
    feature_extractor = VGGFeatureExtractor().to(device)
    discriminator = Discriminator().to(device)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    mp_face_mesh = mp.solutions.face_mesh
    face_mesh = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1, min_detection_confidence=0.5)

    for epoch in range(epochs):
        for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
            data = {key : value.to(device) for key, value in data.items()}
            optimizer.zero_grad()
            to_cuda = ['images', 'masks', 'visibles', 'images_coarse', 'masks_coarse', 'visibles_coarse',
                           'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
                           'pose', 'scale', 'exp_coeff', 'ear', 'landmarks_3d', 'exp_id']
            
            
            for data_item in to_cuda:
                data[data_item] = data[data_item].to(device)
            
            images = data['images']
            visibles = data['visibles']
            if supres is None:
                images_coarse = images
                visibles_coarse = visibles
            else:
                images_coarse = data['images_coarse']
                visibles_coarse = data['visibles_coarse']

            resolution_coarse = images_coarse.shape[2]
            resolution_fine = images.shape[2]

            data['pose'] = data['pose'] + delta_poses[data['exp_id'], :]

            data = model(data, resolution_coarse)
            
            render_images = data['render_images']
            scale_factor = random.random() * 0.45 + 0.8
            scale_factor = int(resolution_coarse * scale_factor) / resolution_coarse
            cropped_render_images, cropped_images, cropped_visibles = random_crop(render_images, images, visibles, scale_factor, resolution_coarse, resolution_fine, device)
            data['cropped_images'] = cropped_images

            supres_images = supres(cropped_render_images)
            data['supres_images'] = supres_images

            with torch.no_grad():
                gen_lmk = extract_landmarks_batch(data['supres_images'], face_mesh)  # [B, 468, 2]
                gt_lmk = extract_landmarks_batch(data['cropped_images'], face_mesh)          # [B, 468, 2]

            loss_lmk = torch.nn.functional.mse_loss(gen_lmk, gt_lmk)

            mouth_mask = get_bbox_mask(gt_lmk, supres_images.shape, region='mouth', padding=0.2)
            eye_mask = get_bbox_mask(gt_lmk, supres_images.shape, region='eye', padding=0.2)
            
            loss_mouth = F.l1_loss(
                (supres_images * cropped_visibles) * mouth_mask,
                (cropped_images * cropped_visibles) * mouth_mask
            )
            loss_eye = F.l1_loss(
                (supres_images * cropped_visibles) * eye_mask,
                (cropped_images * cropped_visibles) * eye_mask
            )

            loss_mouth_features = 0.0
            if hasattr(model, 'mouth_points_indices'):
                mouth_scales = torch.exp(model.scales[model.mouth_points_indices])
                loss_mouth_features = F.relu(mouth_scales.mean() - 0.015) * 500.0

                if hasattr(model, 'teeth_indices'):
                    teeth_features = model.feature[model.teeth_indices]
                    teeth_whiteness_target = torch.ones_like(teeth_features.mean()) * 0.85
                    loss_mouth_features += F.mse_loss(teeth_features.mean(), teeth_whiteness_target) * 100.0

            loss_rgb_lr = F.l1_loss(render_images[:, 0:3, :, :] * visibles_coarse, images_coarse * visibles_coarse)
            loss_rgb_hr = F.l1_loss(supres_images * cropped_visibles, cropped_images * cropped_visibles)
            left_up = (random.randint(0, supres_images.shape[2] - 512), random.randint(0, supres_images.shape[3] - 512))
            loss_vgg = fn_lpips((supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], 
                                        (cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], normalize=True).mean()
            
            # Compute SSIM loss
            loss_ssim = 1.0 - ssim_metric((supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], 
                                          (cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512])
            
            gen_features = feature_extractor((supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512])
            target_features = feature_extractor((cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512])

            loss_feature = 0
            for gen_feat, target_feat in zip(gen_features, target_features):
                gen_gram = gram_matrix(gen_feat)
                target_gram = gram_matrix(target_feat)
                loss_feature += F.mse_loss(gen_gram, target_gram)

            real_loss = -torch.mean(discriminator((cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512]))
            fake_loss = torch.mean(discriminator((supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512]))
            gradient_penalty = compute_gradient_penalty(discriminator, (cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], (supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512])

            d_loss = real_loss + fake_loss + 10 * gradient_penalty
            loss = loss_rgb_hr + loss_rgb_lr + loss_vgg  + loss_ssim  + loss_feature * 1e6 + d_loss * 1e-7 + loss_lmk + loss_mouth_features + 1e2 * (loss_eye + loss_mouth)
            # loss = loss_rgb_hr + loss_rgb_lr + loss_vgg * 1e-1 + loss_feature * 1e6 + d_loss * 1e-7 + loss_lmk + loss_mouth_features + 1e2 * (loss_eye) + (loss_mouth_features)
            
            optimizer.zero_grad()
            loss.backward()
        
            if hasattr(model, 'mouth_points_indices') and model.xyz.grad is not None:
                model.xyz.grad[model.mouth_points_indices] *= 2.0
                
                if model.feature.grad is not None:
                    model.feature.grad[model.mouth_points_indices] *= 2.0
                
                if model.scales.grad is not None:
                    scale_grad = model.scales.grad[model.mouth_points_indices]
                    scale_mask = (scale_grad > 0)
                    model.scales.grad[model.mouth_points_indices][scale_mask] *= 2.0
                if hasattr(model, 'teeth_indices') and hasattr(model, 'opacity') and model.opacity.grad is not None:
                    model.opacity.grad[model.teeth_indices] *= 1.5 

            optimizer.step()

            to_log = {
                'data': data,
                'delta_poses' : delta_poses,
                'gaussianhead' : model,
                'supres' : supres,
                "loss" : loss,
                'loss_rgb_lr' : loss_rgb_lr,
                'loss_rgb_hr' : loss_rgb_hr,
                'loss_vgg' : loss_vgg,
                'loss_ssim' : loss_ssim,
                'loss_feature' : loss_feature,
                'd_loss' : d_loss,
                'loss_lmk' : loss_lmk,
                'loss_mouth' : loss_mouth,
                'loss_eye' : loss_eye,
                'loss_mouth_features' : loss_mouth_features,
                'epoch' : epoch,
                'iter' : i + epoch * len(dataloader),
                'loader_length' : len(dataloader),
            }
            logger.log(to_log)


if __name__ == '__main__':
    
    conf_path = 'configs/train_gaussian_mono.yaml'
    cfg = load_config(conf_path)

    dataset = GaussianDataset(cfg['dataset'])
    dataloader = DataLoaderX(dataset, batch_size=cfg['batch_size'], shuffle=True, pin_memory=True) 

    device = torch.device(cfg['device'])
    torch.cuda.set_device(device)
    
    if os.path.exists(cfg['load_gaussianhead_checkpoint']):
        gaussianhead_state_dict = torch.load(cfg['load_gaussianhead_checkpoint'], map_location=lambda storage, loc: storage)
        gaussianhead = GaussianHeadModule(cfg['gaussianheadmodule'], 
                                          xyz=gaussianhead_state_dict['xyz'], 
                                          feature=gaussianhead_state_dict['feature'],
                                          landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
        gaussianhead.load_state_dict(gaussianhead_state_dict)
    else:
        meshhead_state_dict = torch.load(cfg['load_meshhead_checkpoint'], map_location=lambda storage, loc: storage)
        meshhead = MeshHeadModule(cfg['meshheadmodule'], meshhead_state_dict['landmarks_3d_neutral']).to(device)
        meshhead.load_state_dict(meshhead_state_dict)
        meshhead.subdivide()
        with torch.no_grad():
            data = meshhead.reconstruct_neutral()

        gaussianhead = GaussianHeadModule(cfg['gaussianheadmodule'], 
                                          xyz=data['verts'].cpu(),
                                          feature=torch.atanh(data['verts_feature'].cpu()), 
                                          landmarks_3d_neutral=meshhead.landmarks_3d_neutral.detach().cpu(),
                                          add_mouth_points=True).to(device)
        gaussianhead.exp_color_mlp.load_state_dict(meshhead.exp_color_mlp.state_dict())
        gaussianhead.pose_color_mlp.load_state_dict(meshhead.pose_color_mlp.state_dict())
        gaussianhead.exp_deform_mlp.load_state_dict(meshhead.exp_deform_mlp.state_dict())
        gaussianhead.pose_deform_mlp.load_state_dict(meshhead.pose_deform_mlp.state_dict())
        gaussianhead.pts_embedding.load_state_dict(meshhead.pts_embedding.state_dict())
    
    supres = SuperResolutionModule(cfg['supresmodule']).to(device)
    if os.path.exists(cfg['load_supres_checkpoint']):
        supres.load_state_dict(torch.load(cfg['load_supres_checkpoint'], map_location=lambda storage, loc: storage))
        
    recorder = GaussianHeadTrainRecorder(cfg['recorder'])

    optimized_parameters = [{'params' : supres.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.xyz, 'lr' : float(cfg['lr_net']) * 0.1},
                            {'params' : gaussianhead.feature, 'lr' : float(cfg['lr_net']) * 0.1},
                            {'params' : gaussianhead.exp_color_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.pose_color_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.exp_deform_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.pose_deform_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.exp_attributes_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.pose_attributes_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.pts_embedding.parameters(), 'lr' : float(cfg['lr_net'])},
                            {'params' : gaussianhead.scales, 'lr' : float(cfg['lr_net']) * 0.3},
                            {'params' : gaussianhead.rotation, 'lr' : float(cfg['lr_net']) * 0.1},
                            {'params' : gaussianhead.opacity, 'lr' : float(cfg['lr_net'])}]
    
    if os.path.exists(cfg['load_delta_poses_checkpoint']):
        delta_poses = torch.load(cfg['load_delta_poses_checkpoint'])
    else:
        delta_poses = torch.zeros([dataset.num_exp_id, 6]).to(device)

    if cfg['optimize_pose']:
        delta_poses = delta_poses.requires_grad_(True)
        optimized_parameters.append({'params' : delta_poses, 'lr' : float(cfg['lr_pose'])})
    else:
        delta_poses = delta_poses.requires_grad_(False)

    optimizer = torch.optim.Adam(optimized_parameters)
    
    delta_poses = delta_poses.to(device)
    
    train(dataloader, delta_poses, gaussianhead, supres, optimizer, recorder, device, 1000)

