import os
from copy import deepcopy
from omegaconf import OmegaConf, open_dict
from overrides import overrides
import omegaconf
import wandb
import pathlib
import pickle
import torch
import sys
from pytorch_lightning.utilities import rank_zero_only
from torch_geometric.loader import DataLoader
import pickle
from src.utils import graph
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
import logging
from collections import OrderedDict
import yaml
import re
import copy

from src.datasets import supernode_dataset
from src.diffusion.diffusion_rxn import DiscreteDenoisingDiffusionRxn

log = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

task_to_class_and_model = {                     
    'rxn': {'data_class': supernode_dataset,
            'model_class': DiscreteDenoisingDiffusionRxn}
}

def parse_value(value):
    # Check for boolean format
    if value.lower() == 'true':
        return True
    elif value.lower() == 'false':
        return False
    
    # Check for list format
    if value.startswith('[') and value.endswith(']'):
        return [convert_to_number(v.strip()) for v in value[1:-1].split(',')]
    else:
        # Attempt to convert to int, or leave as string if it fails
        return convert_to_number(value)

def convert_to_number(s):
    try:
        return int(s)
    except ValueError:
        try:
            return float(s)
        except ValueError:
            return s  # Return as string if it's not a number
        
def capture_cli_overrides():
    """
    Capture the command-line arguments that represent configuration overrides.
    This function assumes that command-line overrides are in the format 'key=value'.
    """
    cli_args = sys.argv[1:]  # Exclude the script name
    overrides = {}
    for arg in cli_args:
        if '=' in arg:
            key, value = arg.split('=', 1)
            # Check if the value is a list
            value = parse_value(value)
            # Convert the key to Hydra's nested configuration format
            nested_keys = key.split('.')
            nested_dict = overrides
            for nested_key in nested_keys[:-1]:
                nested_dict = nested_dict.setdefault(nested_key, {})
            nested_dict[nested_keys[-1]] = value
                
    return OmegaConf.create(overrides)

def get_batches_from_datamodule(cfg, parent_path, datamodule):
    if cfg.train.batch_by_size:
        data_list_path = os.path.join(parent_path, datamodule.datadir, 'processed', 'train.pickle')
        train_data = pickle.load(open(data_list_path, 'rb'))
        batches, sizes_found = graph.batch_graph_by_size(input_data_list=train_data, 
                                                         size_bins=cfg.dataset.size_bins['train'], 
                                                         batchsize_bins=cfg.dataset.batchsize_bins['train'],
                                                         get_batches=True)
    else:
        batches = [b for b in datamodule.train_dataloader()]
        
    assert len(batches)>0, 'No batches.'
    
    return batches
        
def save_file_as_artifact_to_wandb(run, artifactname='default', alias='epoch00', type_='model', filepath=None, filename='model.pt'):
    '''
        Uploads model weights as artifact to wandb and returns the run id.
    '''
    filepath = filename if filepath is None else filepath
    artifact = wandb.Artifact(artifactname, type=type_)
    artifact.add_file(filepath, name=filename)
    run.log_artifact(artifact, aliases=[alias])
        
def mkdir_p(dir):
    '''make a directory (dir) if it doesn't exist'''
    if not os.path.exists(dir):
        os.makedirs(dir)

