import torch
import torch.nn as nn
import numpy as np

from backward_model.models.Timer import Model as timer


class ITS:
    def __init__(self, args, backward_checkpoints, device):
        self.args = args
        self.device = device
        self.backward_model = self._load_model(backward_checkpoints, timer)

    def _load_model(self, weight_path, model_class):
        model = model_class(self.args)
        checkpoint = torch.load(weight_path, map_location=self.device)
        model.load_state_dict(checkpoint)
        model = model.to(self.device)
        model.eval()
        return model

    def reverse_forecast(self, step_samples_tensor, history):

        n_samples, batch_size = step_samples_tensor.shape[0], step_samples_tensor.shape[1]
        seq_len, feat_dim = step_samples_tensor.shape[2], step_samples_tensor.shape[3]
        
        with torch.no_grad():
            flipped_input = torch.flip(step_samples_tensor.view(-1, seq_len, feat_dim), dims=[1])
            with torch.amp.autocast('cuda'):
                reconstructed_flipped = self.backward_model(flipped_input, None, None, None)
            
            reconstructed = torch.flip(reconstructed_flipped, dims=[1])
            reconstructed = reconstructed.reshape(n_samples, batch_size, seq_len, feat_dim)
            history_expanded = history.unsqueeze(0)
            
            errors = (reconstructed - history_expanded) ** 2
            mean_errors = errors.mean(dim=0, keepdim=True)  # Shape: [1, batch_size, seq_len, feat_dim]
            std_errors = errors.std(dim=0, unbiased=False, keepdim=True) + 1e-8  # Shape: [1, batch_size, seq_len, feat_dim]
            norm_errors = -(errors - mean_errors) / std_errors # Shape: [n_samples, batch_size, seq_len, feat_dim]
            weights = torch.softmax(norm_errors, dim=0)  # Shape: [n_samples, batch_size, seq_len, feat_dim]
            fusion_candidate = (weights * step_samples_tensor).sum(dim=0)  # Shape: [batch_size, seq_len, feat_dim]
            
        return fusion_candidate

    def reverse_imputation(self, step_samples_tensor, batch_x, mask=None):

        n_samples, batch_size = step_samples_tensor.shape[0], step_samples_tensor.shape[1]
        seq_len, feat_dim = mask.shape[1], mask.shape[2]
        
        
        original_missing = (mask == 0) 
        original_observed = (mask == 1) 
        
        reverse_mask = torch.zeros_like(mask, dtype=torch.bool, device=self.device)

        patch_len = self.args.patch_len
        reverse_mask[:, :patch_len, :] = True
        
        boundary_range = patch_len // 2
        missing_mask = (~original_observed).float()  # [batch_size, seq_len, feat_dim]
        edges = torch.abs(missing_mask[:, 1:] - missing_mask[:, :-1]) > 0  # [batch_size, seq_len-1, feat_dim]
        edges_padded = torch.nn.functional.pad(edges, (0, 0, 0, 1), mode='constant', value=0)
        edge_positions = edges_padded.float().unsqueeze(1)  # [batch_size, 1, seq_len, feat_dim]
        window_kernel = torch.ones(1, 1, 2*boundary_range+1, 1, device=self.device)
        padding = boundary_range 
        windows = torch.nn.functional.conv2d(edge_positions, window_kernel, padding=(padding, 0), groups=1)
        windows = windows.squeeze(1) > 0  # [batch_size, seq_len, feat_dim]
        window_observations = windows & original_observed
        reverse_mask = reverse_mask | window_observations
        not_in_window = ~windows
        distant_observations = not_in_window & original_observed
        prob_mask = torch.rand(batch_size, seq_len, feat_dim, device=self.device) < 0.7
        random_selected = distant_observations & prob_mask
        reverse_mask = reverse_mask | random_selected
        
        f_dim = -1 if self.args.features == 'MS' else 0
        
        with torch.no_grad():
            filled_sequences = batch_x.unsqueeze(0).repeat(n_samples, 1, 1, 1)
            missing_mask_expanded = original_missing.unsqueeze(0).expand(n_samples, -1, -1, -1)
            fill_values = step_samples_tensor.to(filled_sequences.dtype)
            filled_sequences = torch.where(
                missing_mask_expanded, 
                fill_values, 
                filled_sequences
            )
            
            reverse_mask_expanded = reverse_mask.unsqueeze(0).expand(n_samples, -1, -1, -1)
            masked_for_reconstruction = filled_sequences * reverse_mask_expanded.float()
            batch_size_all = n_samples * batch_size
            seq_len, feat_dim = masked_for_reconstruction.shape[2], masked_for_reconstruction.shape[3]
            
            with torch.amp.autocast('cuda'): 
                reconstructed = self.backward_model(
                    masked_for_reconstruction.reshape(batch_size_all, seq_len, feat_dim),
                    None, None,
                    reverse_mask_expanded.reshape(batch_size_all, seq_len, feat_dim)
                )
            
            reconstructed = reconstructed.reshape(n_samples, batch_size, seq_len, reconstructed.shape[-1])
            reconstructed_crop = reconstructed[:, :, :, f_dim:]
            original_crop = filled_sequences[:, :, :, f_dim:]
            observed_mask = (~reverse_mask_expanded[:, :, :, f_dim:]).float()
            step_samples_tensor_cropped = step_samples_tensor[:, :, :, f_dim:]
            
            errors = ((reconstructed_crop - original_crop) * observed_mask).pow(2)
            mean_errors = errors.mean(dim=0, keepdim=True)
            std_errors = errors.std(dim=0, keepdim=True) + 1e-8
            weights = torch.softmax(-(errors - mean_errors) / std_errors, dim=0)
            fusion_candidate = (weights * step_samples_tensor_cropped).sum(dim=0)
            
            return fusion_candidate
        
    def run_inference(self, args, pred, batch_x, mask=None):
        if args.task_name == 'long_term_forecast':
            fusion_candidate = self.reverse_forecast(pred, batch_x)
        elif args.task_name == 'imputation':
            fusion_candidate = self.reverse_imputation(pred, batch_x, mask)
        return fusion_candidate