import hydra
import torch
from pathlib import Path
from omegaconf import DictConfig, ListConfig
import logging
from tt_sbi.tasks import get_task

log = logging.getLogger(__name__)

def move_to_cpu(data_dict):
    new_dict = {}
    for k, v in data_dict.items():
        if isinstance(v, torch.Tensor):
            new_dict[k] = v.cpu()
        else:
            new_dict[k] = v
    return new_dict

@hydra.main(version_base=None, config_path="../configs", config_name="gaussian_config")
def main(cfg: DictConfig):
    log.info(f"Initializing task: {cfg.model_type}")
    task = get_task(cfg)
    
    dim = getattr(cfg, "dim", "X")
    n_obs = getattr(cfg, "n_obs", "Y")
    if cfg.get("n_train", 0) > 0:
        log.info(f"Generating {cfg.n_train} training samples...")
        output_dir = Path(cfg.get("train_data_dir", cfg.get("data_dir", "data")))
        output_dir.mkdir(parents=True, exist_ok=True)
        
        train_data = task.generate_train_data(n_samples=cfg.n_train, seed=cfg.seed)
        train_data = move_to_cpu(train_data)
        
        save_path = output_dir / f"train_d{dim}_n{n_obs}.pt"
        torch.save(train_data, save_path)
        log.info(f"Saved training data to: {save_path}")

    if cfg.get("n_test", 0) > 0:
        log.info(f"Generating {cfg.n_test} test samples...")
        output_dir = Path(cfg.get("test_data_dir", "data/test"))
        output_dir.mkdir(parents=True, exist_ok=True)
        
        scenarios = []

        if "misspec_scenarios" in cfg:
            for scenario in cfg.misspec_scenarios:
                scenarios.append(scenario)
        
        elif "contamination_levels" in cfg:
            scenarios.append({
                "type": "contamination",
                "levels": cfg.contamination_levels
            })
        
        else:
            scenarios.append({
                "type": "none", 
                "levels": [0.0]
            })

        for scen in scenarios:
            m_type = scen["type"]
            levels = scen["levels"]
            
            if isinstance(levels, (float, int)):
                levels = [levels]

            for val in levels:
                log.info(f"Processing misspecification: {m_type} = {val}")
                
                misspec_cfg = {"type": m_type}
                
                if m_type == "contamination":
                    misspec_cfg["contamination_eps"] = val
                    misspec_cfg["multiplier"] = cfg.get("misspec", {}).get("multiplier", 0.95)
                    misspec_cfg["contamination_shift"] = cfg.get("misspec", {}).get("contamination_shift", 2.0)
                    misspec_cfg["var_contam"] = cfg.get("misspec", {}).get("var_contam", None)

                elif m_type == "contamination_shift":
                    fixed_eps = scen.get("fixed_epsilon", 0.2)
                    
                    misspec_cfg["type"] = "contamination"
                    misspec_cfg["contamination_eps"] = fixed_eps
                    misspec_cfg["contamination_shift"] = val 
                    
                    suffix = f"contam_{fixed_eps}_shift_{val}"
                
                elif m_type == "prior_location":
                    misspec_cfg["prior_location_shift"] = val
                    
                elif m_type == "prior_scale":
                    misspec_cfg["prior_scale_factor"] = val
                    
                elif m_type == "likelihood_scale":
                    misspec_cfg["likelihood_scale_factor"] = val
                
                elif m_type == "none" or val == 0.0:
                    misspec_cfg = {} 

                test_data = task.generate_test_data(
                    n_samples=cfg.n_test, 
                    seed=cfg.seed + 1000, 
                    misspec_cfg=misspec_cfg
                )
                
                test_data["misspec_type"] = m_type
                test_data["misspec_value"] = val
                test_data["config"] = {
                    "model_type": cfg.model_type,
                    "n_test": cfg.n_test,
                    "misspec": misspec_cfg
                }
                
                test_data = move_to_cpu(test_data)

                if m_type == "contamination":
                    suffix = f"contam_{val}"
                elif m_type == "contamination_shift":
                    fixed_eps = scen.get("fixed_epsilon", 0.2)
                    suffix = f"contam_{fixed_eps}_shift_{val}"
                elif m_type == "prior_location":
                    suffix = f"prior_loc_{val}"
                elif m_type == "prior_scale":
                    suffix = f"prior_scale_{val}"
                elif m_type == "likelihood_scale":
                    suffix = f"lik_scale_{val}"
                else:
                    suffix = "well_specified"
                
                if val == 0.0 and m_type != "none":
                    suffix = "well_specified"

                filename = f"test_d{dim}_n{n_obs}_{suffix}.pt"
                save_path = output_dir / filename
                
                torch.save(test_data, save_path)
                log.info(f"Saved: {save_path}")

if __name__ == "__main__":
    main()