from calendar import c
import datetime
import math
import os
from pathlib import Path
import random
import yaml

import numpy as np
import torch
from torch import vmap
from torch._functorch.eager_transforms import jacrev, jacfwd
from torch._functorch.functional_call import functional_call
from models.lip_ffn import LipschitzConditionalFFN
from models.lip_mlp import CondLipMLP
from models.lip_siren import CondLipSIREN
from models.net_w_partials import NetWithPartials
from models.siren import LatentModulatedSiren, ConditionalSIREN
from models.NN import ConditionalGeneralNet, ConditionalGeneralResNet, GeneralNet, GeneralNetBunny, GeneralNetPosEnc, GeneralResNet
from models.wire import ConditionalWIRE


def set_all_seeds(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def flip_arrow(xy, dxy):
    return xy + dxy, -1 * dxy

def set_and_true(key, config):
    return (key in config) and config[key]

def set_else_default(key, config, val):
    """Returns the value in config if specified, otherwise returns a default value"""
    return config[key] if key in config else val

def get_is_out_mask(x, bounds):
    out_mask = (x < bounds[:, 0]).any(1) | (x > bounds[:, 1]).any(1)
    return out_mask

def find_model_file(run_id, root_dir, use_epoch=None, get_file='model.pt'):
    # search in all subdirectories
    candidate_files = []
    for subdir, dirs, files in os.walk(root_dir):
        for file in files:
            if run_id in file and get_file in file:
                candidate_files.append(os.path.join(subdir, file))
    
    if len(candidate_files) > 1:
        print(f"Found {len(candidate_files)} candidate files")
        for file in candidate_files:
            print(file)
    
    if use_epoch is None:
        ## return file from latest epoch
        def get_epoch_from_filename(file):
            spl = file.split('-')
            for s in spl:
                if 'it_' in s:
                    return int(s.split('_')[-1])
        ## sort by epoch
        sorted_by_epoch = sorted(candidate_files, key=get_epoch_from_filename)
        return sorted_by_epoch[-1]
        
    else:
        for file in candidate_files:
            if f'it_{use_epoch:05}' in file:
                return file
    
    raise ValueError(f"Could not find model with {run_id=} in {root_dir=}")

def is_every_n_epochs_fulfilled(epoch, config, key):
    if key not in config:
        return False
    is_fulfilled = (epoch % config[key] == 0) or (epoch == config['max_epochs'] - 1)
    return is_fulfilled

def save_model_every_n_epochs(model, optim, sched, config, epoch):
    if not is_every_n_epochs_fulfilled(epoch, config, 'save_every_n_epochs'):
        return False, None
    
    assert 'model_save_path' in config, 'model_save_path must be specified in config'
    
    ## create directory to save models if it does not exist
    if 'save_model_dir' not in config:
        name_stem = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + '-' + config['wandb_id']
        model_parent_path = os.path.join(config['model_save_path'], config['model'], name_stem)
        config['save_model_dir'] = model_parent_path
        
        ## create parent directory if it does not exist
        if not os.path.exists(model_parent_path):
            Path(model_parent_path).mkdir(parents=True, exist_ok=True)
    model_parent_path = config['save_model_dir']
    name_stem = model_parent_path.split('/')[-1]
    
    ## add epoch to filename if not overwriting
    if not config['overwrite_existing_saved_model']:
        name_stem += f'-it_{epoch:05}'
    
    ## save model
    model_filename = name_stem + '-model.pt'
    model_path = os.path.join(model_parent_path, model_filename)
    torch.save(model.state_dict(), model_path)
    
    # save config as yml
    config_filename = name_stem + '-config.yml'
    config_path = os.path.join(model_parent_path, config_filename)
    if not os.path.exists(config_path):
        with open(config_path, 'w') as file:
            yaml.dump(config, file)
    
    ## save optimizer
    if config.get('save_optimizer', False):
        optim_filename = name_stem + '-optim.pt'
        optim_path = os.path.join(model_parent_path, optim_filename)
        torch.save(optim.state_dict(), optim_path)
        
        if config.get('use_scheduler', False):
            ## save scheduler (if used)
            sched_filename = name_stem + '-sched.pt'
            sched_path = os.path.join(model_parent_path, sched_filename)
            torch.save(sched.state_dict(), sched_path)

    return True, model_path
    
def n_rows_cols(n_shapes, flatten=False):
        ''' Returns the number of n_rows and columns. Favor more n_rows over n_colsto better show in wandb side-by-side '''
        if flatten:
            return n_shapes, 1
        n_cols = int(math.sqrt(n_shapes))
        n_rows = int(math.ceil(n_shapes / n_cols))
        return n_rows, n_cols
    
def load_model_optim_sched(config, model, optim, sched):
    ## Load model weights
    if not (set_and_true('load_model', config) or set_and_true('load_mos', config)):
        if 'model_load_path' in config or 'model_load_wandb_id' in config:
            print('WARNING: model_load_path or model_load_wandb_id specified but load_model is False. Ignoring.')
        return model, optim, sched
        
    if 'model_load_wandb_id' in config:
        assert 'model_load_path' not in config, 'model_load_path and model_load_wandb_id cannot be specified at the same time'
        model_load_path = get_model_path_via_wandb_id_from_fs(config['model_load_wandb_id'])
        print(f'Loading model from {model_load_path}...')
        model.load_state_dict(torch.load(model_load_path))
        
        ## Load optimizer and scheduler
        if config.get('load_optimizer', False) or config.get('load_mos', False):
            optim_load_path = model_load_path.replace('-model.pt', '-optim.pt')
            print(f'Loading optimizer from {optim_load_path}...')
            optim.load_state_dict(torch.load(optim_load_path))
        
        if config.get('use_scheduler', False) or config.get('load_mos', False):
            sched_load_path = model_load_path.replace('-model.pt', '-sched.pt')
            print(f'Loading scheduler from {sched_load_path}...')
            sched.load_state_dict(torch.load(sched_load_path))
        
    elif 'model_load_path' in config:
        assert config.get('load_optimizer', False) == False, 'load_optimizer is not supported with model_load_path'
        ## Load model from path
        model.load_state_dict(torch.load(config['model_load_path']))
    else:
        raise ValueError('model_load_path or model_load_wandb_id must be specified if load_model is True')
    
        
    return model, optim, sched

def get_model_path_via_wandb_id_from_fs(run_id, base_dir='./', use_epoch=None, get_file='model.pt'):
    
    dirs = [MODELS_PARENT_DIR]
    if base_dir is not None: dirs.append(base_dir)
    for root_dir in dirs:
        try:
            file_path = find_model_file(run_id, root_dir, use_epoch=use_epoch, get_file=get_file)
            print(f'Found model at {file_path}')
            return file_path
        except Exception as e:
            print(e)
    raise ValueError(f"Could not find model with {run_id=} anywhere")

def generate_color(i):
    r = (i & 1) * 255       # Red when i is odd
    g = ((i >> 1) & 1) * 255 # Green when the second bit of i is 1
    b = ((i >> 2) & 1) * 255 # Blue when the third bit of i is 1
    return r + (g << 8) + (b << 16)

def precompute_sample_grid(n_points, bounds):
        '''
        An equidistant grid of points is computed. These are later taken as starting points to discover critical points
        via gradient descent. The number of total points defined in config['n_points_find_cps'] is equally distributed
        among all dimensions.
        :return:
        xc_grid: the 2d or 3d grid of equidistant points over the domain
        xc_grid_dist: the distance as a 2d or 3d vector neighbors along the respective dimension
        '''
        nx = bounds.shape[0]
        
        n_points_root = int(math.floor(math.pow(n_points, 1 / nx)))
        dist_along_a_dim = 1 / (n_points_root + 1)
        xi_range = torch.arange(start=0, end=n_points_root, step=1) * dist_along_a_dim + dist_along_a_dim / 2
        if nx == 2:
            x1_grid, x2_grid = torch.meshgrid(xi_range, xi_range, indexing="ij")
            xc_grid = torch.stack((x1_grid.reshape(-1), x2_grid.reshape(-1)), dim=1)
        elif nx == 3:
            x1_grid, x2_grid, x3_grid = torch.meshgrid(xi_range, xi_range, xi_range, indexing="ij")
            xc_grid = torch.stack((x1_grid.reshape(-1), x2_grid.reshape(-1), x3_grid.reshape(-1)), dim=1)
        xc_grid = bounds[:, 0] + (bounds[:, -1] - bounds[:, 0]) * xc_grid
        xc_grid_dist = torch.tensor(dist_along_a_dim).repeat(nx) * (bounds[:, -1] - bounds[:, 0])
        return xc_grid.to(torch.float32), xc_grid_dist.to(torch.float32)
    

def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])


