import pytorch_lightning as pl
import torch.nn as nn
import wandb
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from typing import Tuple
import timm
import umap
import warnings
import pytorch_lightning as pl
import wandb
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import torch.nn.functional as F
import warnings
import torch.nn.functional as F
from sklearn.neighbors import KNeighborsClassifier
import os
warnings.filterwarnings('ignore')


class CausalDeltaEmbeddingModel(pl.LightningModule):
    def __init__(
        self,
        backbone_name: str = "vit_small_patch16_224.dino",
        proj_dim: int = 256,
        hidden_dim: int = 1024,
        num_actions: int = 7,
        

        temperature: float = 0.07,         
        alpha_contrast: float = 1.0,        
        alpha_sparsity: float = 0.0,      

        # Training
        lr: float = 5e-4,
        weight_decay: float = 0.05,
        lr_backbone_mult: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.backbone = timm.create_model(
            backbone_name, pretrained=True, num_classes=0, global_pool="token"
        )
        self.feat_dim = self.backbone.num_features
        
        self.causal_projector = nn.Sequential(
            nn.Linear(self.feat_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, proj_dim),
            nn.LayerNorm(proj_dim),
        )

        self.action_head = nn.Sequential(
            nn.Linear(proj_dim, proj_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(proj_dim // 2, num_actions),
        )
    
        self.best_val_ood_acc = 0.0
        self.validation_outputs = [[], []]
        
    def extract_features(self, x: torch.Tensor) -> torch.Tensor:
        backbone_features = self.backbone(x)  # (B, feat_dim)
        causal_features = self.causal_projector(backbone_features)  # (B, proj_dim)
        return causal_features
    
    def forward(self, x_before: torch.Tensor, x_after: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z_before = self.extract_features(x_before)
        z_after = self.extract_features(x_after)
        delta = z_after - z_before 
        logits = self.action_head(delta)
        return logits, delta, None, z_before, z_after
        


    def supervised_contrastive_loss(self, deltas: torch.Tensor, action_labels: torch.Tensor) -> torch.Tensor:
            batch_size = deltas.shape[0]
            device = deltas.device
            
            deltas_norm = F.normalize(deltas, dim=1, p=2)
            
            sim_matrix = torch.mm(deltas_norm, deltas_norm.t()) / self.hparams.temperature
            
            action_mask = action_labels.unsqueeze(0) == action_labels.unsqueeze(1)
            action_mask = action_mask.float()
            
            eye_mask = torch.eye(batch_size, device=device)
            action_mask = action_mask * (1 - eye_mask)
            
            logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
            logits = sim_matrix - logits_max.detach()
            
            exp_logits = torch.exp(logits) * (1 - eye_mask)
            
            log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-8)
            
            mask_pos_pairs = action_mask.sum(dim=1)
            mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
            
            mean_log_prob_pos = (action_mask * log_prob).sum(dim=1) / mask_pos_pairs
            loss = -mean_log_prob_pos.mean()
            
            return loss


    def causal_sparsity_loss(self, deltas: torch.Tensor) -> torch.Tensor:
        return torch.mean(torch.abs(deltas))


    def training_step(self, batch, batch_idx):
        x_before, x_after, action_labels, _ = batch
        
        logits, deltas, _, _, _ = self(x_before, x_after)
        
        ce_loss = F.cross_entropy(logits, action_labels)
        contrastive_loss = self.supervised_contrastive_loss(deltas, action_labels)
        sparsity_loss = self.causal_sparsity_loss(deltas)
        

        total_loss = (
            ce_loss +
            self.hparams.alpha_contrast * contrastive_loss +
            self.hparams.alpha_sparsity * sparsity_loss 
        )
        
        acc = (logits.argmax(dim=1) == action_labels).float().mean()
        
        self.log_dict({
            "train/loss": total_loss,
            "train/acc": acc,
            "train/ce_loss": ce_loss,
            "train/contrastive_loss": contrastive_loss,
            "train/sparsity_loss": sparsity_loss,
        }, on_step=True, on_epoch=True, prog_bar=True)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
        x_before, x_after, action_labels, _ = batch
        
        logits, deltas, _, _, _ = self(x_before, x_after)
        loss = F.cross_entropy(logits, action_labels)
        acc = (logits.argmax(dim=1) == action_labels).float().mean()
        
        self.log(f"val{dataloader_idx}/loss", loss, on_epoch=True)
        self.log(f"val{dataloader_idx}/acc", acc, on_epoch=True, prog_bar=True)

        self.validation_outputs[dataloader_idx].append({
            'deltas': deltas.cpu(),
            'action_labels': action_labels.cpu()
        })


        # Log per-action accuracies 
        for action_idx in range(self.hparams.num_actions):
            mask = action_labels == action_idx
            if mask.sum() > 0:
                action_acc = (logits.argmax(dim=1)[mask] == action_labels[mask]).float().mean()
                self.log(f'val{dataloader_idx}/action_{action_idx}', action_acc, on_epoch=True)           
        
        return acc
    
    def configure_optimizers(self):
        backbone_params = list(self.backbone.parameters())
        causal_params = (list(self.causal_projector.parameters()) + 
                        list(self.action_head.parameters()))
        
        optimizer = torch.optim.AdamW([
            {"params": backbone_params, "lr": self.hparams.lr * self.hparams.lr_backbone_mult},
            {"params": causal_params, "lr": self.hparams.lr},
        ], weight_decay=self.hparams.weight_decay)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    def on_validation_epoch_end(self):
        current_ood_acc = self.trainer.callback_metrics.get("val1/acc/dataloader_idx_1")
        if current_ood_acc is not None:
            current_ood_acc = current_ood_acc.item()
            if current_ood_acc > self.best_val_ood_acc:

                self.best_val_ood_acc = current_ood_acc

        self.validation_outputs = [[], []]
    
    def on_train_end(self):
        if hasattr(self.logger, 'experiment'):
            self.logger.experiment.log({"best_val_ood_acc": self.best_val_ood_acc})

    @staticmethod
    def get_callbacks(every_n_epochs=5, action_index_to_name=None, object_index_to_name=None):
        return [
            AnalysisCallback(
                log_every_n_epochs=5,
                action_index_to_name=action_index_to_name    
            ),
        ]


class PatchWiseDeltaEmbeddingModel(pl.LightningModule):    
    def __init__(
        self,
        backbone_name: str = "vit_base_patch16_224.dino",
        proj_dim: int = 512,
        hidden_dim: int = 1024,
        num_actions: int = 7,
        patch_size: int = 16,
        img_size: int = 224,
        
        temperature: float = 0.07,
        alpha_contrast: float = 2.0,      
        alpha_sparsity: float = 0.5,      

        aggregation_strategy: str = "top_k", 
        top_k: int = 2,

        lr: float = 1e-4,
        weight_decay: float = 0.05,
        lr_backbone_mult: float = 0.1,
        
        is_multi_object: bool = True,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.backbone = timm.create_model(
            backbone_name, pretrained=True, num_classes=0, global_pool=""
        )
        self.feat_dim = self.backbone.num_features
        self.num_patches = (img_size // patch_size) ** 2
        
        self.patch_projector = nn.Sequential(
            nn.Linear(self.feat_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, proj_dim),
            nn.LayerNorm(proj_dim),
        )

        self.action_classifier = nn.Sequential(
            nn.Linear(proj_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_actions),
        )
        
        self.aggregation_strategy = aggregation_strategy
        self.top_k = top_k
        if self.aggregation_strategy == "attention":
            self.patch_attention = nn.Sequential(
                nn.Linear(self.hparams.proj_dim, 128),
                nn.ReLU(),
                nn.Linear(128, 1)
            )

        self.best_val_ood_acc = 0.0
        
    def extract_patch_features(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)  # (B, num_patches + 1, feat_dim)
        
        # Remove CLS token, keep only spatial patches
        if features.shape[1] == self.num_patches + 1:
            patch_features = features[:, 1:, :]  # (B, num_patches, feat_dim)
        else:
            patch_features = features
            
        return patch_features
    

    def forward(self, x_before: torch.Tensor, x_after: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        patches_before = self.extract_patch_features(x_before)
        patches_after = self.extract_patch_features(x_after)
        
        batch_size, num_patches, feat_dim = patches_before.shape
        
        flat_before = patches_before.reshape(-1, feat_dim)
        flat_after = patches_after.reshape(-1, feat_dim)
        
        proj_before = self.patch_projector(flat_before).view(batch_size, num_patches, -1)
        proj_after = self.patch_projector(flat_after).view(batch_size, num_patches, -1)
        
        patch_deltas = proj_after - proj_before  # (B, num_patches, proj_dim)
        patch_magnitudes = torch.norm(patch_deltas, dim=2)  # (B, num_patches)
        
        if self.aggregation_strategy == "top_k":
            aggregated_deltas, selected_info = self._aggregate_top_k(patch_deltas, patch_magnitudes)
        elif self.aggregation_strategy == "threshold":
            aggregated_deltas, selected_info = self._aggregate_threshold(patch_deltas, patch_magnitudes)
        elif self.aggregation_strategy == "attention":
            aggregated_deltas, selected_info = self._aggregate_attention(patch_deltas)
        else:
            aggregated_deltas, selected_info = self._aggregate_single(patch_deltas, patch_magnitudes)
        
        logits = self.action_classifier(aggregated_deltas)
        return logits, aggregated_deltas, selected_info, proj_before, proj_after
    
    def _aggregate_top_k(self, patch_deltas, patch_magnitudes):
        """Top-K patch aggregation"""
        batch_size = patch_deltas.shape[0]
        top_k_values, top_k_indices = torch.topk(patch_magnitudes, k=self.top_k, dim=1)
        
        aggregated_deltas = []
        for i in range(batch_size):
            sample_deltas = patch_deltas[i, top_k_indices[i]]
            weights = F.softmax(top_k_values[i], dim=0)
            weighted_delta = torch.sum(sample_deltas * weights.unsqueeze(1), dim=0)
            aggregated_deltas.append(weighted_delta)
        
        return torch.stack(aggregated_deltas), top_k_indices
    
    def _aggregate_threshold(self, patch_deltas, patch_magnitudes):
        """Threshold-based aggregation (adaptive number of patches)"""
        batch_size = patch_deltas.shape[0]
        thresholds = torch.quantile(patch_magnitudes, 0.75, dim=1, keepdim=True)
        
        aggregated_deltas = []
        selected_counts = []
        
        for i in range(batch_size):
            sig_mask = patch_magnitudes[i] > thresholds[i]
            if sig_mask.sum() == 0:
                sig_mask[torch.argmax(patch_magnitudes[i])] = True
            
            sig_deltas = patch_deltas[i, sig_mask]
            sig_mags = patch_magnitudes[i, sig_mask]
            
            weights = F.softmax(sig_mags, dim=0)
            weighted_delta = torch.sum(sig_deltas * weights.unsqueeze(1), dim=0)
            
            aggregated_deltas.append(weighted_delta)
            selected_counts.append(sig_mask.sum().item())
        
        return torch.stack(aggregated_deltas), selected_counts
    
    def _aggregate_attention(self, patch_deltas):
        batch_size, num_patches, proj_dim = patch_deltas.shape
        
        attention_scores = self.patch_attention(patch_deltas)  # (B, num_patches, 1)
        attention_weights = F.softmax(attention_scores.squeeze(-1), dim=1)  # (B, num_patches)
        
        aggregated_deltas = torch.sum(patch_deltas * attention_weights.unsqueeze(-1), dim=1)  # (B, proj_dim)
        
        return aggregated_deltas, attention_weights


    def causal_sparsity_loss(self, deltas: torch.Tensor) -> torch.Tensor:
        return torch.mean(torch.abs(deltas))
    
   
    def supervised_contrastive_loss(self, deltas: torch.Tensor, action_labels: torch.Tensor) -> torch.Tensor:
        batch_size = deltas.shape[0]
        device = deltas.device
        
        deltas_norm = F.normalize(deltas, dim=1, p=2)
        
        sim_matrix = torch.mm(deltas_norm, deltas_norm.t()) / self.hparams.temperature
        
        action_mask = action_labels.unsqueeze(0) == action_labels.unsqueeze(1)
        action_mask = action_mask.float()
        
        eye_mask = torch.eye(batch_size, device=device)
        action_mask = action_mask * (1 - eye_mask)
        
        logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
        logits = sim_matrix - logits_max.detach()
        
        exp_logits = torch.exp(logits) * (1 - eye_mask)
        
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-8)
        
        mask_pos_pairs = action_mask.sum(dim=1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        
        mean_log_prob_pos = (action_mask * log_prob).sum(dim=1) / mask_pos_pairs
        loss = -mean_log_prob_pos.mean()
        
        return loss
   

    def training_step(self, batch, batch_idx):
        x_before, x_after, action_labels, _ = batch
        
        logits, patch_deltas, patch_indices, _, _ = self(x_before, x_after)
        
        ce_loss = F.cross_entropy(logits, action_labels)
        contrastive_loss = self.supervised_contrastive_loss(patch_deltas, action_labels)
        sparsity_loss = self.causal_sparsity_loss(patch_deltas)
    
        total_loss = (ce_loss + 
                     self.hparams.alpha_contrast * contrastive_loss +
                     self.hparams.alpha_sparsity * sparsity_loss
                     )
        
        acc = (logits.argmax(dim=1) == action_labels).float().mean()
        
        prefix = "multi" if self.hparams.is_multi_object else "single"
        self.log_dict({
            f"{prefix}/train/loss": total_loss,
            f"{prefix}/train/acc": acc,
            f"{prefix}/train/ce_loss": ce_loss,
            f"{prefix}/train/contrastive_loss": contrastive_loss,
        }, on_step=True, on_epoch=True, prog_bar=True)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
        x_before, x_after, action_labels, _ = batch
        
        logits, patch_deltas, patch_indices, _, _ = self(x_before, x_after)
        loss = F.cross_entropy(logits, action_labels)
        acc = (logits.argmax(dim=1) == action_labels).float().mean()
        

        prefix = "multi" if self.hparams.is_multi_object else "single"
        self.log(f"{prefix}/val{dataloader_idx}/loss", loss, on_epoch=True)
        self.log(f"{prefix}/val{dataloader_idx}/acc", acc, on_epoch=True, prog_bar=True)

        return acc
    
    def configure_optimizers(self):
        backbone_params = list(self.backbone.parameters())
        other_params = [p for n, p in self.named_parameters() if 'backbone' not in n]
        
        optimizer = torch.optim.AdamW([
            {"params": backbone_params, "lr": self.hparams.lr * self.hparams.lr_backbone_mult},
            {"params": other_params, "lr": self.hparams.lr},
        ], weight_decay=self.hparams.weight_decay)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    def on_validation_epoch_end(self):
        """Track best OOD performance"""
        prefix = "multi" if self.hparams.is_multi_object else "single"
        current_ood_acc = self.trainer.callback_metrics.get(f"{prefix}/val1/acc_epoch/dataloader_idx_1")
        if current_ood_acc is not None:
            current_ood_acc = current_ood_acc.item()
            if current_ood_acc > self.best_val_ood_acc:
                self.best_val_ood_acc = current_ood_acc

    @staticmethod
    def get_callbacks(every_n_epochs=5, action_index_to_name=None, object_index_to_name=None):
        return [
            AnalysisCallback(
                log_every_n_epochs=5,
                action_index_to_name=action_index_to_name,
            )
        ]



class AnalysisCallback(pl.Callback):
    def __init__(
        self,
        log_every_n_epochs: int = 5,
        max_samples_train: int = 2000,
        max_samples_val: int = 500,
        action_index_to_name: dict = None,
    ):
        super().__init__()
        self.log_every_n_epochs = log_every_n_epochs
        self.max_samples_train = max_samples_train
        self.max_samples_val = max_samples_val
        self.action_map = action_index_to_name or {}

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        if (trainer.current_epoch + 1) % self.log_every_n_epochs != 0: return
            
        print(f"\n{'='*20} RUNNING ANALYSIS on Epoch {trainer.current_epoch + 1} {'='*20}")
        logger = pl_module.logger.experiment

        # Use training data for stable prototypes, OOD validation data for evaluation
        prototypes = self._calculate_prototypes(trainer.train_dataloader, pl_module)
        ood_data = self._collect_data(trainer.val_dataloaders[1], pl_module, self.max_samples_val)
        iid_data = self._collect_data(trainer.val_dataloaders[0], pl_module, self.max_samples_val)

        train_data = self._collect_data(trainer.train_dataloader, pl_module, self.max_samples_train)

        # Run all analyses
        self._calculate_quantitative_metrics(train_data, ood_data, trainer.current_epoch, logger)
        self._plot_symmetry_heatmap(prototypes, trainer.current_epoch)
        self._plot_umap(iid_data, trainer.current_epoch, "iid")
        self._plot_umap(ood_data, trainer.current_epoch, "ood")
        print(f"\n{'='*20} FINAL ANALYSIS COMPLETE {'='*20}")

    def _plot_symmetry_heatmap(self, prototypes: dict, epoch: int):
        print("  Generating Prototype Similarity Heatmap for symmetry discovery...")
        
        if len(prototypes) < 2: return

        proto_indices = sorted(prototypes.keys())
        proto_names = [self.action_map.get(i, str(i)) for i in proto_indices]
        proto_vectors = torch.stack([prototypes[i] for i in proto_indices])
        
        proto_vectors_norm = F.normalize(proto_vectors, p=2, dim=1)

        sim_matrix = torch.mm(proto_vectors_norm, proto_vectors_norm.t()).cpu().numpy()

        fig, ax = plt.subplots(figsize=(10, 8))
        
        sns.heatmap(
            sim_matrix,
            annot=True,         
            fmt=".2f",          
            cmap="RdBu_r",      
            xticklabels=proto_names,
            yticklabels=proto_names,
            ax=ax,
            linewidths=.5,
            vmin=-1, vmax=1     
        )
        
        ax.set_title(f'Structural Relationships (Epoch {epoch+1})', fontsize=14)
        plt.xticks(rotation=45, ha="right")
        plt.yticks(rotation=0)
        plt.tight_layout()

        wandb.log({f"Symmetries_Heatmap": wandb.Image(fig)})
        plt.close(fig)

    def _collect_data(self, dataloader, pl_module, max_samples):
        pl_module.eval()
        device = pl_module.device
        all_z_before, all_z_after, all_deltas, all_actions = [], [], [], []
        
        for batch in dataloader:
            x_before, x_after, action_labels, _ = batch
            with torch.no_grad():
                _, deltas, _, z_before, z_after = pl_module(x_before.to(device), x_after.to(device))
            
            all_z_before.append(z_before.cpu())
            all_z_after.append(z_after.cpu())
            all_deltas.append(deltas.cpu())
            all_actions.append(action_labels.cpu())

            if sum(len(d) for d in all_deltas) >= max_samples:
                break
        
        return {
            "z_before": torch.cat(all_z_before), "z_after": torch.cat(all_z_after),
            "deltas": torch.cat(all_deltas), "actions": torch.cat(all_actions)
        }

    def _calculate_prototypes(self, dataloader, pl_module):
        print("  Calculating action prototypes from training data...")
        train_data = self._collect_data(dataloader, pl_module, self.max_samples_train)
        prototypes = {}
        for action_idx in train_data["actions"].unique().tolist():
            mask = train_data["actions"] == action_idx
            prototypes[action_idx] = train_data["deltas"][mask].mean(dim=0)
        return prototypes

    def _calculate_quantitative_metrics(self, train_data, test_data, epoch, logger):
        print("  Calculating k-NN Accuracy")
        
        train_deltas_np = train_data["deltas"].numpy()
        train_labels_np = train_data["actions"].numpy()
        
        test_deltas_np = test_data["deltas"].numpy()
        test_labels_np = test_data["actions"].numpy()
        num_classes = len(np.unique(test_labels_np))
        
        # k-NN
        knn = KNeighborsClassifier(n_neighbors=5, metric='cosine').fit(train_deltas_np, train_labels_np) 
        knn_acc = knn.score(test_deltas_np, test_labels_np)
        wandb.log({f"kNN_Accuracy": knn_acc})
        print(f"    k-NN Accuracy: {knn_acc:.4f}")


    def _plot_umap(self, data, epoch, set):
        print("  Generating UMAP plot...")
        reducer = umap.UMAP(n_neighbors=5, min_dist=0.5, n_components=2, random_state=42,
                            metric='cosine', spread=2.0, n_epochs=200)
        embeddings_2d = reducer.fit_transform(data["deltas"].numpy())

        df = pd.DataFrame({
            "action": [self.action_map.get(i.item(), str(i.item())) for i in data["actions"]],
            # "object": [self.object_map.get(i.item(), str(i.item())) for i in data["objects"]],
            "umap-1": embeddings_2d[:, 0], "umap-2": embeddings_2d[:, 1]
        })

        fig, ax = plt.subplots(figsize=(14, 10))
        unique_actions = df["action"].unique()
        markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h', 'H', '+', 'x']
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_actions)))
        
        for i, action in enumerate(unique_actions):
            action_data = df[df["action"] == action]
            ax.scatter(action_data["umap-1"], action_data["umap-2"], 
                    c=[colors[i]], marker=markers[i % len(markers)], 
                    s=100, alpha=0.8, label=action, 
                    edgecolors='black', linewidth=0.5)
        
        ax.set_title(f'UMAP of Delta Embeddings (OOD, Epoch {epoch+1})', fontsize=16)
        ax.set_xlabel('UMAP Dimension 1', fontsize=12)
        ax.set_ylabel('UMAP Dimension 2', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
        plt.tight_layout()
        wandb.log({f"UMAP_Visualization": wandb.Image(fig)})

        plot_dir = "plots"
        os.makedirs(plot_dir, exist_ok=True)
        fig.savefig(os.path.join(plot_dir, f"umap_epoch_{epoch}_{set}.png"), dpi=300, bbox_inches='tight')

        plt.close(fig)