def merge_configs(default_cfg, new_cfg, cli_overrides):
    # Merges the new_cfg to the default_cfg, and then merges the cli_overrides to the result.
    # e.g. scenario: 
    # new_cfg.neuralnet.n_layers=5, default_cfg.neuralnet.n_layers=None, 
    # upload_artifact does not exist in new_cfg, default_cfg.wandb.upload_artifact=True, 
    # new_cfg.test.n_conditions=8, default_cfg.test.n_conditions=3, (cli) default_cfg.test.n_conditions=10
    # we want the following: cfg.neuralnet.n_layers=5, cfg.wandb.upload_artifact=True, cfg.test.n_conditions=10
    default_cfg_ = copy.deepcopy(default_cfg)
    default_cfg_.neuralnet.p_to_r_skip_connection
    OmegaConf.set_struct(default_cfg_, False)
    # 1. merge run.config with default cfg to get any new missing fields
    merged_cfg = OmegaConf.merge(default_cfg_, new_cfg)
    merged_cfg.neuralnet.p_to_r_skip_connection
    # 2. override result with cli_overrides because they might be overriden by run.config
    merged_cfg = OmegaConf.merge(merged_cfg, cli_overrides)
    merged_cfg.neuralnet.p_to_r_skip_connection
    
    return merged_cfg

def load_wandb_config(cfg):
    assert cfg.general.wandb.run_id is not None, f'Need to give run_id here. Got cfg.general.wandb.run_id={cfg.general.wandb.run_id}.'
    
    api = wandb.Api() 
    run = api.run(f"{cfg.general.wandb.entity}/{cfg.general.wandb.project}/{cfg.general.wandb.run_id}") 
    run_config = OmegaConf.create(dict(run.config))
    # run.finish() <- Can't have this, the run could actually be running if this is used during evaluation

    return run_config

def setup_wandb(cfg, job_type):
    assert (cfg.general.wandb.resume==False) or (cfg.general.wandb.resume and cfg.general.wandb.run_id!=''), "If wandb_resume is True, wandb.run_id must be set"
    # tags and groups
    kwargs = {'entity': cfg.general.wandb.entity, 'project': cfg.general.wandb.project, 'job_type': job_type,
              'group': cfg.general.wandb.group, 'tags': cfg.general.wandb.tags, 'mode': cfg.general.wandb.mode}
    
    log.info(kwargs)
    if cfg.general.wandb.resume:
        kwargs['id'] = cfg.general.wandb.run_id
        kwargs['resume'] = 'allow'
        run = wandb.init(**kwargs)
        run.config['train']['epochs'] = cfg.train.epochs # need this when resuming or otherwise overriding the epochs defined in the yaml file
        run.config['general']['wandb'] = {'resume': True, 'run_id': run.id, 'entity': cfg.general.wandb.entity, 
                                        'project': cfg.general.wandb.project, 'mode': 'online'}
        cfg = OmegaConf.create(dict(run.config))
    else:
        # if we're not resuming, use the cfg dictionary to create a run
        config_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
        kwargs['config'] = config_dict
        kwargs['name'] = cfg.general.wandb.run_name 
        run = wandb.init(**kwargs)

    return run, cfg

def get_wandb_run(run_path):
    api = wandb.Api()
    run = api.run(run_path)
    return run

def resume_wandb_run(cfg):
    assert cfg.general.wandb_id != "" and cfg.general.wandb_id != None, "wandb_id must be set if wandb_resume is True"
    return wandb.init(id=cfg.general.wandb_id, project=cfg.general.project, entity=cfg.general.wandb_team, resume="allow")

def download_checkpoint_from_wandb(cfg, savedir, epoch_num, run=None):
    # Download the checkpoint
    run = wandb.Api(overrides={"entity": cfg.general.wandb.entity, "project": cfg.general.wandb.project}).run(f"{cfg.general.wandb.project}/{cfg.general.wandb.run_id}")
    all_artifacts = run.logged_artifacts()
    downloaded_dir = None
    for a in all_artifacts:
        if a.name.startswith(f"eval_epoch{epoch_num}:"): # For resuming in the old format
            downloaded_dir = a.download(root=os.path.join(savedir))
            os.rename(os.path.join(downloaded_dir, f"eval_epoch{epoch_num}.pt"), 
                      os.path.join(downloaded_dir, f"epoch{epoch_num}.pt"))
            break
        if a.type=='model':
            if len([s for s in a.aliases if s.split('epoch')[-1]==str(epoch_num)])==1:
                downloaded_dir = a.download(root=os.path.join(savedir))
                if os.path.exists(os.path.join(downloaded_dir, f"eval_epoch{epoch_num}.pt")):
                    os.rename(os.path.join(downloaded_dir, f"eval_epoch{epoch_num}.pt"), 
                        os.path.join(downloaded_dir, f"epoch{epoch_num}.pt"))
                break
            else:
                assert f'Found more than one model checkpoint file with alias epoch{epoch_num}\n'

    # Get the name of the downloaded file
    assert downloaded_dir is not None, f"No checkpoint found for epoch={epoch_num}."
    
    downloaded_file = os.path.join(downloaded_dir, f"epoch{epoch_num}.pt")

    return downloaded_file, a

