import os
import time
import json
from itertools import product


import numpy as np
from tqdm import tqdm
import torch
import torch.cuda.amp
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR

from momentfm.utils.utils import control_randomness
from momentfm.data.forecasting_dataset import ForecastingDataset
from momentfm.utils.forecasting_metrics import get_forecasting_metrics
from momentfm import MOMENTPipeline
from pooling.src.device_utils import DeviceUtils


class forecast_exp:

    @staticmethod
    def get_forecast_dataloaders(
        dataset_folder: str,
        forecast_horizon: int,
        dataset_name: str,
        batch_size: int,
        only_OT: bool = False,
        dataloader_num_workers: int = 0,
        pin_memory: bool = False,
    ):
        datafile = f"{dataset_folder}/{dataset_name}.csv"
        # Load data
        train_dataset = ForecastingDataset(
                                        file_path=datafile, 
                                        data_split="train", 
                                        forecast_horizon=forecast_horizon,
                                        only_OT=only_OT,
                                    )
        train_loader = DataLoader(
                                train_dataset, 
                                batch_size=batch_size, 
                                shuffle=True,
                                num_workers=dataloader_num_workers,
                                pin_memory=pin_memory,
                            )

        val_dataset = ForecastingDataset(
                                        file_path=datafile,
                                        data_split="val", 
                                        forecast_horizon=forecast_horizon,
                                        only_OT=only_OT,
                                    )
        val_loader = DataLoader(
                                val_dataset,
                                batch_size=batch_size, 
                                shuffle=True,
                                num_workers=dataloader_num_workers,
                                pin_memory=pin_memory,
                            )

        test_dataset = ForecastingDataset(
                                        file_path=datafile,
                                        data_split="test",
                                        forecast_horizon=forecast_horizon,
                                        only_OT=only_OT,
                                    )
        test_loader = DataLoader(
                                test_dataset, 
                                batch_size=batch_size, 
                                shuffle=True,
                                num_workers=dataloader_num_workers,
                                pin_memory=pin_memory,
                            )
        
        train_seq_len = train_dataset.seq_len
        val_seq_len = val_dataset.seq_len
        test_seq_len = test_dataset.seq_len
        assert train_seq_len == val_seq_len == test_seq_len, \
            f"Unequal sequence length: Train seq len: {train_seq_len}, Val seq len: {val_seq_len}, Test seq len: {test_seq_len}"

        return train_loader, val_loader, test_loader, train_dataset.seq_len, train_dataset.n_channels
    

    @staticmethod
    def evaluate_model(model, data_loader, device):
        # Evaluate the model on the test split
        trues, preds, histories, losses = [], [], [], []
        model.eval()
        with torch.no_grad():
            for timeseries, forecast, input_mask in tqdm(data_loader, total=len(data_loader)):
            # Move the data to the GPU
                timeseries = timeseries.float().to(device)
                input_mask = input_mask.to(device)
                forecast = forecast.float().to(device)

                with torch.cuda.amp.autocast():
                    output = model(x_enc=timeseries, input_mask=input_mask)
                
                # loss = criterion(output.forecast, forecast)                
                # losses.append(loss.item())

                trues.append(forecast.detach().cpu().numpy())
                preds.append(output.forecast.detach().cpu().numpy())
                # histories.append(timeseries.detach().cpu().numpy())
        
        # losses = np.array(losses)
        # average_loss = np.average(losses)
        model.train()

        trues = np.concatenate(trues, axis=0)
        preds = np.concatenate(preds, axis=0)
        # histories = np.concatenate(histories, axis=0)
        
        metrics = get_forecasting_metrics(y=trues, y_hat=preds, reduction='mean')

        return metrics


    @staticmethod
    def run_one_exp(
        pooling_method: str,
        model_size: str,
        train_epoch: int, 
        seed: int,
        dataset_folder: str,
        dataset_name: str,
        optimizer_name: str,
        forecast_horizon: int,
        batch_size: int,
        result_folder: str,
        timestamp: str,
        lr: float = 1e-4,
        use_mixed_precision: bool = False,
        only_OT: bool = False,
    ):
        print("------------------------------------------")
        print(f"Launching experiment on {dataset_name} with model size {model_size}, pooling method {pooling_method}, and seed {seed}")

        # Set random seeds for PyTorch, Numpy etc.
        control_randomness(seed=seed)
        train_loader, val_loader, test_loader, seq_length, n_channels = \
            forecast_exp.get_forecast_dataloaders(
                dataset_folder=dataset_folder,
                dataset_name=dataset_name,
                forecast_horizon=forecast_horizon,
                batch_size=batch_size,
                only_OT=only_OT,
            )

        model = MOMENTPipeline.from_pretrained(
            f"AutonLab/MOMENT-1-{model_size}", 
            model_kwargs={
                'task_name': 'forecasting',
                'forecast_horizon': forecast_horizon,
                'head_dropout': 0.1,
                'weight_decay': 0,
                'freeze_encoder': True, # Freeze the patch embedding layer
                'freeze_embedder': True, # Freeze the transformer encoder
                'freeze_head': False, # The linear forecasting head must be trained
                'pooling_method': pooling_method, # Use the mean of the patch embeddings
                'input_ts_len': seq_length, # The length of the input time series
                'n_channels': n_channels, # The number of channels in the input time series
            },
            # local_files_only=True,  # Whether or not to only look at local files (i.e., do not try to download the model).
        )

        model.init()

        device = DeviceUtils.get_training_device()
        # Determine precision
        dtype = DeviceUtils.get_mixed_precision_dtype() if use_mixed_precision else torch.float32

        print(f"Using device: {device}, dtype: {dtype}")

        model = model.to(device=device)

        scaler = torch.cuda.amp.GradScaler() if use_mixed_precision and device.type == 'cuda' else None

        criterion = torch.nn.MSELoss()
        criterion = criterion.to(device)
        
        if optimizer_name == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")

        # Create a OneCycleLR scheduler
        max_lr = 1e-4
        total_steps = len(train_loader) * train_epoch
        scheduler = OneCycleLR(optimizer, max_lr=max_lr, total_steps=total_steps, pct_start=0.3)

        # Gradient clipping value
        max_norm = 5.0
        
        checkpoint_path = f"{result_folder}/checkpoints_{timestamp}/forecast_moment_{model_size}_dataset_{dataset_name}/head_of_best_model_with_{pooling_method}_pooling_seed_{seed}.pth"
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

        best_val_mse = float('inf')
        start_time = time.time()  # Record the start time
        
        
        for epoch in range(1, train_epoch+1):
            losses = []

            model.train()
            for timeseries, forecast, input_mask in tqdm(train_loader, total=len(train_loader)):
                # Move the data to the GPU
                timeseries = timeseries.float().to(device)
                input_mask = input_mask.to(device)
                forecast = forecast.float().to(device)
                optimizer.zero_grad(set_to_none=True)

                with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_mixed_precision):
                    output = model(x_enc=timeseries, input_mask=input_mask)
                    loss = criterion(output.forecast, forecast)

                if use_mixed_precision and device.type == 'cuda':
                    # Scales the loss for mixed precision training
                    scaler.scale(loss).backward()
                    # Clip gradients
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                    optimizer.step()

                losses.append(loss.item())
                # Step the learning rate scheduler
                scheduler.step()

            losses = np.array(losses)
            average_loss = np.average(losses)

            print(f"Epoch {epoch}/{train_epoch}: Train MSE loss: {average_loss:.3f}")

            model.eval()
            val_metrics = forecast_exp.evaluate_model(model, val_loader, device)
            print(f"Epoch {epoch}/{train_epoch}: Val MSE: {val_metrics.mse:.3f} | Val MAE: {val_metrics.mae:.3f}")
            
            # Track best epoch based on val MSE
            if val_metrics.mse < best_val_mse:
                best_val_mse = val_metrics.mse
                torch.save(model.head.state_dict(), checkpoint_path)
                print(f"Checkpoint saved at epoch {epoch} with val_loss: {best_val_mse:.4f}")
                
                test_metrics = forecast_exp.evaluate_model(model, test_loader, device)
                print(f"Epoch {epoch}/{train_epoch}: Test MSE: {test_metrics.mse:.3f} | Test MAE: {test_metrics.mae:.3f}")
                result = {
                    'metric_reported_at_epoch': epoch,
                    'train_loss': float(average_loss),
                    'val_mse': float(val_metrics.mse),
                    'val_mae': float(val_metrics.mae),
                    'val_mape': float(val_metrics.mape),
                    'val_smape': float(val_metrics.smape),
                    'val_rmse': float(val_metrics.rmse),
                    'test_mse': float(test_metrics.mse),
                    'test_mae': float(test_metrics.mae),
                    'test_mape': float(test_metrics.mape),
                    'test_smape': float(test_metrics.smape),
                    'test_rmse': float(test_metrics.rmse),
                }
            print("---------------------------------")
        
        end_time = time.time()  # Record the end time
        total_training_time = end_time - start_time  # Calculate total training time
        print(f"Total training time: {total_training_time:.2f} seconds")
        
        result = result | {
            'train_epoch': train_epoch,
            'datset_name': dataset_name,
            'pooling_method': pooling_method,
            'seed': seed,
            'optimizer_name': optimizer_name,
            'lr': lr,
            'total_training_time': total_training_time,
            'batch_size': batch_size,
            'forecast_horizon': forecast_horizon,
            'model_size': model_size,
            'device': str(device),
            'use_mixed_precision': use_mixed_precision,
            'only_OT': only_OT,
            'checkpoint_path': checkpoint_path,
        }
        
        return result


    @staticmethod
    def run_exp(
        pooling_method_l: list,
        model_size_l: list,
        seed_l: list,
        dataset_folder: str,
        dataset_name_l: list,
        forecast_horizon_l: list,
        train_epoch: int, 
        result_folder: str,
        timestamp: str,
        optimizer_name: str = 'adam',
        batch_size: int = 32,
        lr: float = 1e-4,
        output_path: str = None,
        only_OT: bool = False,
        use_mixed_precision: bool = False,
    ):
        results = []
        total_experiments = (
            len(pooling_method_l) *
            len(model_size_l) *
            len(seed_l) *
            len(dataset_name_l) *
            len(forecast_horizon_l)
        )
        # Validate model_size_l
        valid_model_sizes = {'small', 'base', 'large'}
        if not all(size in valid_model_sizes for size in model_size_l):
            raise ValueError(f"Invalid model_size_l. Allowed values are: {valid_model_sizes}")
        
        if output_path == None:
            output_path = f'{result_folder}/forecast_exp_results_{timestamp}.jsonl'
        # Make sure the folder exists
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        # Optional: clear existing file
        with open(output_path, 'w') as f:
            pass

        # Iterate over the Cartesian product of the parameters with a progress bar
        # for pooling_method, dataset_name, seed, model_size, forecast_horizon in tqdm(
        #     product(pooling_method_l, dataset_name_l, seed_l, model_size_l, forecast_horizon_l),
        #     total=total_experiments,
        #     desc="Running Experiments"
        # ):
        
        for model_size, forecast_horizon, seed, dataset_name, pooling_method in tqdm(
            product(model_size_l, forecast_horizon_l, seed_l, dataset_name_l, pooling_method_l,),
            total=total_experiments,
            desc="Running Experiments"
        ):
            result = forecast_exp.run_one_exp(
                            pooling_method=pooling_method,
                            model_size=model_size,
                            train_epoch=train_epoch,
                            seed=seed,
                            dataset_folder=dataset_folder,
                            dataset_name=dataset_name,
                            optimizer_name=optimizer_name,
                            forecast_horizon=forecast_horizon,
                            batch_size=batch_size,
                            lr=lr,
                            use_mixed_precision=use_mixed_precision,
                            only_OT=only_OT,
                            result_folder=result_folder,
                            timestamp=timestamp,
                        )
            results.append(result)
        
            # Append to file
            if output_path:
                with open(output_path, 'a') as f:
                    f.write(json.dumps(result) + '\n')
    
        return results