def get_stateless_net_with_partials(model, nz=0):
    """
    Returns the stateless representation of a torch model,
    including the vectorized Jacobian and Hessian matrices.
    """

    ## Parameters for stateless model
    params = dict(model.named_parameters())

    ## Stateless model
    if nz == 0:
        def f(params, x):
            """
            Stateless call to the model. This works for
            1) single inputs:
            x: [nx]
            returns: [ny]
            -- and --
            2) batch inputs:
            x: [bx, nx]
            returns: [bx, ny]
            """
            return functional_call(model, params, x)
        
        ## Jacobian
        f_x = jacrev(f, argnums=1)  ## params, [nx] -> [nx, ny]
        vf_x = vmap(f_x, in_dims=(None, 0), out_dims=(0))  ## params, [bx, nx] -> [bx, ny, nx]
        ## Hessian
        f_xx = jacfwd(f_x, argnums=1)  ## params, [nx] -> [nx, ny, nx]
        vf_xx = vmap(f_xx, in_dims=(None, 0), out_dims=(0))  ## params, [bx, nx] -> [bx, ny, nx, nx]
        vf_z = None
    else:
        def f(params, x, z):
            return functional_call(model, params, (x, z))

        ## Note the difference: in the in_dims and out_dims we want to vectorize in the 0-th dimension
        ## Jacobian
        f_x = jacrev(f, argnums=1)  ## params, [nx], [nz] -> [nx, ny]
        vf_x = vmap(f_x, in_dims=(None, 0, 0), out_dims=(0))  ## params, [bxz, nx], [bxz, nz] -> [bxz, ny, nx]
        ## Hessian
        f_xx = jacfwd(f_x, argnums=1)  ## params, [nx], [nz] -> [nx, ny, nx]
        vf_xx = vmap(f_xx, in_dims=(None, 0, 0), out_dims=(0))  ## params, [bxz, nx], [bxz, nz] -> [bxz, ny, nx, nx]
        ## Jacobian wrt z
        f_z = jacrev(f, argnums=2)  ## params, [nx], [nz] -> [nz, ny]
        vf_z = vmap(f_z, in_dims=(None, 0, 0), out_dims=(0))  ## params, [bxz, nx], [bxz, nz] -> [bxz, ny, nz]

        # f_theta = jacrev(f, argnums=0)  ## params, [nx], [nz] -> [nx, n_params]
        # vf_theta = vmap(f_theta, in_dims=(None, 0, 0), out_dims=(0))  ## params, [bxz, nx], [bxz, nz] -> [bxz, nx, n_params]

    netp = NetWithPartials(f, f_x, vf_x, f_xx, vf_xx, params, vf_z)

    return netp