def get_latest_epoch_from_wandb(cfg):
    """
    Gets the number of the latest checkpoint epoch from the wandb run
    """
    run = wandb.Api(overrides={"entity": cfg.general.wandb.entity, "project": cfg.general.wandb.project}).run(f"{cfg.general.wandb.project}/{cfg.general.wandb.run_id}")
    all_artifacts = run.logged_artifacts()
    epoch_nums = []
    for a in all_artifacts:
        if a.name.startswith("eval_epoch"): # Old format for artifacts
            epoch_nums.append(int(a.name.split("eval_epoch")[1].split(":")[0]))
        elif a.type=='model':
            # assume only a single epoch# alias exists for each model artifact version
            epoch_nb = [re.findall(r'\d+', alias)[0] for alias in a.aliases if 'epoch' in alias][0]
            epoch_nums.append(int(epoch_nb))
    assert len(epoch_nums) > 0, "No checkpoints found for the specified run."
    return max(epoch_nums)

def get_wandb_run_path(cfg):
    return f"{cfg.general.wandb_team}/{cfg.general.project}/{cfg.general.wandb_id}"

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_dataset(cfg, dataset_class, shuffle=True, recompute_info=False, return_datamodule=False, slices={'train':None, 'val':None, 'test':None}):
    datamodule = dataset_class.DataModule(cfg)
    datamodule.prepare_data(shuffle=shuffle, slices=slices)

    dataset_infos = dataset_class.DatasetInfos(datamodule=datamodule, atom_types=cfg.dataset.atom_types, bond_types=cfg.dataset.bond_types, allowed_bonds=cfg.dataset.allowed_bonds, 
                                               zero_bond_order=cfg.dataset.zero_bond_order, recompute_info=recompute_info, remove_h=cfg.dataset.remove_h)
    dataset_infos.compute_input_output_dims(datamodule=datamodule)

    return (datamodule, dataset_infos) if return_datamodule else dataset_infos

def check_if_dataparallel_dict(state_dict):
    if 'module' in list(state_dict.keys())[-1]: # don't use [0], but instead the last key, because for EMA, the first key is 'initted'
        return True
    else:
        return False

def dataparallel_dict_to_regular_state_dict(state_dict):
    new_dict = OrderedDict()
    for key in state_dict.keys():
        new_dict[key.replace('module.', '')] = state_dict[key]
    return new_dict

def regular_state_dict_to_dataparallel_dict(state_dict):
    new_dict = OrderedDict()
    for key in state_dict.keys():
        # new_dict['module.'+key] = state_dict[key]
        new_dict['model.module.'+'.'.join(key.split('.')[1:])] = state_dict[key]
    return new_dict

def load_weights(model, model_state_dict, device_count=None):
    assert device_count is not None, f'Expected device_count to not be None. Found device_count={device_count}'

    print(check_if_dataparallel_dict(model_state_dict))
    if check_if_dataparallel_dict(model_state_dict) and device_count <= 1:
        model_state_dict = dataparallel_dict_to_regular_state_dict(model_state_dict)
    elif not check_if_dataparallel_dict(model_state_dict) and device_count > 1:
        model_state_dict = regular_state_dict_to_dataparallel_dict(model_state_dict)
        
    model.load_state_dict(model_state_dict)
    
    return model

