import os
import time
import json
from itertools import product


import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler



from momentfm import MOMENTPipeline
from momentfm.utils.utils import control_randomness
from momentfm.data.informer_dataset import InformerDataset
from momentfm.utils.masking import Masking
from momentfm.data.forecasting_dataset import ForecastingDataset
from momentfm.utils.forecasting_metrics import get_forecasting_metrics
from pooling.src.device_utils import DeviceUtils

class imputation_exp:

    @staticmethod
    def split_with_modified_mask(batch_x, mask, batch_mask):
        """
        Splits batch_x into two tensors based on 'mask', and creates a modified version of 
        batch_mask corresponding to the kept values.
        
        Parameters:
            batch_x (torch.Tensor): Input tensor of shape [bs, 1, L].
            mask (torch.Tensor): Binary tensor of shape [bs, L] indicating which values to keep (1) vs. discard (0).
                                Each row is assumed to have an equal number of ones and zeros.
            batch_mask (torch.Tensor): Binary tensor of shape [bs, L] indicating actual (1) vs. padded (0) values.
        
        Returns:
            tensor_keep (torch.Tensor): Tensor of kept values with shape [bs, 1, n_keep].
            tensor_discard (torch.Tensor): Tensor of discarded values with shape [bs, 1, n_discard].
            modified_batch_mask (torch.Tensor): Tensor of kept mask values with shape [bs, n_keep].
        
        Example:
            >>> import torch
            >>> bs, L = 4, 10  # For example, 4 samples and L=10 elements per sample.
            >>> batch_x = torch.randn(bs, 1, L)
            >>> # Create a selection mask: for demonstration, assume the first L//2 elements are selected (ones)
            >>> # and the rest are zeros.
            >>> mask = torch.cat([torch.ones(L // 2), torch.zeros(L // 2)]).unsqueeze(0).repeat(bs, 1)
            >>> # Create a batch_mask indicating actual values and padded values.
            >>> # For this example, assume the first 7 elements are actual (1) and the remaining 3 are padded (0).
            >>> batch_mask = torch.cat([torch.ones(7), torch.zeros(L - 7)]).unsqueeze(0).repeat(bs, 1)
            >>> tensor_keep, tensor_discard, modified_batch_mask = MyProcessor.split_with_modified_mask(batch_x, mask, batch_mask)
            >>> print("tensor_keep (shape):", tensor_keep.shape)
            >>> print("tensor_discard (shape):", tensor_discard.shape)
            >>> print("modified_batch_mask (shape):", modified_batch_mask.shape)
        """
        # Remove the singleton channel dimension to reshape to [bs, L].
        batch_x = batch_x.squeeze(1)
        
        # Initialize lists to gather each sample's results.
        keep_list = []
        discard_list = []
        modified_mask_list = []
        
        # Process each sample in the batch.
        for bx_row, sel_mask, bm_row in zip(batch_x, mask, batch_mask):
            # Convert the selection mask to a boolean tensor.
            bool_sel = sel_mask.bool()
            
            # Select the values in bx_row according to the boolean mask.
            kept_vals = bx_row[bool_sel]
            discarded_vals = bx_row[~bool_sel]
            
            # Similarly, select the corresponding entries in batch_mask.
            kept_bm = bm_row[bool_sel]
            
            keep_list.append(kept_vals)
            discard_list.append(discarded_vals)
            modified_mask_list.append(kept_bm)
        
        # Stack the lists into tensors; each will have shape [bs, number_of_selected_values].
        tensor_keep = torch.stack(keep_list, dim=0).unsqueeze(1)
        tensor_discard = torch.stack(discard_list, dim=0).unsqueeze(1)
        modified_batch_mask = torch.stack(modified_mask_list, dim=0)
        
        return tensor_keep, tensor_discard, modified_batch_mask


    @staticmethod
    def cache_fixed_masks(dataloader, mask_generator, device):
        fixed_batches = []
        for batch_x, batch_masks in tqdm(dataloader, desc="Caching masks for stable evaluation on validation/test set"):
            batch_x = batch_x.to(device).float()
            n_channels = batch_x.shape[1]
            batch_x_reshaped = batch_x.reshape((-1, 1, batch_x.shape[2]))

            batch_masks = batch_masks.to(device).long()
            batch_masks = batch_masks.repeat_interleave(n_channels, axis=0)

            mask = mask_generator.generate_mask(x=batch_x_reshaped, input_mask=batch_masks).to(device).long()

            fixed_batches.append((batch_x.cpu(), batch_masks.cpu(), mask.cpu()))
        return fixed_batches


    @staticmethod
    @torch.no_grad()
    def evaluate_fixed(
                        model, 
                        cached_batches, 
                        device, 
                        # criterion, 
                        imputation_exp, 
                        input_ts_len, 
                        use_mixed_precision,
                        dtype,
                    ):
        model.eval()
        # total_loss = 0.0
        # num_batches = 0
        trues = []
        preds = []
        for batch_x, batch_masks, mask in tqdm(cached_batches, desc="Evaluating (Fixed Masks)"):
            batch_x = batch_x.to(device).float()
            batch_masks = batch_masks.to(device).long()
            mask = mask.to(device).long()

            n_channels = batch_x.shape[1]
            batch_x = batch_x.reshape((-1, 1, input_ts_len))
            batch_masks = batch_masks.repeat_interleave(n_channels, axis=0)

            batch_x_keep, batch_x_discard, modified_batch_masks = imputation_exp.split_with_modified_mask(
                batch_x, mask, batch_masks)

            if use_mixed_precision and device.type == 'cuda':
                with torch.autocast(device_type=device.type, dtype=dtype):
                    output = model(x_enc=batch_x_keep, input_mask=modified_batch_masks)
            else:
                output = model(x_enc=batch_x_keep, input_mask=modified_batch_masks)

            window_size = batch_x_discard.shape[2]
            reconstruction = output.reconstruction.reshape((-1, n_channels, window_size))
            batch_x_discard = batch_x_discard.reshape((-1, n_channels, window_size))

            trues.append(batch_x_discard.detach().cpu().numpy())
            preds.append(reconstruction.detach().cpu().numpy())
            # loss = criterion(reconstruction, batch_x_discard)
            # total_loss += loss.item()
            # num_batches += 1

        trues = np.concatenate(trues, axis=0)
        preds = np.concatenate(preds, axis=0)
        metrics = get_forecasting_metrics(y=trues, y_hat=preds, reduction='mean')
        return metrics
    
    @staticmethod
    def get_imputation_dataloaders(
        dataset_folder: str,
        dataset_name: str,
        batch_size: int,
        only_OT: bool = False,
    ):
        datafile = f"{dataset_folder}/{dataset_name}.csv"
        # Load data
        train_dataset = ForecastingDataset(
                                        file_path=datafile, 
                                        data_split="train", 
                                        task_name="imputation",
                                        only_OT=only_OT,
                                    )
        train_loader = DataLoader(
                                train_dataset, 
                                batch_size=batch_size, 
                                shuffle=True
                            )

        val_dataset = ForecastingDataset(
                                        file_path=datafile,
                                        data_split="val", 
                                        task_name="imputation",
                                        only_OT=only_OT,
                                    )
        val_loader = DataLoader(
                                val_dataset,
                                batch_size=batch_size, 
                                shuffle=True
                            )

        test_dataset = ForecastingDataset(
                                        file_path=datafile,
                                        data_split="test",
                                        task_name="imputation",
                                        only_OT=only_OT,
                                    )
        test_loader = DataLoader(
                                test_dataset, 
                                batch_size=batch_size, 
                                shuffle=True
                            )
        
        train_seq_len = train_dataset.seq_len
        val_seq_len = val_dataset.seq_len
        test_seq_len = test_dataset.seq_len
        assert train_seq_len == val_seq_len == test_seq_len, \
            f"Unequal sequence length: Train seq len: {train_seq_len}, Val seq len: {val_seq_len}, Test seq len: {test_seq_len}"

        return train_loader, val_loader, test_loader, train_dataset.seq_len, train_dataset.n_channels


    @staticmethod
    def run_one_exp(
        pooling_method: str,
        model_size: str,
        train_epoch: int, 
        seed: int,
        dataset_folder: str,
        dataset_name: str,
        optimizer_name: str,
        batch_size: int,
        result_folder: str,
        timestamp: str,
        lr: float = 1e-4,
        use_mixed_precision: bool = False,
        only_OT: bool = False,
    ):  
        print("------------------------------------------")
        print(f"Launching experiment on {dataset_name} with model size {model_size}, pooling method {pooling_method}, and seed {seed}")

        # Set random seeds for PyTorch, Numpy etc.
        control_randomness(seed=seed)

        # TODO: get dataset here
        train_dataloader, val_dataloader, test_dataloader, input_ts_len, n_channels = \
        imputation_exp.get_imputation_dataloaders(
            dataset_folder=dataset_folder,
            dataset_name=dataset_name,
            batch_size=batch_size,
            only_OT=only_OT,
        )

        model = MOMENTPipeline.from_pretrained(
            f"AutonLab/MOMENT-1-{model_size}",  
            model_kwargs={
                'task_name': 'imputation',
                'input_ts_len': input_ts_len,
                'freeze_encoder': True, # Freeze the patch embedding layer
                'freeze_embedder': True, # Freeze the transformer encoder
                'freeze_head': False, # The linear forecasting head must be trained
                'n_channels': n_channels,
                'pooling_method': pooling_method,
                }
        )

        model.init()

        device = DeviceUtils.get_training_device()
        # Determine precision
        dtype = DeviceUtils.get_mixed_precision_dtype() if use_mixed_precision else torch.float32
        # dtype = torch.float16

        print(f"Using device: {device}, dtype: {dtype}")

        model = model.to(device=device)

        scaler = torch.cuda.amp.GradScaler() if use_mixed_precision and device.type == 'cuda' else None

        criterion = torch.nn.MSELoss()
        criterion = criterion.to(device)

        if optimizer_name == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")

        mask_generator = Masking(single_patch_token_mask=True)

        cached_val_batches = imputation_exp.cache_fixed_masks(val_dataloader, mask_generator, device)
        cached_test_batches = imputation_exp.cache_fixed_masks(test_dataloader, mask_generator, device)
        

        checkpoint_path = f"{result_folder}/checkpoints_{timestamp}/imputation_moment_{model_size}_dataset_{dataset_name}/head_of_best_model_with_{pooling_method}_pooling_seed_{seed}.pth"
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

        best_val_mse = float("inf")
        best_test_mse = None
        best_epoch = -1
        start_time = time.time()  # Record the start time


        for epoch in range(train_epoch):
            model.train()
            epoch_loss = 0.0
            num_batches = 0

            # for batch_x, batch_masks in tqdm(train_dataloader, total=len(train_dataloader), desc="Batches"):
            for batch_x, batch_masks in train_dataloader:
                batch_x = batch_x.to(device).float()
                n_channels = batch_x.shape[1]
                batch_x = batch_x.reshape((-1, 1, input_ts_len)) 

                batch_masks = batch_masks.to(device).long()
                batch_masks = batch_masks.repeat_interleave(n_channels, axis=0)

                mask = mask_generator.generate_mask(
                    x=batch_x, input_mask=batch_masks).to(device).long()

                batch_x_keep, batch_x_discard, modified_batch_masks = imputation_exp.split_with_modified_mask(
                    batch_x, mask, batch_masks)

                batch_x_keep = batch_x_keep.to(device)
                modified_batch_masks = modified_batch_masks.to(device)
                
                optimizer.zero_grad()

                with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_mixed_precision):
                    output = model(x_enc=batch_x_keep, input_mask=modified_batch_masks)
                    window_size = batch_x_discard.shape[2]
                    reconstruction = output.reconstruction.reshape((-1, n_channels, window_size))
                    batch_x_discard = batch_x_discard.reshape((-1, n_channels, window_size))

                    loss = criterion(reconstruction, batch_x_discard)

                if use_mixed_precision and device.type == 'cuda':

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()

                epoch_loss += loss.item()
                num_batches += 1

            avg_train_loss = epoch_loss / num_batches
            val_metrics = imputation_exp.evaluate_fixed(model, cached_val_batches, device, imputation_exp, input_ts_len, use_mixed_precision, dtype)

            print(f"Epoch [{epoch + 1}/{train_epoch}] - Train Loss: {avg_train_loss:.6f} | Val Loss: {val_metrics.mse:.6f}")
            
            # Check for best validation loss
            if val_metrics.mse < best_val_mse:
                torch.save(model.head.state_dict(), checkpoint_path)
                best_val_mse = val_metrics.mse
                best_epoch = epoch + 1  # since epoch is zero-indexed
                test_metrics = imputation_exp.evaluate_fixed(model, cached_test_batches, device, imputation_exp, input_ts_len, use_mixed_precision, dtype)
                best_test_mse = test_metrics.mse
                print(f"✅ New best val MSE: {best_val_mse:.6f} at epoch {best_epoch}, test MSE: {test_metrics.mse:.6f}")
                result = {
                    'metric_reported_at_epoch': epoch,
                    'train_loss': float(avg_train_loss),
                    'val_mse': float(val_metrics.mse),
                    'val_mae': float(val_metrics.mae),
                    'val_mape': float(val_metrics.mape),
                    'val_smape': float(val_metrics.smape),
                    'val_rmse': float(val_metrics.rmse),
                    'test_mse': float(test_metrics.mse),
                    'test_mae': float(test_metrics.mae),
                    'test_mape': float(test_metrics.mape),
                    'test_smape': float(test_metrics.smape),
                    'test_rmse': float(test_metrics.rmse),
                }
            print("---------------------------------")
        
        end_time = time.time()  # Record the end time
        total_training_time = end_time - start_time  # Calculate total training time
        print(f"Total training time: {total_training_time:.2f} seconds")


        result = result | {
            'train_epoch': train_epoch,
            'datset_name': dataset_name,
            'pooling_method': pooling_method,
            'seed': seed,
            'optimizer_name': optimizer_name,
            'lr': lr,
            'total_training_time': total_training_time,
            'batch_size': batch_size,
            'model_size': model_size,
            'device': str(device),
            'use_mixed_precision': use_mixed_precision,
            'only_OT': only_OT,
            'cckpoint_path': checkpoint_path,
        }
        
        # Final report
        print(f"\n🏁 Training complete.")
        print(f"📉 Best Val Loss: {best_val_mse:.6f} (Epoch {best_epoch})")
        print(f"🧪 Corresponding Test Loss: {best_test_mse:.6f}")

        return result


    @staticmethod
    def run_exp(
        pooling_method_l: list,
        model_size_l: list,
        seed_l: list,
        dataset_folder: str,
        dataset_name_l: list,
        train_epoch: int,
        result_folder: str, 
        timestamp: str = None,
        optimizer_name: str = 'adam',
        batch_size: int = 32,
        lr: float = 1e-4,
        output_path: str = None,
        use_mixed_precision: bool = False,
        only_OT: bool = False,
        
    ):
        results = []
        total_experiments = (
            len(pooling_method_l) *
            len(model_size_l) *
            len(seed_l) *
            len(dataset_name_l)
        )
        # Validate model_size_l
        valid_model_sizes = {'small', 'base', 'large'}
        if not all(size in valid_model_sizes for size in model_size_l):
            raise ValueError(f"Invalid model_size_l. Allowed values are: {valid_model_sizes}")
        
        if output_path == None:
            output_path = f'{result_folder}/imputation_exp_results_{timestamp}.jsonl'
        # Make sure the folder exists
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        # Optional: clear existing file
        with open(output_path, 'w') as f:
            pass

        # Iterate over the Cartesian product of the parameters with a progress bar
        # for pooling_method, dataset_name, seed, model_size, forecast_horizon in tqdm(
        #     product(pooling_method_l, dataset_name_l, seed_l, model_size_l, forecast_horizon_l),
        #     total=total_experiments,
        #     desc="Running Experiments"
        # ):
        
        for model_size, seed, dataset_name, pooling_method in tqdm(
            product(model_size_l, seed_l, dataset_name_l, pooling_method_l,),
            total=total_experiments,
            desc="Running Experiments"
        ):
            result = imputation_exp.run_one_exp(
                            pooling_method=pooling_method,
                            model_size=model_size,
                            train_epoch=train_epoch,
                            seed=seed,
                            dataset_folder=dataset_folder,
                            dataset_name=dataset_name,
                            optimizer_name=optimizer_name,
                            batch_size=batch_size,
                            lr=lr,
                            use_mixed_precision=use_mixed_precision,
                            only_OT=only_OT,
                            result_folder=result_folder,
                            timestamp=timestamp,
                        )
            results.append(result)
        
            # Append to file
            if output_path:
                with open(output_path, 'a') as f:
                    f.write(json.dumps(result) + '\n')
    
        return results