def get_activation(act_str):
    if act_str == 'relu':
        activation = torch.relu
    elif act_str == 'softplus':
        activation = torch.nn.Softplus(beta=10)
    elif act_str == 'celu':
        activation = torch.celu
    elif act_str == 'sin':
        activation = torch.sin
    elif act_str == 'tanh':
        activation = torch.tanh
    else:
        activation = None
        # print(f'activation not set')

    return activation

def do_plot(config, epoch, key=None):
        """
        Checks if the plot specified by the key should be produced.
        First checks if the output is set: if not, no need to produce the plot.
        Then, checks if the key is set and true.
        If the key is not specified, only checks the global plotting behaviour.
        """
        is_output_set = config['fig_show'] or config['fig_save'] or config['fig_wandb']

        # if no output is set, no need to produce the plot
        if is_output_set == False:
            return False
        
        # if no key is specified, only check the global plotting behaviour
        if key is None:
            return is_output_set
        
        # otherwise, check if the key is set
        if key not in config:
            return False
        
        val = config[key]
        if isinstance(val, bool):
            return val
        
        if isinstance(val, dict):
            # if the val is a tuple, it is (do_plot, plot_interval)
            if val['on'] == False:
                return False
            else:
                assert val['interval'] % config['plot_every_n_epochs'] == 0, f'plot_interval must be a multiple of plot_every_n_epochs'
                return (epoch % val['interval'] == 0) or (epoch == config['max_epochs'])
        
        raise ValueError(f'Unknown value for key {key}: {val}')