def load_all_state_dicts(cfg, model, optimizer, lr_scheduler, scaler, checkpoint_file, device_count=None):
    checkpoint = torch.load(checkpoint_file, map_location=torch.device(device))
    load_weights(model, checkpoint['model_state_dict'], device_count=device_count)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    if checkpoint['scaler_state_dict']!={}: # need this because scaler only available in gpu
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
    if 'ema_state_dict' in checkpoint.keys():
        ema_state_dict = checkpoint['ema_state_dict']
        if check_if_dataparallel_dict(ema_state_dict) and device_count <= 1:
            ema_state_dict = dataparallel_dict_to_regular_state_dict(ema_state_dict)
        elif not check_if_dataparallel_dict(ema_state_dict) and device_count > 1:
            ema_state_dict = regular_state_dict_to_dataparallel_dict(ema_state_dict)
        model.ema.load_state_dict(ema_state_dict)

def get_model_and_train_objects(cfg, model_class, model_kwargs, parent_path, savedir, run=None, epoch_num=None, 
                                load_weights_bool=True, device=None, device_count=None):
    assert device is not None and device_count is not None, f'Expected device and device_count not to be None. Found device={device} and device_count={device_count}'
    
    model = model_class(cfg=cfg, **model_kwargs)

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.train.lr, amsgrad=True, weight_decay=cfg.train.weight_decay) 
    lr_scheduler = get_lr_scheduler(cfg, optimizer)
    scaler = torch.cuda.amp.GradScaler()
    last_epoch = 0

    if load_weights_bool and (cfg.general.wandb.resume or cfg.general.wandb.run_id):
        assert cfg.general.wandb.resume==False or (cfg.general.wandb.resume and run is not None),\
           f'cfg.general.wandb.resume=={cfg.general.wandb.resume}, expected run != None.'
        last_epoch = epoch_num or get_latest_epoch_from_wandb(cfg)
        checkpoint_file, artifact = download_checkpoint_from_wandb(cfg, savedir, last_epoch, run=run)
        load_all_state_dicts(cfg, model, optimizer, lr_scheduler, scaler, checkpoint_file, device_count=device_count)
        artifact_name_in_wandb = f"{cfg.general.wandb.entity}/{cfg.general.wandb.project}/{artifact.name}"
        if run != None:
            run.use_artifact(artifact_name_in_wandb)

    return model, optimizer, lr_scheduler, scaler, last_epoch

def load_weights_from_wandb(cfg, epoch_num, savedir, model, optimizer, lr_scheduler, scaler, run=None, device_count=None):
    last_epoch = epoch_num or get_latest_epoch_from_wandb(cfg)
    checkpoint_file, artifact = download_checkpoint_from_wandb(cfg, savedir, last_epoch)
    load_all_state_dicts(cfg, model, optimizer, lr_scheduler, scaler, checkpoint_file, device_count)
    artifact_name_in_wandb = f"{cfg.general.wandb.entity}/{cfg.general.wandb.project}/{checkpoint_file.split('/')[-2]}"
    if run!=None: run.use_artifact(artifact_name_in_wandb)
    
    return model, optimizer, lr_scheduler, scaler, artifact_name_in_wandb

