import wandb
import torch
from torch.utils.data import DataLoader

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf, DictConfig

import omegaconf
from dataclasses import dataclass, field
from typing import Optional

import os
from pathlib import Path

from lib.utils import set_seed
from lib.jepa_acssm import RJEPAACSSM

from source.data.dataset import ROIPretrainDatasetSplit
from source.data.config import ROIDatasetConfig, TrainingConfig, ModelConfig

@dataclass
class Config:
    data: ROIDatasetConfig = field(default_factory=ROIDatasetConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    experiment_name: Optional[str] = None
    device: Optional[str] = None

def flatten_trainer_config(cfg: DictConfig) -> dict:
    """Flatten nested config into a single dictionary."""
    flat_dict = {}
    
    # Convert DictConfig to regular dict and flatten
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)
    
    # Update with each config section
    flat_dict.update(cfg_dict.get('training', {}))
    flat_dict.update(cfg_dict.get('model', {}))
    
    flat_dict['device'] = cfg_dict.get('device', None)

    return flat_dict

cs = ConfigStore.instance()
cs.store(name="config", node=Config)

@hydra.main(config_path="cfg/pretrain", config_name="bdo", version_base=None)
def main(cfg: Config) -> None:
    # Setup
    set_seed(cfg.training.random_seed)
    cfg.device = f"cuda:{cfg.training.gpus[0]}" # Use the first GPU in the list for initial placement

    if cfg.experiment_name is None:
        raise ValueError("experiment_name must be provided")
        
    checkpoint_path = Path(f"./checkpoints2/{cfg.experiment_name}")
    os.makedirs(checkpoint_path, exist_ok=True)
    
    # Initialize dataset
    dataset_kwargs = OmegaConf.to_container(cfg.data, resolve=True)
    train_roi_dataset = ROIPretrainDatasetSplit(split="train", **dataset_kwargs)
    test_roi_dataset = ROIPretrainDatasetSplit(split="test", **dataset_kwargs)

    train_loader = DataLoader(
        train_roi_dataset,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=cfg.training.num_workers,
        pin_memory=cfg.training.pin_memory,
        drop_last=True
    )
    test_loader = DataLoader(
        test_roi_dataset,
        batch_size=cfg.training.batch_size,
        shuffle=False,
        num_workers=cfg.training.num_workers,
        pin_memory=cfg.training.pin_memory
    )
    
    # Flatten config for ACSSM
    cfg.training.num_steps = len(train_loader) * cfg.training.num_epochs
    trainer_cfg = flatten_trainer_config(cfg)

    # Initialize model with flattened config
    model = RJEPAACSSM(OmegaConf.create(trainer_cfg))
    print(f"# param of encoder: {model.encoder_params}")
    print(f"# param of predictor: {model.predictor_params}")
    
    config_path = Path(checkpoint_path).joinpath('config.yaml')
    OmegaConf.save(cfg, config_path)

    project_name = f"BDO-ROI={cfg.data.roi}"
    # Initialize wandb
    wandb.init(
        project=project_name,
        config=OmegaConf.to_container(cfg, resolve=True),
        save_code=True,
        mode="online" if not cfg.debug else "offline",
        name=cfg.experiment_name
    )
    
    wandb.define_metric("train/step")
    wandb.define_metric("train/*", step_metric="train/step")
    wandb.define_metric("epoch/epoch")
    wandb.define_metric("epoch/*", step_metric="epoch/epoch")
    
    for epoch in range(cfg.training.num_epochs):
        epoch_dict = {}
        train_epoch_dict = model.train(train_loader, epoch=epoch)

        for key, val in train_epoch_dict.items() :
            train_epoch_dict[key] = val / train_epoch_dict["train/num_samples"]
        
        print('-' * 50)
        print(f"[Train] Epoch {epoch+1}/{cfg.training.num_epochs}")
        if "train_mask_nll_recon" in train_epoch_dict.keys() :
            print(f'recon-NLL: {train_epoch_dict["train_mask_nll_recon"]:.6f} || recon-MSE: {train_epoch_dict["train_mask_mse_recon"]:.6f}')
        if "train_mask_nll_jepa" in train_epoch_dict.keys() :  
            print(f'jepa-NLL: {train_epoch_dict["train_mask_nll_jepa"]:.6f} || jepa-MSE: {train_epoch_dict["train_mask_mse_jepa"]:.6f}')            
        
        if (epoch + 1) % cfg.training.num_eval_epochs == 0 or epoch + 1 == cfg.training.num_epochs :

            test_epoch_dict = model.evaluate(test_loader, epoch=epoch)

            for key, val in test_epoch_dict.items() :
                test_epoch_dict[key] = val / test_epoch_dict["test/num_samples"]
            
            print(f"[Test] Epoch {epoch+1}/{cfg.training.num_epochs}")
            if "test_mask_nll_recon" in test_epoch_dict.keys() :
                print(f'recon-NLL: {test_epoch_dict["test_mask_nll_recon"]:.6f} || recon-MSE: {test_epoch_dict["test_mask_mse_recon"]:.6f}')
            if "test_nll_jepa" in test_epoch_dict.keys() :  
                print(f'jepa-NLL: {test_epoch_dict["test_nll_jepa"]:.6f} || jepa-MSE: {test_epoch_dict["test_mse_jepa"]:.6f}')            

            epoch_dict.update(train_epoch_dict)
            epoch_dict.update(test_epoch_dict)
            epoch_dict.update({"epoch/epoch": epoch})
            
            wandb.log(epoch_dict)
            
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.jepa.state_dict(),
                'optimizer_state_dict': model.optimizer.state_dict(),
            }
            if cfg.training.use_scheduler :
                checkpoint.update({'scheduler_state_dict': model.scheduler.state_dict()})
            
            if not cfg.save_only_best:
                torch.save(
                    checkpoint,
                    checkpoint_path.joinpath(f'{cfg.experiment_name}_{epoch+1}.pt')
                )
                
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.jepa.state_dict(),
    }
    torch.save(
        checkpoint,
        checkpoint_path.joinpath(f'{cfg.experiment_name}_last.pt')
    )
    
    wandb.finish()
    print(f"experiment name : {cfg.experiment_name}")
    

if __name__ == "__main__":
    main()