def get_model(config):
    
    model_str = config['model']
    activation = get_activation(config.get('activation', None))
    layers = config['layers']
    layers.insert(0, config['nx'] + config['nz'])  ## input layer
    layers.append(1)  ## output layer
    
    if model_str == 'siren':
        raise ValueError('SIREN is not supported yet; the layers and in_features are not well defined')
        # model = SIREN(layers=layers, in_features=config['nx'], out_features=1, w0=config['w0'], w0_initial=config['w0_initial'])
    elif model_str == 'cond_siren':
        model = ConditionalSIREN(layers=layers, w0=config['w0'], w0_initial=config['w0_initial'])
    elif model_str == 'lip_siren':
        model = CondLipSIREN(layers=layers, nz=config['nz'], w0=config['w0'], w0_initial=config['w0_initial'])
    elif model_str == 'lip_mlp':
        model = CondLipMLP(layers=layers)
    elif model_str == 'comod_siren':
        model = LatentModulatedSiren(layers=layers, w0=config['w0'], w0_initial=config['w0_initial'], latent_dim=config['nz'])
    elif model_str == 'cond_wire':
        model = ConditionalWIRE(layers=layers, first_omega_0=config['w0_initial'], hidden_omega_0=config['w0'], scale=config['wire_scale'])
    elif model_str == 'general_net':
        model = GeneralNet(ks=config['layers'], act=activation)
    elif model_str == 'general_resnet':
        model = GeneralResNet(ks=config['layers'], act=activation)
    elif model_str == 'cond_general_net':
        model = ConditionalGeneralNet(ks=config['layers'], act=activation)
    elif model_str == 'cond_general_resnet':
        model = ConditionalGeneralResNet(ks=config['layers'], act=activation)
    elif model_str == 'cond_ffn':
        model = ConditionalFFN(layers=layers, nz=config['nz'], n_ffeat=config['n_ffeat'], sigma=config['ffn_sigma'])
    elif model_str == 'lip_ffn':
        model = LipschitzConditionalFFN(layers=layers, nz=config['nz'], n_ffeat=config['n_ffeat'], sigma=config['ffn_sigma'])
    elif model_str == 'general_net_posenc':
        enc_dim = 2*config['nx']*config['N_posenc']
        model = GeneralNetPosEnc(ks=[config['nx'], enc_dim, 20, 20, 1])
    elif model_str == 'bunny':
        model = GeneralNetBunny(act='sin')
    else:
        raise ValueError(f'model not specified properly in config: {config["model"]}')
    return model


### FORMATTING

def format_significant_digits(number, n_digits):
    if number == 0:
        return "0.0"

    # Find the exponent to adjust the precision
    exponent = int(f"{number:e}".split('e')[-1])
    
    # Format the number to show the first non-zero digit followed by (n_digits - 1) additional digits
    n_decimals = max(n_digits - exponent - 1, n_digits)
    formatted_number = f"{number:.{n_decimals}f}"
    
    # Return the formatted number
    return formatted_number


def inflate_bounds(bounds, amount=0.10):
    """
    Inflate a bounding box by the specified fraction of the extent on each side.
    To finding good surface points, the bounding box should not be tight.
    This is because
        (A) we want to sample initial points around the surface which cannot be fulfilled if the surface is near bbox,
        (B) point trajectories may overshoot the surface and leave the bbox getting filtered.
    This is not required if the bbox is not tight.
    """
    lengths = bounds[:,1] - bounds[:,0]
    bounds_ = bounds.clone()
    bounds_[:,0] -= lengths*amount
    bounds_[:,1] += lengths*amount
    return bounds_


def load_yaml_and_drop_keys(file_path, keys_to_drop):
    def remove_key_from_yaml(yaml_text, key_to_remove):
        lines = yaml_text.split('\n')
        modified_lines = []
        skip = False

        for line in lines:
            stripped_line = line
            if stripped_line.startswith(key_to_remove + ':'):
                skip = True
            elif skip and stripped_line.startswith('-'):
                continue
            elif skip and stripped_line.startswith('  '):
                continue
            else:
                skip = False
                modified_lines.append(line)

        return '\n'.join(modified_lines)

    # Read the YAML file as text
    with open(file_path, 'r') as file:
        yaml_text = file.read()
    for key_to_remove in keys_to_drop:
        modified_yaml_text = remove_key_from_yaml(yaml_text, key_to_remove)
    config = yaml.safe_load(modified_yaml_text)
    
    return config