def load_weights_from_wandb_no_download(cfg, epoch_num, savedir, model, optimizer, lr_scheduler, scaler, run=None, device_count=None):
    if not os.path.exists(os.path.join(savedir, f'epoch{epoch_num}.pt')) and \
       not os.path.exists(os.path.join(savedir, f'eval_epoch{epoch_num}.pt')):
        return load_weights_from_wandb(cfg, epoch_num, savedir, model, optimizer, lr_scheduler, 
                                       scaler, run=run, device_count=device_count)
           
    checkpoint_file = os.path.join(savedir, f'epoch{epoch_num}.pt')\
                      if os.path.exists(os.path.join(savedir, f'epoch{epoch_num}.pt'))\
                      else os.path.join(savedir, f'eval_epoch{epoch_num}.pt')
    load_all_state_dicts(cfg, model, optimizer, lr_scheduler, scaler, checkpoint_file, device_count)
    artifact_name_in_wandb = f"{cfg.general.wandb.entity}/{cfg.general.wandb.project}/{checkpoint_file.split('/')[-2]}"
    if run!=None: run.use_artifact(artifact_name_in_wandb)
    
    return model, optimizer, lr_scheduler, scaler, artifact_name_in_wandb

def get_lr_scheduler(cfg, optimizer):
    # Learning rate scheduling. Not used in the paper, but could be useful for future work. 
    if cfg.train.lr_scheduler == 'none':
        lr_scale = lambda epoch: 1.0
        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_scale)
        return lr_scheduler
    if cfg.train.lr_scheduler == 'linear':
        num_warmup_epochs = cfg.train.num_warmup_epochs
        num_annealing_epochs = cfg.train.num_annealing_epochs - num_warmup_epochs
        initial_lr = cfg.train.initial_lr
        warmup_lr = cfg.train.lr 
        final_lr = cfg.train.final_lr
        def lr_scale(epoch):
            if epoch < num_warmup_epochs:
                return ((epoch + 1) / num_warmup_epochs * (warmup_lr - initial_lr) + initial_lr) / warmup_lr
            elif epoch < num_warmup_epochs + num_annealing_epochs:
                t = (epoch - num_warmup_epochs) / num_annealing_epochs
                return ((1 - t) * warmup_lr + t * final_lr) / warmup_lr
            else:
                return final_lr / warmup_lr
        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_scale)
        return lr_scheduler
    elif cfg.train.lr_scheduler == 'cosine':
        num_warmup_epochs = cfg.train.num_warmup_epochs
        num_annealing_epochs = cfg.train.num_annealing_epochs - num_warmup_epochs
        initial_lr = cfg.train.initial_lr
        warmup_lr = cfg.train.lr
        final_lr = cfg.train.final_lr
        def lr_scale(epoch):
            if epoch < num_warmup_epochs:
                return ((epoch + 1) / num_warmup_epochs * (warmup_lr - initial_lr) + initial_lr) / warmup_lr
            elif epoch < num_warmup_epochs + num_annealing_epochs:
                t = (epoch - num_warmup_epochs) / num_annealing_epochs
                return ((np.cos(t*np.pi)+1)/2* warmup_lr + (1 - (np.cos(t*np.pi)+1)/2) * final_lr) / warmup_lr
            else:
                return final_lr / warmup_lr
        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_scale)
        return lr_scheduler
    else:
        raise NotImplementedError
    
def load_testfile(cfg, data_class):
    if str(cfg.dataset.dataset_nb)!='': 
        res = '-' + str(cfg.dataset.dataset_nb)
    else:
        res = cfg.dataset.dataset_nb
    base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
    root_path = os.path.join(base_path, cfg.dataset.datadir+res)
    path = os.path.join(root_path, 'processed', f'test_{int(cfg.test.testfile)}.pt')
    assert os.path.exists(path), f'Path {path} does not exist.'
    dataset = data_class.Dataset(stage=f'test_{int(cfg.test.testfile)}', root=root_path)
    # path = os.path.join(parent_path, f'data/uspto-50k{res}/processed/test_{int(cfg.test.testfile)-1}.pt')
    # test_dataset = torch.load(path)
    g = torch.Generator()
    g.manual_seed(cfg.train.seed)

    test_dataloader = DataLoader(dataset, batch_size=cfg.test.batch_size,
                                 num_workers=cfg.dataset.num_workers, generator=g,
                                 shuffle=False)
    return test_dataloader