import os
import torch
import torch.cuda.amp
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import timm
from torch.optim import Adam
import pandas as pd

from pretrain_propensity_model import PropensityModel, NUM_TREATMENTS
from data.pairs_dataset.dataset import KneeFeatureConditioningDataset, FEATURE_DIM, PERCEPTUAL_FEATURES

torch.set_float32_matmul_precision('high')

# ---------------------------
# Set up the propensity model.
# Note: Use the full FEATURE_DIM from your dataset instead of len(PERCEPTUAL_FEATURES)
propensity_model = PropensityModel(feature_dim=FEATURE_DIM, num_treatments=NUM_TREATMENTS).to("cuda")
propensity_model_path = "/home/acc/Treatment_Modeling/Temporal-Treatment-Modeling/checkpoints/Propensity_Model/propensity_model.pth"
propensity_model.load_state_dict(torch.load(propensity_model_path, map_location="cuda"))
propensity_model.eval()
for param in propensity_model.parameters():
    param.requires_grad = False

# -----------------------------------------------------------------------------
# DiffusionLitModule with Dual Feature Conditioning, IPW Diffusion Loss, & Feature Loss
# -----------------------------------------------------------------------------
class DiffusionLitModule(pl.LightningModule):
    def __init__(self, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 latent_scale=0.18215, lr=1e-5, feature_loss_weight=1.0, propensity_model=None):
        super().__init__()
        self.vae = vae.eval()
        self.unet = unet
        self.scheduler = scheduler
        self.text_encoder = text_encoder.eval()
        self.tokenizer = tokenizer
        self.img_conditioning = img_conditioning
        self.latent_scale = latent_scale
        self.lr = lr
        self.feature_loss_weight = feature_loss_weight
        self.save_hyperparameters(ignore=['vae', 'unet', 'scheduler', 'text_encoder', 'tokenizer', 'img_conditioning', 'propensity_model'])
        
        # Freeze VAE and text encoder.
        for param in self.vae.parameters():
            param.requires_grad = False
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        self.propensity_model = propensity_model
        if self.propensity_model is not None:
            self.propensity_model.eval()
            for param in self.propensity_model.parameters():
                param.requires_grad = False

        self.transform_to_pil = T.ToPILImage()

        device = next(self.text_encoder.parameters()).device
        dummy_text = self.tokenizer("dummy", return_tensors="pt").input_ids.to(device)
        dummy_embed = self.text_encoder(dummy_text)[0]
        text_embed_dim = dummy_embed.size(-1)
        # Use the full FEATURE_DIM from the concatenation of bilateral x‑ray grades, clinical and demographic info.
        self.feature_conditioner = nn.Linear(FEATURE_DIM, text_embed_dim)
        
        # Load a separate pretrained predictor for each perceptual feature.
        self.feature_predictors = nn.ModuleDict()
        for feat in PERCEPTUAL_FEATURES:
            model_path = os.path.join("checkpoints", feat, f"{feat}_model.pth")
            state_dict = torch.load(model_path, map_location="cpu")
            num_outputs = state_dict['head.weight'].shape[0]
            predictor = timm.create_model('efficientformerv2_l.snap_dist_in1k', pretrained=True)
            predictor.reset_classifier(num_outputs)
            predictor.load_state_dict(state_dict)
            predictor.eval()
            for param in predictor.parameters():
                param.requires_grad = False
            self.feature_predictors[feat] = predictor.to(device)
    
    def setup(self, stage=None):
        self.pipe = StableDiffusionImg2ImgPipeline(
            vae=self.vae,
            unet=self.unet,
            scheduler=self.scheduler,
            text_encoder=self.text_encoder,
            tokenizer=self.tokenizer,
            safety_checker=None,
            feature_extractor=None,
        ).to(self.device)
        self.feature_conditioner.to(self.device)
        for feat in self.feature_predictors:
            self.feature_predictors[feat].to(self.device)
    
    def forward(self, input_images, target_images, prompts, earlier_features):
        input_latents = self.vae.encode(input_images).latent_dist.sample() * self.latent_scale
        target_latents = self.vae.encode(target_images).latent_dist.sample() * self.latent_scale

        timesteps = torch.randint(
            0, self.scheduler.config.num_train_timesteps, 
            (input_latents.size(0),), device=self.device
        ).long()
        noise = torch.randn_like(target_latents)
        noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps)

        text_inputs = self.tokenizer(
            prompts, padding="max_length", max_length=self.tokenizer.model_max_length, 
            return_tensors="pt"
        ).input_ids.to(self.device)
        text_embeddings = self.text_encoder(text_inputs)[0]

        img_cond = self.img_conditioning(input_latents)
        conditioning_feats = self.feature_conditioner(earlier_features.to(self.device))
        conditioning = text_embeddings + img_cond.unsqueeze(1) + conditioning_feats.unsqueeze(1)

        pred_noise = self.unet(noisy_latents, timesteps, encoder_hidden_states=conditioning).sample
        return pred_noise, noise, noisy_latents, timesteps, target_latents
    
    def training_step(self, batch, batch_idx):
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        treatment_label = batch["treatment_label"]

        pred_noise, noise, noisy_latents, timesteps, target_latents = self(
            input_images, target_images, prompts, earlier_features
        )
        
        per_sample_diff_loss = F.mse_loss(pred_noise, noise, reduction='none')
        per_sample_diff_loss = per_sample_diff_loss.view(per_sample_diff_loss.size(0), -1).mean(dim=1)
        
        if self.propensity_model is not None:
            with torch.amp.autocast("cuda", enabled=False):
                # determine the propensity model's parameter dtype and cast inputs accordingly
                model_dtype = next(self.propensity_model.parameters()).dtype
                inp_f32 = input_images.to(model_dtype)
                ef_f32 = earlier_features.to(model_dtype)
                propensity_logits = self.propensity_model(inp_f32, ef_f32)
            propensity_probs = torch.sigmoid(propensity_logits)
            # compute joint probability for multi-label treatment
            label = treatment_label.to(self.device)  # shape [B, NUM_TREATMENTS]
            # independent probabilities per treatment
            prob = propensity_probs  # shape [B, NUM_TREATMENTS]
            joint_prob = torch.prod(torch.where(label == 1, prob, 1.0 - prob), dim=1) + 1e-6  # shape [B]
            ipw = 1.0 / joint_prob
            weighted_diff_loss = (ipw * per_sample_diff_loss).mean()
        else:
            weighted_diff_loss = per_sample_diff_loss.mean()

        try:
            predicted_latents = self.scheduler.predict_original_sample(noisy_latents, timesteps, pred_noise)
        except Exception as e:
            self.log("scheduler_error", 1.0, sync_dist=True)
            predicted_latents = target_latents

        generated = self.vae.decode(predicted_latents / self.latent_scale).sample
        
        feature_loss_total = 0.0
        valid_counts = 0
        for i, feat in enumerate(PERCEPTUAL_FEATURES):
            gt = later_features[:, i].long()
            valid_mask = (gt != -1)
            if valid_mask.sum() == 0:
                continue
            predictor = self.feature_predictors[feat]
            pred_logits = predictor(generated)
            loss_i = F.cross_entropy(pred_logits[valid_mask], gt[valid_mask])
            feature_loss_total += loss_i
            valid_counts += 1
        feature_loss = feature_loss_total / valid_counts if valid_counts > 0 else 0.0
        
        total_loss = weighted_diff_loss + self.feature_loss_weight * feature_loss
        
        batch_size = input_images.size(0)
        self.log("train_diff_loss", weighted_diff_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)
        self.log("train_feature_loss", feature_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)
        self.log("train_loss", total_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        treatment_label = batch["treatment_label"]

        if self.propensity_model is not None:
            # compute propensity with correct precision to avoid NaNs
            with torch.amp.autocast("cuda", enabled=False):
                model_dtype = next(self.propensity_model.parameters()).dtype
                inp_f32 = input_images.to(model_dtype)
                ef_f32 = earlier_features.to(model_dtype)
                propensity_logits = self.propensity_model(inp_f32, ef_f32)
            propensity_probs = torch.sigmoid(propensity_logits)
            label = treatment_label.to(self.device)  # shape [B, NUM_TREATMENTS]
            joint_prob = torch.prod(torch.where(label == 1, propensity_probs, 1.0 - propensity_probs), dim=1) + 1e-6
            ipw = 1.0 / joint_prob

        pred_noise, noise, noisy_latents, timesteps, target_latents = self(input_images, target_images, prompts, earlier_features)
        val_loss = F.mse_loss(pred_noise, noise)
        batch_size = input_images.size(0)
        self.log("val_loss", val_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)

        if self.propensity_model is not None:
            per_sample_diff_loss = F.mse_loss(pred_noise, noise, reduction='none')
            per_sample_diff_loss = per_sample_diff_loss.view(per_sample_diff_loss.size(0), -1).mean(dim=1)
            ipw_val_loss = (ipw * per_sample_diff_loss).mean()
            self.log("ipw_val_loss", ipw_val_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)

        # calculate ite score
        try:
            predicted_latents = self.scheduler.predict_original_sample(noisy_latents, timesteps, pred_noise)
        except Exception as e:
            self.log("scheduler_error", 1.0, sync_dist=True)
            predicted_latents = target_latents
        generated = self.vae.decode(predicted_latents / self.latent_scale).sample

        val_gen_loss = F.mse_loss(generated, target_images)
        self.log("val_gen_loss", val_gen_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)
        if self.propensity_model is not None:
            val_gen_loss = F.mse_loss(generated, target_images, reduction='none')
            val_gen_loss = val_gen_loss.view(val_gen_loss.size(0), -1).mean(dim=1)
            ipw_val_gen_loss = (ipw * val_gen_loss).mean()
            self.log("ipw_gen_val_loss", ipw_val_gen_loss, prog_bar=True, sync_dist=True, batch_size=batch_size)

        for i, feat in enumerate(PERCEPTUAL_FEATURES):
            earlier_features_i = earlier_features[:, i].long()
            later_features_i = later_features[:, i].long()
            valid_earlier_features = (earlier_features_i != -1)
            valid_later_features = (later_features_i != -1)
            valid_ite_mask = valid_earlier_features & valid_later_features
            if (valid_earlier_features.sum() == 0) or (valid_later_features.sum() == 0):
                continue
            predictor = self.feature_predictors[feat]
            pred_logits = predictor(generated)
            gt_logits = predictor(target_images)
            pred = pred_logits.argmax(dim=-1).float()
            gt_model = gt_logits.argmax(dim=-1).float()
            pred = pred[valid_ite_mask]

            pred_delta = pred - earlier_features_i[valid_ite_mask].float()
            gt_delta = later_features_i[valid_ite_mask].float() - earlier_features_i[valid_ite_mask].float()
            gt_model_delta = gt_model[valid_ite_mask] - earlier_features_i[valid_ite_mask].float()
            val_ite_feat = torch.abs(pred_delta - gt_delta).mean()
            val_model_ite_feat = torch.abs(pred_delta - gt_model_delta).mean()
            self.log(f"val_ite_{feat}", val_ite_feat, prog_bar=True, sync_dist=True, batch_size=batch_size)
            self.log(f"val_model_ite_{feat}", val_model_ite_feat, prog_bar=True, sync_dist=True, batch_size=batch_size)

            ipw_mask =ipw[valid_ite_mask]
            ipw_val_ite_feat = torch.abs((pred_delta - gt_delta)*ipw_mask).mean()
            ipw_val_model_ite_feat = (torch.abs(pred_delta - gt_model_delta)*ipw_mask).mean()
            self.log(f"ipw_val_ite_{feat}", ipw_val_ite_feat, prog_bar=True, sync_dist=True, batch_size=batch_size)
            self.log(f"ipw_val_model_ite_{feat}", ipw_val_model_ite_feat, prog_bar=True, sync_dist=True, batch_size=batch_size)
        if batch_idx == 0:
            self.log_images(input_images, target_images, prompts)
    
    def log_images(self, input_images, target_images, prompts, strength=0.75, guidance_scale=7.5, num_images=25):
        input_pil = [self.transform_to_pil(img.cpu()) for img in input_images[:num_images]]
        generated_pil = self.pipe(
            prompt=prompts[:num_images],
            image=input_pil,
            strength=strength,
            guidance_scale=guidance_scale,
        ).images
        
        input_grid = make_grid(input_images[:num_images], nrow=num_images)
        target_grid = make_grid(target_images[:num_images], nrow=num_images)
        generated_grid = make_grid([T.ToTensor()(img) for img in generated_pil], nrow=num_images)
        
        self.logger.experiment.add_image("Input Images", input_grid, self.current_epoch)
        self.logger.experiment.add_image("Target Images", target_grid, self.current_epoch)
        self.logger.experiment.add_image("Generated Images", generated_grid, self.current_epoch)
        
        prompts_text = "\n\n".join([f"Image {i+1}: {p}" for i, p in enumerate(prompts[:num_images])])
        self.logger.experiment.add_text("Prompts", prompts_text, self.current_epoch)
    
    def configure_optimizers(self):
        return Adam(
            list(self.unet.parameters()) +
            list(self.img_conditioning.parameters()) +
            list(self.feature_conditioner.parameters()) +
            list(self.feature_predictors.parameters()),
            lr=self.lr
        )

# -----------------------------------------------------------------------------
# Data Loading and Training Setup
# -----------------------------------------------------------------------------
common_transforms = T.Compose([T.Resize((224, 224)), T.ToTensor()])

train_dataset = KneeFeatureConditioningDataset("data/pairs_dataset/train.csv", image_transform=common_transforms)
val_dataset = KneeFeatureConditioningDataset("data/pairs_dataset/val.csv", image_transform=common_transforms)

train_loader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=50, num_workers=8, pin_memory=True)

pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("cuda")

class ImageConditioning(nn.Module):
    def __init__(self, in_channels=pipe.vae.config.latent_channels, hidden_dim=768):
        super().__init__()
        self.linear = nn.Linear(in_channels, hidden_dim)
    def forward(self, latent):
        return self.linear(latent.mean(dim=[2, 3]))

img_conditioning = ImageConditioning().to("cuda").float()

model = DiffusionLitModule(
    vae=pipe.vae,
    unet=pipe.unet,
    scheduler=pipe.scheduler,
    text_encoder=pipe.text_encoder,
    tokenizer=pipe.tokenizer,
    img_conditioning=img_conditioning,
    feature_loss_weight=1.0,
    propensity_model=propensity_model
)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename=f"best_informed_diffusion_ipw_model_{'_'.join(PERCEPTUAL_FEATURES)}",
    save_top_k=1,
    mode="min",
)
logger = TensorBoardLogger("tb_logs", name=f"diffusion_informed_ipw_{'_'.join(PERCEPTUAL_FEATURES)}")
trainer = pl.Trainer(
    max_epochs=30,
    accelerator="gpu",
    devices=[1, 2],
    precision="16-mixed",
    logger=logger,
    log_every_n_steps=10,
    strategy="ddp",
    callbacks=[checkpoint_callback]
)

trainer.fit(model, train_loader, val_loader)
print(f"Best model saved at: {checkpoint_callback.best_model_path}")
