import torch
import torch.nn as nn
import torch.nn.functional as F
from .clip_feature_extractor import CLIPFeatureExtractor
from .ttt_module import TTTModule
import copy
from torch.func import functional_call
from collections import OrderedDict


def select_dissimilar_indices_batched(pairwise_dists, k):
        """
        Approximates diverse selection per batch by choosing top-k windows
        with highest total pairwise distances.
        pairwise_dists: [B, N, N]
        Returns: [B, k]
        """
        diversity_scores = pairwise_dists.sum(dim=-1)  # [B, N]
        selected_idxs = torch.topk(diversity_scores, k=k, dim=-1)[1]  # [B, k]
        return selected_idxs

class ValueHead(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 4 * input_dim)
        self.linear2 = nn.Linear(4 * input_dim, 1)

        # He initialization for ReLU
        nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
        nn.init.constant_(self.linear1.bias, 0.0)
        nn.init.kaiming_normal_(self.linear2.weight, nonlinearity='relu')
        nn.init.constant_(self.linear2.bias, 0.0)

        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.linear1(x))
        x = self.linear2(x)
        return torch.sigmoid(x).squeeze(-1)


# --- Main TTT Model ---
class VLMAdapt(nn.Module):
    def __init__(
        self,
        clip_model_name="ViT-B-32",
        pretrained_clip="openai",
        projection_dim=256,
        lambda_self=1.0
    ):
        super().__init__()
        self.clip_feature_extractor = CLIPFeatureExtractor(clip_model_name, pretrained_clip)
        clip_dim = 2 * self.clip_feature_extractor.visual_encoder.output_dim  # Because of concatenation
        print(f'CLIP feature dimension: {clip_dim}')
        print(f'Projection dimension: {projection_dim}')
        # Low-rank projections
        projection_dim = 64
        self.P_K = nn.Linear(clip_dim, projection_dim)
        self.P_V = nn.Linear(clip_dim, projection_dim)
        self.P_Q = nn.Linear(clip_dim, projection_dim)

        # RNN for baseline
        self.rnn = nn.GRU(projection_dim, projection_dim, batch_first=True)
        # Adaptation Module
        self.f_adapt = TTTModule(projection_dim)

        # Progression Head
        self.progress_head = ValueHead(projection_dim)

        # Losses
        self.self_loss_fn = nn.MSELoss()
        self.pred_loss_fn = nn.MSELoss()

        # Store parameters for TTT
        self.lambda_self = lambda_self

    def forward_joint_embedding(self, frames, goals):
        B, T, C, H, W = frames.shape
        all_f = frames.view(-1, C, H, W)
        img_emb = self.clip_feature_extractor.encode_image(all_f).view(B, T, -1)
        txt_emb = self.clip_feature_extractor.encode_text(goals).unsqueeze(1).expand(-1, T, -1)
        return torch.cat([img_emb, txt_emb], dim=-1)  # Concatenation along feature dimension

        #return torch.cat([img_emb, txt_emb], dim=-1)  # (B, T, 2D)
        
        
    def forward_dot_product_embedding(self, frames, goals):
  
        B, T, C, H, W = frames.shape
        all_f = frames.view(-1, C, H, W)
        img_emb = self.clip_feature_extractor.encode_image(all_f).view(B, T, -1)
        txt_emb = self.clip_feature_extractor.encode_text(goals).unsqueeze(1).expand(-1, T, -1)

        # Normalize both image and text embeddings
        img_emb_norm = img_emb / img_emb.norm(dim=-1, keepdim=True)
        txt_emb_norm = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
        return torch.cat([img_emb_norm, txt_emb_norm], dim=-1)  # (B, T, 2D)

  
    
    def forward_self_supervised_inference(self, fused):
        # Projections are fixed hyperparameters during self-supervised
        with torch.no_grad():
            corrupted_input = self.P_K(fused)
            target_view = self.P_V(fused)
        recon = self.f_adapt(corrupted_input)
        return recon, target_view

    def forward_prediction(self, fused):
        z_t = self.f_adapt(self.P_Q(fused))
        return self.progress_head(z_t)
    
    def forward_self_supervised(self, fused):
        # Projections are fixed hyperparameters during self-supervised
        #with torch.no_grad():
        corrupted_input = self.P_K(fused)
        target_view = self.P_V(fused)
        recon = self.f_adapt(corrupted_input)
        loss_self = self.self_loss_fn(recon, target_view)
        return loss_self    
    
    
    def train_batch_ft(self, frames, goals, progress_labels, valid_masks, optimizer, lambda_, random_w_size=8, num_windows=8, grad_clip=None):
            """
            Trains the model using frozen CLIP + learned linear head.
            No test-time adaptation (no TTT / self-supervision).
            """
            # [B, T, D]
            fused = self.forward_joint_embedding(frames, goals)

            # Predict progress scores for all frames
            pred = self.progress_head(self.P_Q(fused))  # (B, T)

            # Compute loss only on valid (masked) positions
            loss_pred = self.pred_loss_fn(pred * valid_masks, progress_labels * valid_masks)

            optimizer.zero_grad()
            loss_pred.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip)

            optimizer.step()

            return loss_pred.item(), 0.0
        
    def train_batch_rnn(self, frames, goals, progress_labels, valid_masks, optimizer, lambda_, random_w_size=8, num_windows=8, grad_clip=None):
        """
        Trains the model using frozen CLIP + temporal modeling with RNN.
        """
        # [B, T, D] joint CLIP embeddings
        fused = self.forward_joint_embedding(frames, goals)  # [B, T, D]
        
        # Project features
        x_proj = self.P_Q(fused)  # [B, T, H]

        # RNN forward (assumes batch_first=True)
        rnn_out, _ = self.rnn(x_proj)  # [B, T, H]

        # Predict progress
        pred = self.progress_head(rnn_out).squeeze(-1)  # [B, T]

        # Loss on valid positions
        loss_pred = self.pred_loss_fn(pred * valid_masks, progress_labels * valid_masks)

        optimizer.zero_grad()
        loss_pred.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip)

        optimizer.step()

        return loss_pred.item(), 0.0


    def train_batch_dissimilarity(self, frames, goals, progress_labels, valid_masks, optimizer,
                    lambda_, random_w_size=8, num_windows=2, grad_clip=None):
        """
        Train using TTT with diversity-promoting window sampling, fully parallelized across batch.
        """
        fused = self.forward_joint_embedding(frames, goals)  # [B, T, D]
        B, T, D = fused.shape

        total_loss_pred = 0.0
        total_loss_self = 0.0
        optimizer.zero_grad()

        stride = max(1, random_w_size // 2)
        # stride = max(1, random_w_size // 2)
        candidate_starts = torch.arange(0, T - random_w_size + 1, stride, device=fused.device)
        num_cands = candidate_starts.shape[0]

        # Extract windows: [B, NumCands, W, D]
        fused_windows = torch.stack([
            fused[:, s:s + random_w_size, :] for s in candidate_starts
        ], dim=1)

        label_windows = torch.stack([
            progress_labels[:, s:s + random_w_size] for s in candidate_starts
        ], dim=1)

        mask_windows = torch.stack([
            valid_masks[:, s:s + random_w_size] for s in candidate_starts
        ], dim=1)

        # Pool to get [B, NumCands, D] features
        pooled = fused_windows.mean(dim=2)
        pooled = torch.nn.functional.normalize(pooled, dim=-1)

        # Compute pairwise distances: [B, N, N]
        pdists = torch.cdist(pooled, pooled, p=2)

        # Select diverse indices per batch: [B, num_windows]
        selected_indices = select_dissimilar_indices_batched(pdists, k=num_windows)
  

        # Gather selected windows: [B, num_windows, W, D]
        fused_chunks = torch.gather(
            fused_windows, dim=1,
            index=selected_indices[:, :, None, None].expand(B, num_windows, random_w_size, D)
        )
        label_chunks = torch.gather(
            label_windows, dim=1,
            index=selected_indices[:, :, None].expand(B, num_windows, random_w_size)
        )
        mask_chunks = torch.gather(
            mask_windows, dim=1,
            index=selected_indices[:, :, None].expand(B, num_windows, random_w_size)
        )

        # Adaptation across windows
        for n in range(num_windows):
            x_window = fused_chunks[:, n, :, :]  # [B, W, D]
            y_window = label_chunks[:, n, :]
            m_window = mask_chunks[:, n, :]

            f_adapt_weights = OrderedDict((name, param.clone()) for name, param in self.f_adapt.named_parameters())

            for t in range(random_w_size):
                x_t = x_window[:, t:t+1, :]  # [B, 1, D]
                y_t = y_window[:, t:t+1]
                m_t = m_window[:, t:t+1]

                corrupted = self.P_K(x_t)
                target = self.P_V(x_t)

                recon = functional_call(self.f_adapt, f_adapt_weights, (corrupted,))
                loss_self = self.self_loss_fn(recon, target)

                grads = torch.autograd.grad(loss_self, f_adapt_weights.values(), create_graph=True)
                f_adapt_weights = OrderedDict(
                    (name, param - lambda_ * grad)
                    for (name, param), grad in zip(f_adapt_weights.items(), grads)
                )

                q_input = self.P_Q(x_t)
                adapted_out = functional_call(self.f_adapt, f_adapt_weights, (q_input,))
                pred = self.progress_head(adapted_out)

                loss_pred = self.pred_loss_fn(pred * m_t, y_t * m_t)

                total_loss_pred += loss_pred
                total_loss_self += loss_self

        total_loss = total_loss_pred + lambda_ * total_loss_self
        total_loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip)
        optimizer.step()

        return total_loss_pred.item(), total_loss_self.item()

  
    
    def train_batch_random(self, frames, goals, progress_labels, valid_masks, optimizer, lambda_, random_w_size=8, num_windows=8, grad_clip=None):
        """
        Train using TTT on multiple random windows per sequence in the batch.
        """
        fused = self.forward_joint_embedding(frames, goals)  # [B, T, D]
        B, T, D = fused.shape
    

        total_loss_pred = 0.0
        total_loss_self = 0.0
        optimizer.zero_grad()

        for _ in range(num_windows):
            # Randomly sample window start positions for each sequence
            if T <= random_w_size:
                starts = torch.zeros(B, dtype=torch.long, device=fused.device)
            else:
                starts = torch.randint(0, int(T - random_w_size + 1), (B,), device=fused.device)
            ends = starts + random_w_size

            # Stack windowed slices from each sequence in the batch
            fused_chunks = torch.stack([fused[b, int(s):int(e), :] for b, (s, e) in enumerate(zip(starts, ends))])


            label_chunks = torch.stack([progress_labels[b, int(s):int(e)] for b, (s, e) in enumerate(zip(starts, ends))])

            mask_chunks = torch.stack([valid_masks[b, int(s):int(e)] for b, (s, e) in enumerate(zip(starts, ends))])

            # B_chunk, b, D = fused_chunks.shape
            # perm = torch.randperm(b, device=fused_chunks.device)
            # fused_chunks = fused_chunks[:, perm, :]
            # label_chunks = label_chunks[:, perm]
            # mask_chunks = mask_chunks[:, perm]
            # Clone initial weights for adaptation
            f_adapt_weights = OrderedDict((name, param.clone()) for name, param in self.f_adapt.named_parameters())

            for t in range(int(random_w_size)):
                x_t = fused_chunks[:, t:t+1, :]  # [B, 1, D]
                y_t = label_chunks[:, t:t+1]
                m_t = mask_chunks[:, t:t+1]

                corrupted = self.P_K(x_t)
                target = self.P_V(x_t)

                recon = functional_call(self.f_adapt, f_adapt_weights, (corrupted,))
                loss_self = self.self_loss_fn(recon, target)

                grads = torch.autograd.grad(loss_self, f_adapt_weights.values(), create_graph=True)
                f_adapt_weights = OrderedDict(
                    (name, param - lambda_ * grad)
                    for (name, param), grad in zip(f_adapt_weights.items(), grads)
                )

                # Predict using adapted model
                q_input = self.P_Q(x_t)
                adapted_out = functional_call(self.f_adapt, f_adapt_weights, (q_input,))
                pred = self.progress_head(adapted_out)

                loss_pred = self.pred_loss_fn(pred * m_t, y_t * m_t)

                total_loss_pred += loss_pred
                total_loss_self += loss_self

        total_loss = total_loss_pred + lambda_ * total_loss_self
        total_loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip)
        optimizer.step()

        return total_loss_pred.item(), total_loss_self.item()




    def train_mini_batch(self, frames, goals, progress_labels, valid_masks, optimizer, lambda_, random_w_size=8, num_windows=2, grad_clip=None):
        chunk_size = random_w_size
        """
        Train the model using mini-batch TTT (online, chunk-wise adaptation).
        
        """
        fused = self.forward_joint_embedding(frames, goals)  # [B, T, D]
        B, T, D = fused.shape
        
        # === Shuffle both inputs and labels ===
        # B_chunk, b, D = fused.shape
        # perm = torch.randperm(b, device=fused.device)
        # fused = fused[:, perm, :]
        # progress_labels = progress_labels[:, perm]
        # valid_masks = valid_masks[:, perm]

        total_loss_pred = 0.0
        total_loss_self = 0.0
        optimizer.zero_grad()

        # Start with initial weights
        f_adapt_weights = OrderedDict((name, param.clone()) for name, param in self.f_adapt.named_parameters())

        for start in range(0, T, chunk_size):
            end = min(start + chunk_size, T)
            fused_chunk = fused[:, start:end, :]             # [B, b, D]
            labels_chunk = progress_labels[:, start:end]     # [B, b]
            mask_chunk = valid_masks[:, start:end]           # [B, b]
            
            # Shuffle
            # # === Shuffle both inputs and labels ===
            # B_chunk, b, D = fused_chunk.shape
            # perm = torch.randperm(b, device=fused_chunk.device)
            # fused_chunk = fused_chunk[:, perm, :]
            # labels_chunk = labels_chunk[:, perm]
            # mask_chunk = mask_chunk[:, perm]

            corrupted = self.P_K(fused_chunk)                # [B, b, D]
            target = self.P_V(fused_chunk)                   # [B, b, D]

            # Use current weights to compute reconstruction
            recon = functional_call(self.f_adapt, f_adapt_weights, (corrupted,))
            loss_self = self.self_loss_fn(recon, target)

            # Compute gradients w.r.t. current weights
            grads = torch.autograd.grad(
                loss_self,
                f_adapt_weights.values(),
                create_graph=True  # allows gradient flow
            )

            # Update weights for next chunk (Wt = W_{t'} - η ∇ℓ)
            f_adapt_weights = OrderedDict(
                (name, param - lambda_ * grad)
                for (name, param), grad in zip(f_adapt_weights.items(), grads)
            )

            # Predict with updated weights
            query_input = self.P_Q(fused_chunk)
            adapted_output = functional_call(self.f_adapt, f_adapt_weights, (query_input,))
            pred_chunk = self.progress_head(adapted_output)

            # Supervised loss
            loss_pred = self.pred_loss_fn(pred_chunk * mask_chunk, labels_chunk * mask_chunk)

            # Accumulate losses
            total_loss_pred += loss_pred
            total_loss_self += loss_self

        # Backward + update
        # Outer loop loss (weighted sum)
        total_loss = total_loss_pred + lambda_ * total_loss_self  # alpha controls reconstruction strength

        # Backward and optimize
        total_loss.backward()
        #total_loss_pred.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip)
        optimizer.step()

        return total_loss_pred.item(), total_loss_self.item()
    
    
   

    def offline_ttt_inference(self, frames, goal_text, ttt_lr=1e-3, ttt_epochs=5):
        B, T, C, H, W = frames.shape
        fused = self.forward_joint_embedding(frames, goal_text)

        # Clone model for adaptation
        adapted_model = copy.deepcopy(self)
        adapted_model.train()

        # Freeze everything except f_adapt
        for param in adapted_model.parameters():
            param.requires_grad = False
        for param in adapted_model.f_adapt.parameters():
            param.requires_grad = True

        # --- Manual self-supervised adaptation ---
        corrupted_input, target_view = adapted_model.forward_self_supervised_inference(fused) 

        for _ in range(ttt_epochs):
            # Use the manual update method to update f_adapt
            loss_self = adapted_model.f_adapt.manual_update(corrupted_input, target_view, lr=ttt_lr)

        # --- Predict progress with adapted model ---
        adapted_model.eval()
        with torch.no_grad():
            predictions = adapted_model.forward_prediction(fused)

        return predictions
    
    
    
    
    def windowed_ttt_inference(self, frames, goal_text, ttt_lr=1e-3, ttt_epochs=5, window_size=8, reset=True):
        B, T, C, H, W = frames.shape
        fused = self.forward_joint_embedding(frames, goal_text)  # (B, T, D)
        #print(f'window size: {window_size}, reset: {reset}')#

        adapted_model = copy.deepcopy(self)
        initial_state = copy.deepcopy(self.f_adapt.state_dict())
        adapted_model.eval()

        for param in adapted_model.parameters():
            param.requires_grad = False
        for param in adapted_model.f_adapt.parameters():
            param.requires_grad = True

        predictions = []

        for t in range(T):
            if reset:
                adapted_model.f_adapt.load_state_dict(initial_state)

            # Get the window of data ending at t
            start = max(0, t - window_size + 1)
            window = fused[:, start:t+1, :]  # [B, window_size, D]

            # === Adaptation using self-supervised updates ===
            corrupted = adapted_model.P_K(window)
            target = adapted_model.P_V(window)

            for _ in range(ttt_epochs):
                _ = adapted_model.f_adapt.manual_update(corrupted, target, lr=ttt_lr)

            # Predict progress at current timestep
            query_t = fused[:, t, :].unsqueeze(1)  # [B, 1, D]
            pred_t = adapted_model.forward_prediction(query_t)  # [B, 1]
            predictions.append(pred_t)

        predictions = torch.cat(predictions, dim=1)  # [B, T]
        return predictions
    
    
    

    def inference_no_ttt(self, frames, goal_text):
        B, T, C, H, W = frames.shape
        fused = self.forward_joint_embedding(frames, goal_text)

        # No adaptation: use original model
        self.eval()  # Ensure eval mode, no gradients tracked

        predictions = []

        with torch.no_grad():
            for t in range(T):
                fused_t = fused[:, t, :].unsqueeze(1)  # (B, 1, D)
                pred_t = self.forward_prediction(fused_t)  # (B, 1)
                predictions.append(pred_t)

        predictions = torch.cat(predictions, dim=1)  # (B, T)
        return predictions
    
   
    def compute_clip_similarity_score(self, frames, goals):
            """
            Compute CLIP similarity baseline between image frames and goal texts.
            """
            B, T, C, H, W = frames.shape
            all_f = frames.view(-1, C, H, W)
            
            img_emb = self.clip_feature_extractor.encode_image(all_f).view(B, T, -1)  # (B, T, D)
            txt_emb = self.clip_feature_extractor.encode_text(goals).unsqueeze(1).expand(-1, T, -1)  # (B, T, D)

            # Normalize for cosine similarity
            img_emb_norm = F.normalize(img_emb, dim=-1)
            txt_emb_norm = F.normalize(txt_emb, dim=-1)

            # Cosine similarity along feature dimension
            clip_scores = (img_emb_norm * txt_emb_norm).sum(dim=-1)  # (B, T), values in [-1, 1]

            # Optional: Normalize to [0, 1] for consistency
            clip_scores = (clip_scores + 1.0) / 2.0

            return clip_scores  # (B, T)



