import gc
import os
import sys
import time
import yaml
import wandb
import socket
import shutil
import random
import pynvml
import warnings
import numpy as np
from tqdm import tqdm
from typing import Tuple, List

from easydict import EasyDict
from itertools import product

import torch
from torch.utils.data import TensorDataset, DataLoader

from sampling import sampling

from utils.runner import train_epoch, val_epoch

from models.diffusion import Diffusion

def seed_all(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

def grid_search(param_grid: dict) -> List[dict]:
    keys, values = param_grid.keys(), param_grid.values()
    return [dict(zip(keys, combination)) for combination in product(*values)]

def train_sample_eval(conf_path: str, is_coding_mode: bool = False) -> str:
    with open(conf_path, 'r') as f:
        config = EasyDict(yaml.safe_load(f))

    device = torch.device(config.model.device)

    seed_all(config.train.seed)
    log_dir = '/'.join(conf_path.split('/')[:-1])
    shutil.copytree('./models', os.path.join(log_dir, 'models'))
    ckpt_dir = os.path.join(log_dir, 'checkpoints')
    os.makedirs(ckpt_dir, exist_ok=True)

    proj_name = config.project

    if not is_coding_mode:
        wandb.init(
            project=f'Phy-diffusion-{config.model.dataset}-aug',
            # project=f'temp',
            entity='Anonymous',
            config=config,
            group=proj_name,
            job_type="training",
            name=proj_name,
            notes=socket.gethostname(),
            save_code=True,
            reinit=True
        )
        
    print('Loading dataset')

    train_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(np.load(config.dataset.train_x)).to(torch.float32),
            torch.from_numpy(np.load(config.dataset.train_edge)).to(torch.float32),
            torch.from_numpy(np.load(config.dataset.train_energy)).to(torch.float32)
        ),
        config.train.batch_size,
        num_workers=0,
        shuffle=True
    )
    val_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(np.load(config.dataset.val_x)).to(torch.float32),
            torch.from_numpy(np.load(config.dataset.val_edge)).to(torch.float32),
            torch.from_numpy(np.load(config.dataset.val_energy)).to(torch.float32)
        ),
        config.train.batch_size,
        num_workers=0,
        shuffle=False
    )

    model = Diffusion(config.model).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3,
        weight_decay=0,
        betas=(0.95, 0.999)
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=0.6,
        patience=10,
    )

    best_val_loss = float('inf')
    print("Training start!!!")
    ckpt_path = os.path.join(ckpt_dir, f'ckpt.pt')

    for epoch in tqdm(range(config.train.num_epoches), dynamic_ncols=True):
        is_convergent = train_epoch(model, train_loader, optimizer, epoch, device)

        if epoch % config.train.val_freq == 0:
            val_loss = val_epoch(model, val_loader, epoch, device)
            scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            torch.save({
                'config': config,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'avg_val_loss': val_loss,
            }, ckpt_path)

        if is_convergent:
            print('Model converges! Early stopping!')
            break

        if val_loss > 0.2 and os.path.exists(ckpt_path):
            model.load_state_dict(torch.load(ckpt_path)['model'])
            print('Training failed. Loading model histroy!')
    del model
    gc.collect()

    sampling(ckpt_path)
    return ckpt_path

def get_arg(
    n_spring: int, 
    loss: str, 
    model: str,
    network: dict,
    is_coding_mode: bool = False,
    save_path: str = 'logs',
    diffusion: str = 'vp',
    device = None,
) -> Tuple[str, bool]:
    assert diffusion in ['vp', 've']

    yml = {
        'train': {'seed': 42, 'batch_size': 64, 'num_epoches': 1000, 'val_freq': 2}
    }

    dataset = f'{n_spring}spring'
    yml['dataset'] = {
        'train_x': f'data/train_x_{n_spring}_spring.npy',
        'train_edge': f'data/train_edge_{n_spring}_spring.npy',
        'train_energy': f'data/train_energy_{n_spring}_spring.npy',

        'val_x': f'data/val_x_{n_spring}_spring.npy',
        'val_edge': f'data/val_edge_{n_spring}_spring.npy',
        'val_energy': f'data/val_energy_{n_spring}_spring.npy',
    }
    dataset_para = {
        'input_size': 20,
        'input_length': 50,
        'n_system': n_spring,
        'repara_size': 2 * n_spring
    }

    pynvml.nvmlInit()
    to_gb = 1024**3
    if device is None:
        pynvml.nvmlInit()
        free = []
        
        n_gpu = pynvml.nvmlDeviceGetCount()
        for gpu_id in range(n_gpu):
            handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
            meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
            free.append(meminfo.free / to_gb)

        device = n_gpu - np.argmax(list(reversed(free))) - 1
    elif isinstance(device, int):
        assert device >= 0
    elif isinstance(device, str):
        device = int(''.join(device.split('cuda:')[1:]))
        
    handler = pynvml.nvmlDeviceGetHandleByIndex(device)
    meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
    print(f'Using device: {device}. Free memory: {meminfo.free / to_gb}')
    device = 'cuda:' + str(device)

    yml['model'] = {
        'loss': loss,
        'dataset': dataset,
        'device': device,
        'diffusion': diffusion,
        'model': model,
        'network': {'device': device, **dataset_para, **network},
        'sampling': {'method': 'ode'},
    }
    
    if is_coding_mode:
        save_folder = 'temp'
    else:
        save_folder = dataset + '---' + loss + '---' + model + '---' + \
            '--'.join([k + '-' + str(v) for k, v in network.items()])
        
    yml['project'] = save_folder

    save_folder = save_folder + '---' + \
        time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
    save_folder = os.path.join(save_path, model, save_folder)
    os.makedirs(save_folder)
    save_yml = os.path.join(save_folder, 'config.yml')
    with open(save_yml, 'w') as yaml_file:
        yaml.dump(yml, yaml_file, sort_keys=False)

    return save_yml, is_coding_mode



if __name__ == '__main__':
    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
    os.environ['WANDB_CONFIG_DIR'] = '/tmp/.config-' + os.environ['USER']
    torch.multiprocessing.set_start_method('spawn') 



    try:
        hidden_size = int(sys.argv[1])
    except:
        hidden_size = 512

    try:
        device = int(sys.argv[2])
    except:
        device = None

    try:
        loss = sys.argv[3]
    except:
        loss = 'naive'

    # loss = 'naive'

    # loss = 'momentum_1.0'
    # loss = 'momentum_0.5'
    # loss = 'momentum_0.1'

    # loss = 'implicit_energy_0.1'
    # loss = 'implicit_energy_0.01'
    # loss = 'implicit_energy_0.001'
    # loss = 'implicit_energy_0.0001'

    # loss = 'jensen_0.1'
    # loss = 'jensen_0.01'
    # loss = 'jensen_0.001'
    # loss = 'jensen_0.0001'

    n_spring = 5
    general_hyper = {
        't_embed_size': [64],
        'gnn_hidden_size': [256],
        'n_layers': [3]
    }
    is_coding_mode = sys.gettrace() is not None
    model = 'EGNN_GRU'
    networks_para = grid_search({
        'rnn_hidden_size': [hidden_size],
    } | general_hyper)

    for network in networks_para:
        train_sample_eval(*get_arg(
            n_spring, loss, model, network, is_coding_mode,
            save_path='logs',
            device=device
        ))

