from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn 

from coral.siren import ModulatedSiren, ModulatedSirenGridMix


def create_inr_instance(cfg, input_dim=1, output_dim=1, device="cuda"):
    device = torch.device(device)
    if cfg.inr.model_type == "siren":
        inr = ModulatedSiren(
            dim_in=input_dim,
            dim_hidden=cfg.inr.hidden_dim,
            dim_out=output_dim,
            num_layers=cfg.inr.depth,
            w0=cfg.inr.w0,
            w0_initial=cfg.inr.w0,
            use_bias=True,
            modulate_scale=cfg.inr.modulate_scale,
            modulate_shift=cfg.inr.modulate_shift,
            use_latent=cfg.inr.use_latent,
            latent_dim=cfg.inr.latent_dim,
            modulation_net_dim_hidden=cfg.inr.hypernet_width,
            modulation_net_num_layers=cfg.inr.hypernet_depth,
            last_activation=cfg.inr.last_activation,
        ).to(device)
    
    elif cfg.inr.model_type == "siren_gridmix":
        inr = ModulatedSirenGridMix(
            dim_in=input_dim,
            dim_hidden=cfg.inr.hidden_dim,
            dim_out=output_dim,
            num_layers=cfg.inr.depth,
            w0=cfg.inr.w0,
            w0_initial=cfg.inr.w0,
            use_bias=True,
            modulate_scale=cfg.inr.modulate_scale,
            modulate_shift=cfg.inr.modulate_shift,
            use_latent=cfg.inr.use_latent,
            latent_dim=cfg.inr.latent_dim,
            modulation_net_dim_hidden=cfg.inr.hypernet_width,
            modulation_net_num_layers=cfg.inr.hypernet_depth,
            last_activation=cfg.inr.last_activation,
            use_norm=cfg.inr.use_norm,
            grid_size=cfg.inr.grid_size,
            siren_init=cfg.inr.siren_init,
            grid_base=cfg.inr.grid_base,
            grid_sum=cfg.inr.grid_sum,
            share_grid=cfg.inr.share_grid,
            grid_size_2=cfg.inr.grid_size_2,
        ).to(device)
    else:
        raise NotImplementedError(f"No corresponding class for {cfg.inr.model_type}")

    return inr


def load_inr_model(
    run_dir, run_name, data_to_encode, input_dim=1, output_dim=1, device="cuda"
):  
    inr_train = torch.load(run_dir / f"{run_name}.pt")

    inr_state_dict = inr_train["inr"]
    cfg = inr_train["cfg"]
    alpha = inr_train["alpha"]
    print(f'{run_name}, epoch {inr_train["epoch"]}, alpha1 {alpha.item()}')
    inr = create_inr_instance(cfg, input_dim, output_dim, device)
    # inr.load_state_dict(inr_state_dict)
    new_state_dict = {}
    for name, para in inr_state_dict.items():
        if 'module.' in name:
            new_state_dict[name[7:]] = para 
        else:
            new_state_dict[name] = para 
    inr.load_state_dict(new_state_dict)
    inr.eval()

    return inr, alpha