"""
Base classes for diffusion model training.
Contains common functionality extracted from diff_*.py files.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from torch.optim import AdamW
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline
import timm
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score

from data.pairs_dataset.dataset import NUM_TREATMENTS, FEATURE_DIM, PERCEPTUAL_FEATURES
from diffusion_utils import (
    calculate_predicted_x0, calculate_feature_loss, calculate_ite_metrics,
    log_qualitative_images, calculate_discriminator_metrics, aggregate_test_metrics,
    save_test_results, safe_tensor_log, calculate_ssim, EPSILON_DELTA_T
)


class ImageConditioning(nn.Module):
    """Image conditioning module for latent features."""
    def __init__(self, in_channels=4, hidden_dim=768):
        super().__init__()
        self.linear = nn.Linear(in_channels, hidden_dim)
    
    def forward(self, latent):
        return self.linear(latent.mean(dim=[2, 3]))


class ContextEncoderRNN(nn.Module):
    """RNN-based context encoder for historical data."""
    def __init__(self, args, current_img_cond_dim, cov_hist_dim, num_treatments_hist):
        super().__init__()
        self.model_type = args.disc_model_type
        self.delta_t_feat_dim = args.disc_delta_t_feat_dim
        self.side_feat_dim = args.disc_side_feat_dim

        seq_input_dim = cov_hist_dim + num_treatments_hist

        if self.model_type.lower() == "rnn":
            self.hist_encoder = nn.LSTM(
                seq_input_dim, args.disc_hidden_dim, args.disc_num_layers,
                batch_first=True, dropout=args.disc_dropout if args.disc_num_layers > 1 else 0
            )
            self.hist_encoder_output_dim = args.disc_hidden_dim
        else:
            raise ValueError(f"Unsupported context encoder model type: {self.model_type}")

        self.delta_t_processor = nn.Sequential(nn.Linear(1, self.delta_t_feat_dim), nn.ReLU())
        self.side_processor = nn.Sequential(nn.Linear(1, self.side_feat_dim), nn.ReLU())
        
        self.output_dim = (self.hist_encoder_output_dim + current_img_cond_dim + 
                          self.delta_t_feat_dim + self.side_feat_dim)

    def forward(self, img_cond_current, cov_seq_hist, trt_seq_hist, lengths_hist, 
                delta_t, side, image_seq_hist=None, delta_t_mean=0.0, delta_t_std=1.0):
        
        combined_hist_seq = torch.cat([cov_seq_hist, trt_seq_hist], dim=2)
        packed_input = pack_padded_sequence(combined_hist_seq, lengths_hist.cpu(), 
                                          batch_first=True, enforce_sorted=False)
        self.hist_encoder.flatten_parameters()
        _, (h_n, _) = self.hist_encoder(packed_input)
        hist_context_vector = h_n[-1]

        delta_t_norm = (delta_t - delta_t_mean) / (delta_t_std + EPSILON_DELTA_T)
        delta_t_features = self.delta_t_processor(delta_t_norm)
        side_features = self.side_processor(side)

        if img_cond_current.ndim == 3:
            img_cond_current = img_cond_current.mean(dim=1)

        rich_context_vector = torch.cat([
            hist_context_vector,
            img_cond_current,
            delta_t_features,
            side_features
        ], dim=1)
        
        return rich_context_vector


class BaseDiffusionLitModule(pl.LightningModule):
    """Base Lightning module with common diffusion functionality."""
    
    def __init__(self, args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 delta_t_mean, delta_t_std):
        super().__init__()
        self.save_hyperparameters(args, ignore=['vae', 'unet', 'scheduler', 'text_encoder', 
                                               'tokenizer', 'img_conditioning', 'feature_predictors'])
        
        # Core components
        self.vae = vae
        self.unet = unet  
        self.scheduler = scheduler
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.img_conditioning = img_conditioning
        
        # Parameters
        self.latent_scale = 0.18215
        self.lr = getattr(args, 'lr', getattr(args, 'lr_generator', 1e-5))
        self.feature_loss_weight = args.feature_loss_weight
        self.delta_t_mean = delta_t_mean
        self.delta_t_std = delta_t_std
        
        # Initialize components
        self.transform_to_pil = T.ToPILImage()
        self.feature_predictors = nn.ModuleDict()
        self.pipe = None
        self._test_outputs = []
        
        self._freeze_components()
        self._load_feature_predictors()

    def _freeze_components(self):
        """Freeze VAE, text encoder, and feature predictors."""
        if hasattr(self, 'vae'):
            self.vae.eval()
            for p in self.vae.parameters():
                p.requires_grad_(False)
                
        if hasattr(self, 'text_encoder'):
            self.text_encoder.eval()
            for p in self.text_encoder.parameters():
                p.requires_grad_(False)
                
        print("VAE and Text Encoder frozen.")

    def _load_feature_predictors(self):
        """Load pre-trained feature predictors."""
        print("Loading and freezing feature predictors...")
        for feat in PERCEPTUAL_FEATURES:
            model_path = os.path.join("checkpoints", feat, f"{feat}_model.pth")
            if not os.path.exists(model_path):
                print(f"Warning: Feature predictor not found: {model_path}")
                continue
                
            state_dict = torch.load(model_path, map_location="cpu")
            num_outputs = state_dict['head.weight'].shape[0]
            predictor = timm.create_model('efficientformerv2_l.snap_dist_in1k', pretrained=True)
            predictor.reset_classifier(num_outputs)
            predictor.load_state_dict(state_dict)
            predictor.eval()
            for param in predictor.parameters():
                param.requires_grad = False
            self.feature_predictors[feat] = predictor

    def setup(self, stage=None):
        """Setup for training/testing."""
        # Move feature predictors to device
        for feat in self.feature_predictors:
            self.feature_predictors[feat].to(self.device)
            
        if self.pipe is None:
            self.pipe = StableDiffusionImg2ImgPipeline(
                vae=self.vae, unet=self.unet, scheduler=self.scheduler,
                text_encoder=self.text_encoder, tokenizer=self.tokenizer,
                safety_checker=None, feature_extractor=None,
            ).to(self.device)
        print("Setup complete.")

    def forward(self, input_images, target_images, prompts, **kwargs):
        """Forward pass through diffusion model."""
        input_dtype = input_images.dtype
        
        with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                      self.trainer.precision.startswith("bf16")):
            input_latents = self.vae.encode(input_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            target_latents = self.vae.encode(target_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, 
                                    (input_latents.size(0),), device=self.device).long()
            noise = torch.randn_like(target_latents)
            noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps)
            
            text_inputs = self.tokenizer(prompts, padding="max_length", 
                                       max_length=self.tokenizer.model_max_length, 
                                       truncation=True, return_tensors="pt").input_ids.to(self.device)
            text_embeddings = self.text_encoder(text_inputs)[0]
            
            # Basic conditioning
            unet_conditioning = text_embeddings
            pred_noise = self.unet(noisy_latents.to(self.unet.dtype), timesteps, 
                                 encoder_hidden_states=unet_conditioning.to(self.unet.dtype)).sample
            
        return pred_noise, noise, noisy_latents, timesteps, target_latents.to(input_dtype)

    def _calculate_predicted_x0(self, pred_noise, timesteps, noisy_latents):
        """Calculate predicted x0 from noise prediction."""
        return calculate_predicted_x0(pred_noise, timesteps, noisy_latents, self.scheduler)

    def training_step(self, batch, batch_idx):
        """Basic training step."""
        input_images = batch["input_image"]
        target_images = batch["target_image"] 
        prompts = batch["prompt"]
        later_features = batch.get("later_features")
        
        pred_noise, noise, noisy_latents, timesteps, target_latents = self(
            input_images, target_images, prompts
        )
        
        if torch.isnan(pred_noise).any():
            print("FATAL TRAIN: NaN in pred_noise")
            return None
            
        # Basic diffusion loss
        diffusion_loss = F.mse_loss(pred_noise, noise)
        
        # Feature loss if enabled
        feature_loss = torch.tensor(0.0, device=self.device)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0 and later_features is not None:
            feature_loss = calculate_feature_loss(
                self, pred_noise, timesteps, noisy_latents, later_features,
                self.feature_predictors, self.vae, self.latent_scale,
                self.trainer.precision, self.device
            )
        
        total_loss = diffusion_loss + self.feature_loss_weight * feature_loss
        
        if torch.isnan(total_loss):
            print("FATAL TRAIN: NaN in total_loss")
            return None
            
        # Logging
        bs = input_images.size(0)
        self.log("train_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=bs)
        self.log("train_diffusion_loss", diffusion_loss, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            self.log("train_feature_loss", feature_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=bs)
            
        return total_loss

    def validation_step(self, batch, batch_idx):
        """Basic validation step."""
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch.get("earlier_features")
        later_features = batch.get("later_features")
        
        pred_noise, noise, noisy_latents, timesteps, target_latents = self(
            input_images, target_images, prompts
        )
        
        if torch.isnan(pred_noise).any():
            print("WARN VAL: NaN pred_noise")
            return
            
        # Log losses
        val_loss = F.mse_loss(pred_noise, noise)
        val_loss = safe_tensor_log(val_loss, 1000.0, self.device)
        self.log("val_loss", val_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))
        
        val_mae_loss = F.l1_loss(pred_noise, noise)
        val_mae_loss = safe_tensor_log(val_mae_loss, 1000.0, self.device)
        self.log("val_mae_loss", val_mae_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        # Generate images
        generated_images = None
        try:
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                print("WARN VAL: NaN/Inf in pred_latents")
                generated_images = target_images
        except Exception as e:
            print(f"WARN VAL: Image generation failed: {e}")
            generated_images = target_images
            
        if generated_images is None:
            generated_images = target_images
            
        # Log generation loss
        val_gen_loss = F.mse_loss(generated_images, target_images)
        if not torch.isnan(val_gen_loss):
            self.log("val_gen_loss", val_gen_loss, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
        
        # Log SSIM metric
        val_ssim = calculate_ssim(generated_images, target_images, self.device)
        if not torch.isnan(val_ssim):
            self.log("val_ssim", val_ssim, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))

        # ITE Metrics if available
        if (len(self.feature_predictors) > 0 and earlier_features is not None and 
            later_features is not None and not torch.equal(generated_images, target_images)):
            
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device
            )
            
            for metric_name, metric_value in ite_metrics.items():
                if not np.isnan(metric_value):
                    self.log(f"val_{metric_name}", metric_value, prog_bar=False, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Log qualitative images occasionally
        if self.global_rank == 0 and batch_idx == 2:
            log_qualitative_images(
                self.pipe, input_images[0:16], target_images[0:16], prompts[0:16],
                self.logger, self.current_epoch, device=self.device
            )

    def test_step(self, batch, batch_idx):
        """Basic test step."""
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch.get("earlier_features")
        later_features = batch.get("later_features")
        interval_labels = batch.get("interval_labels")
        
        # Treatment mask
        treat_mask = None
        if interval_labels is not None:
            treat_mask = (interval_labels[:, -1] == 0)
        
        metrics = {}
        
        with torch.no_grad():
            pred_noise, noise, noisy_latents, timesteps, target_latents = self(
                input_images, target_images, prompts
            )
            
        if torch.isnan(pred_noise).any():
            raise ValueError("NaN in pred_noise")
            
        # Calculate losses
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_mae = F.l1_loss(pred_noise, noise)
        metrics['test_loss_mse'] = loss_mse.item()
        metrics['test_loss_mae'] = loss_mae.item()
        
        # Per-sample losses
        per_sample_mse = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        per_sample_mae = F.l1_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        
        if treat_mask is not None and treat_mask.any():
            metrics['test_loss_mse_treated'] = per_sample_mse[treat_mask].mean().item()
            metrics['test_loss_mae_treated'] = per_sample_mae[treat_mask].mean().item()

        # Generate images
        generated_images = None
        with torch.no_grad():
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                self.vae.eval()
                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                             self.trainer.precision.startswith("bf16")):
                    generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            if generated_images is None or torch.isnan(generated_images).any():
                generated_images = target_images

        val_gen_loss = F.mse_loss(generated_images, target_images)
        metrics['test_gen_loss'] = val_gen_loss.item()
        
        # Calculate SSIM
        test_ssim = calculate_ssim(generated_images, target_images, self.device)
        metrics['test_ssim'] = test_ssim.item()
        
        per_sample_gen = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask is not None and treat_mask.any():
            metrics['test_gen_loss_treated'] = per_sample_gen[treat_mask].mean().item()

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and earlier_features is not None and 
            later_features is not None and not torch.equal(generated_images, target_images)):
            
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device, treat_mask
            )
            metrics.update({f'test_{k}': v for k, v in ite_metrics.items()})

        metrics['batch_size'] = float(input_images.size(0))
        self._test_outputs.append(metrics.copy())
        
        return {k: v for k, v in metrics.items() if np.isfinite(v)}

    def on_test_epoch_end(self):
        """Aggregate test results."""
        aggregated_metrics, total_samples = aggregate_test_metrics(self._test_outputs)
        
        results_file_path = getattr(self.hparams, 'results_file', None)
        ckpt_path = getattr(self.hparams, 'ckpt_path', None)
        
        save_test_results(aggregated_metrics, total_samples, results_file_path, ckpt_path)
        self._test_outputs.clear()

    def configure_optimizers(self):
        """Configure optimizers."""
        params = list(self.unet.parameters()) + list(self.img_conditioning.parameters())
        return AdamW(params, lr=self.lr)


class BaseAdversarialMixin:
    """Mixin for adversarial functionality."""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Will be initialized by the concrete class
        
    def _init_adversarial_components(self, args, discriminator_pos_weight=None):
        """Initialize adversarial components."""
        # Context encoder for rich representations
        self.context_encoder = ContextEncoderRNN(
            args=args,
            current_img_cond_dim=args.disc_current_img_cond_dim,
            cov_hist_dim=FEATURE_DIM,
            num_treatments_hist=NUM_TREATMENTS
        )
        
        # U-Net cross-attention projection
        unet_cross_attn_dim = self.text_encoder.config.hidden_size
        if self.context_encoder.output_dim != unet_cross_attn_dim:
            self.rich_context_projector = nn.Linear(self.context_encoder.output_dim, unet_cross_attn_dim)
        else:
            self.rich_context_projector = nn.Identity()
            
        # Treatment discriminator
        self.treatment_discriminator = TreatmentDiscriminator(
            args=args, input_rich_context_dim=self.context_encoder.output_dim
        )
        
        # Discriminator loss criterion
        self.discriminator_pos_weight = discriminator_pos_weight
        if discriminator_pos_weight is not None:
            self.register_buffer("disc_pos_weight_buffer", discriminator_pos_weight)
            self.criterion_discriminator = nn.BCEWithLogitsLoss(pos_weight=self.disc_pos_weight_buffer)
        else:
            self.criterion_discriminator = nn.BCEWithLogitsLoss()
            
        # Adversarial parameters
        self.lr_generator = args.lr_generator
        self.lr_discriminator = args.lr_discriminator  
        self.adversarial_weight = args.adversarial_weight
        
        self.automatic_optimization = False


class BaseIPWMixin:
    """Mixin for IPW functionality."""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def _init_ipw_components(self, args, propensity_model=None):
        """Initialize IPW components."""
        self.propensity_model = propensity_model
        self.disable_ipw_loss = args.disable_ipw_loss
        self.ipw_max_clamp = args.ipw_max_clamp
        self.propensity_delta_t_mean = self.delta_t_mean
        self.propensity_delta_t_std = self.delta_t_std
        
        if self.propensity_model is not None:
            self.propensity_model.eval()
            for p in self.propensity_model.parameters():
                p.requires_grad_(False)


class TreatmentDiscriminator(nn.Module):
    """Simple MLP discriminator for treatment prediction."""
    def __init__(self, args, input_rich_context_dim):
        super().__init__()
        self.num_treatments_output = NUM_TREATMENTS
        self.treatment_head = nn.Sequential(
            nn.Linear(input_rich_context_dim, args.disc_hidden_dim),
            nn.ReLU(),
            nn.Dropout(args.disc_dropout),
            nn.Linear(args.disc_hidden_dim, self.num_treatments_output)
        )
        print(f"TreatmentDiscriminator input dim: {input_rich_context_dim}, output: {self.num_treatments_output}")

    def forward(self, rich_context_vector):
        return self.treatment_head(rich_context_vector)