import os
import time
import json
from itertools import product


import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, random_split
from momentfm.data.classification_dataset import ClassificationDataset
from momentfm import MOMENTPipeline
from pooling.src.device_utils import DeviceUtils


class cls_exp:
    @staticmethod
    def compute_accuracy(output, labels):
        """Function to compute accuracy"""
        preds = output.argmax(dim=1)  # Get predicted class
        return (preds == labels).float().mean().item()

    @staticmethod
    def evaluate_model(model, dataloader, criterion, device):
        """Evaluate the model on a given dataloader"""
        model.eval()
        total_loss, total_acc = 0, 0

        with torch.no_grad():
            for data, input_mask, labels in dataloader:
                data, input_mask, labels = data.float().to(device), input_mask.to(device), labels.to(device)

                output = model(x_enc=data, input_mask=input_mask)
                loss = criterion(output.logits, labels)

                total_loss += loss.item()
                total_acc += cls_exp.compute_accuracy(output.logits, labels)

        total_loss /= len(dataloader)
        total_acc /= len(dataloader)

        return total_loss, total_acc
    

    @staticmethod
    def get_classification_dataloaders(
        dataset_folder: str,
        dataset_name: str,
        batch_size: int = 64,
        val_split: float = 0.2,
        seed: int = 42
    ):
        """
        Loads a classification dataset and returns train/val/test DataLoaders.

        Args:
            dataset_folder (str): Path to the dataset directory.
            dataset_name (str): Dataset name.
            batch_size (int): Batch size for DataLoaders.
            val_split (float): Fraction of the training set used for validation.
            seed (int): Random seed for splitting.

        Returns:
            Tuple of (train_dataloader, val_dataloader, test_dataloader, num_class)
        """
        torch.manual_seed(seed)

        # Load full training and test datasets
        full_train_dataset = ClassificationDataset(
            dataset_folder=dataset_folder,
            dataset_name=dataset_name,
            data_split='train'
        )


        test_dataset = ClassificationDataset(
            dataset_folder=dataset_folder,
            dataset_name=dataset_name,
            data_split='test'
        )

        # Split train into train/val
        train_size = int((1 - val_split) * len(full_train_dataset))
        val_size = len(full_train_dataset) - train_size
        train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

        # Create DataLoaders
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

        num_class = len(np.unique(full_train_dataset.train_labels))
        assert full_train_dataset.len_timeseries == test_dataset.len_timeseries, "Train and test datasets have different lengths"
        ts_len = full_train_dataset.len_timeseries


        return train_dataloader, val_dataloader, test_dataloader, num_class, ts_len

    @staticmethod
    def run_one_exp(
        pooling_method: str,
        model_size: str,
        train_epoch: int = 10, 
        seed: int = 42,
        dataset_folder: str = '../data/Timeseries-PILE/classification/UCR',
        dataset_name: str = 'ECG200',
        optimizer_name: str = 'adam',
        lr: float = 1e-3,
        use_mixed_precision: bool = False
    ):
        torch.manual_seed(seed)
        print(f"launching experiment on {dataset_name} with model size {model_size} and pooling method {pooling_method}")

        train_dataloader, val_dataloader, test_dataloader, num_class, ts_len = \
            cls_exp.get_classification_dataloaders(
                dataset_folder=dataset_folder,
                dataset_name=dataset_name,
                batch_size=64,
                val_split=0.2,
                seed=seed
            )

        # Initialize model in classification mode
        model = MOMENTPipeline.from_pretrained(
            f"AutonLab/MOMENT-1-{model_size}",
            model_kwargs={
                'task_name': 'classification',
                'n_channels': 1,
                'num_class': num_class,
                'pooling_method': pooling_method,
                'freeze_encoder': True, # Freeze the patch embedding layer
                'freeze_embedder': True, # Freeze the transformer encoder
                'freeze_head': False, # The linear head must be trained
                'input_ts_len': ts_len,
            },
        )

        model.init()
        device = DeviceUtils.get_training_device()
        # Determine precision
        dtype = DeviceUtils.get_mixed_precision_dtype() if use_mixed_precision else torch.float32

        print(f"Using device: {device}, dtype: {dtype}")
        # scaler = GradScaler() if use_mixed_precision and device.type == 'cuda' else None
        scaler = GradScaler() if use_mixed_precision and device.type == 'cuda' else None


        model = model.to(device=device).float()

        # Loss & optimizer
        criterion = nn.CrossEntropyLoss()
        if optimizer_name == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=lr)
        elif optimizer_name == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unknown optimizer: {optimizer_name}")

        checkpoint_path = f"./checkpoints/cls_moment_{model_size}_dataset_{dataset_name}_seed_{seed}/best_model_with_{pooling_method}_pooling.pth"
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

        best_val_loss = float("inf")

        start_time = time.time()  # Record the start time
        
        for epoch in range(train_epoch):
            model.train()
            train_loss, train_acc = 0, 0

            for data, input_mask, labels in train_dataloader:
                data, input_mask, labels = data.float().to(device), input_mask.to(device), labels.to(device)
                
                optimizer.zero_grad()

                if use_mixed_precision and scaler is not None:
                    with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_mixed_precision):
                        output = model(x_enc=data, input_mask=input_mask)
                        loss = criterion(output.logits, labels)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    output = model(x_enc=data, input_mask=input_mask)
                    loss = criterion(output.logits, labels)
                    loss.backward()
                    optimizer.step()
                
                train_loss += loss.item()
                train_acc += cls_exp.compute_accuracy(output.logits, labels)

            train_loss /= len(train_dataloader)
            train_acc /= len(train_dataloader)

            val_loss, val_acc = cls_exp.evaluate_model(model, val_dataloader, criterion, device)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved at epoch {epoch+1} with val_loss: {val_loss:.4f}")

            print(f"Epoch [{epoch+1}/{train_epoch}] - "
                f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} - "
                f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        end_time = time.time()  # Record the end time
        total_training_time = end_time - start_time  # Calculate total training time
        print(f"Total training time: {total_training_time:.2f} seconds")

        # Reload the best model
        print("\nReloading the best model from checkpoint...")
        model.load_state_dict(torch.load(checkpoint_path))
        model = model.to(device)

        # Final evaluation
        train_loss, train_acc = cls_exp.evaluate_model(model, train_dataloader, criterion, device)
        val_loss, val_acc = cls_exp.evaluate_model(model, val_dataloader, criterion, device)
        test_loss, test_acc = cls_exp.evaluate_model(model, test_dataloader, criterion, device)

        print(f"\nFinal Evaluation (Best Model) with {pooling_method} pooling:")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
        print(f"----------------------------------------------------\n")

        result = {
            'dataset_name': dataset_name,
            'pooling_method': pooling_method,
            'seed': seed,
            'optimizer_name': optimizer_name,
            'train_epoch': train_epoch,
            'train_loss': train_loss, 'train_acc': train_acc,
            'val_loss': val_loss, 'val_acc': val_acc,
            'test_loss': test_loss, 'test_acc': test_acc,
            'model_size': model_size,
            'learning_rate': lr,
            'total_training_time': total_training_time,
            'use_mixed_precision': use_mixed_precision,
            'device': str(device),
            'dtype': str(dtype),
        }
        return result

    @staticmethod
    def run_exp(
        pooling_methods: list,
        dataset_names: list,
        seed_list: list,
        optimizer_name_list: list,
        dataset_folder: str = '../data/Timeseries-PILE/classification/UCR',
        train_epoch: int = 10,
        model_size_list: list = None,
        use_mixed_precision: bool = False,
        output_path: str = None,
        lr: float = 1e-4,
    ):
        results = []
        total_experiments = (
            len(pooling_methods) *
            len(dataset_names) *
            len(seed_list) *
            len(optimizer_name_list) *
            len(model_size_list)
        )
        # Validate model_size_list
        valid_model_sizes = ('small', 'base', 'large')
        if not all(size in valid_model_sizes for size in model_size_list):
            raise ValueError(f"Invalid model_size_list. Allowed values are: {valid_model_sizes}")
        
        if output_path:
            # Make sure the folder exists
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            # Optional: clear existing file
            with open(output_path, 'w') as f:
                pass
                
    
        # Iterate over the Cartesian product of the parameters with a progress bar
        for pooling_method, dataset_name, seed, optimizer_name, model_size in tqdm(
            product(pooling_methods, dataset_names, seed_list, optimizer_name_list, model_size_list),
            total=total_experiments,
            desc="Running Experiments"
        ):
            result = cls_exp.run_one_exp(
                pooling_method=pooling_method,
                dataset_name=dataset_name,
                seed=seed,
                optimizer_name=optimizer_name,
                dataset_folder=dataset_folder,
                train_epoch=train_epoch,
                model_size=model_size,
                use_mixed_precision=use_mixed_precision,
                lr=lr,
            )
            results.append(result)
        
            # Append to file
            if output_path:
                with open(output_path, 'a') as f:
                    f.write(json.dumps(result) + '\n')
    
        return results