import gc
import os
import sys
import time
import yaml
import wandb
import socket
import shutil
import random
import pynvml
import pickle
import warnings
import numpy as np
from tqdm import tqdm
from typing import Tuple, List

from easydict import EasyDict
from itertools import product

import torch
from utils.SWDataset import SWDataset
from torch.utils.data import DataLoader

from sampling import sampling_eval
from utils.runner import train_epoch, val_epoch

from models.diffusion import Diffusion

def seed_all(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

def grid_search(param_grid: dict) -> List[dict]:
    keys, values = param_grid.keys(), param_grid.values()
    return [dict(zip(keys, combination)) for combination in product(*values)]

def train_sample_eval(conf_path: str, is_coding_mode: bool = False) -> str:
    with open(conf_path, 'r') as f:
        config = EasyDict(yaml.safe_load(f))

    device = torch.device(config.model.device)

    seed_all(config.train.seed)
    log_dir = '/'.join(conf_path.split('/')[:-1])
    shutil.copytree('./models', os.path.join(log_dir, 'models'))
    ckpt_dir = os.path.join(log_dir, 'checkpoints')
    os.makedirs(ckpt_dir, exist_ok=True)

    proj_name = config.project
    if not is_coding_mode:
        wandb.init(
            project=f'Phy-diffusion-shallow-water',
            entity='Anonymous',
            config=config,
            group=proj_name,
            job_type="training",
            name=proj_name,
            notes=socket.gethostname(),
            save_code=True,
            reinit=True
        )
        
    print('Loading dataset')

    train_data = pickle.load(open(config.train_path, 'rb'))
    val_data   = pickle.load(open(config.val_path, 'rb'))

    train_loader = DataLoader(
        SWDataset(train_data),
        config.train.batch_size,
        num_workers=0,
        shuffle=True
    )
    val_loader = DataLoader(
        SWDataset(val_data),
        config.train.batch_size,
        num_workers=0,
        shuffle=False
    )

    model = Diffusion(config.model, **train_data).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3,
        weight_decay=0,
        betas=(0.95, 0.999)
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=0.6,
        patience=10,
    )

    best_val_loss = float('inf')
    print("Training start!!!")

    for epoch in tqdm(range(config.train.num_epoches), dynamic_ncols=True):
        is_convergent = train_epoch(model, train_loader, optimizer, epoch, device)

        if epoch % config.train.val_freq == 0:
            val_loss = val_epoch(model, val_loader, epoch, device)
            scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            ckpt_path = os.path.join(ckpt_dir, f'ckpt.pt')
            torch.save({
                'config': config,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'avg_val_loss': val_loss,
            }, ckpt_path)

        if is_convergent:
            print('Model converges! Early stopping!')
            break

        if val_loss > 1.0 and os.path.exists(os.path.join(ckpt_dir, f'ckpt.pt')):
            model.load_state_dict(os.path.join(ckpt_dir, f'ckpt.pt'))
            print('Training failed. Loading model histroy!')
    del model
    torch.cuda.empty_cache()
    gc.collect()

    sampling_eval(ckpt_path)
    return ckpt_path

def get_arg(
    loss: str, 
    network: dict,
    is_coding_mode: bool = False,
    save_path: str = 'logs',
    diffusion: str = 'vp',
    device = None,
) -> Tuple[str, bool]:
    assert diffusion in ['vp', 've']

    yml = {
        'train': {'seed': 42, 'batch_size': 64, 'num_epoches': 1000, 'val_freq': 1},
        'train_path': 'data/train.pkl',
        'val_path': 'data/val.pkl',
        'test_path': 'data/test.pkl',
    }

    pynvml.nvmlInit()
    to_gb = 1024**3
    if device is None:
        pynvml.nvmlInit()
        free = []
        
        n_gpu = pynvml.nvmlDeviceGetCount()
        for gpu_id in range(n_gpu):
            handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
            meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
            free.append(meminfo.free / to_gb)

        device = n_gpu - np.argmax(list(reversed(free))) - 1
    elif isinstance(device, int):
        assert device >= 0
    elif isinstance(device, str):
        device = int(''.join(device.split('cuda:')[1:]))
        
    handler = pynvml.nvmlDeviceGetHandleByIndex(device)
    meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
    print(f'Using device: {device}. Free memory: {meminfo.free / to_gb}')
    device = 'cuda:' + str(device)

    yml['model'] = {
        'loss': loss,
        'device': device,
        'diffusion': diffusion,
        'network': {
            'device': device,
            **network
        },
        'sampling': {'method': 'ode'},
    }
    
    if is_coding_mode:
        save_folder = 'temp'
    else:
        save_folder = loss + '---' + \
            '--'.join([k + '-' + str(v) for k, v in network.items()])
        
    yml['project'] = save_folder

    save_folder = save_folder + '---' + \
        time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
    save_folder = os.path.join(save_path, save_folder)
    os.makedirs(save_folder)
    save_yml = os.path.join(save_folder, 'config.yml')
    with open(save_yml, 'w') as yaml_file:
        yaml.dump(yml, yaml_file, sort_keys=False)

    return save_yml, is_coding_mode



if __name__ == '__main__':
    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
    os.environ['WANDB_CONFIG_DIR'] = '/tmp/.config-' + os.environ['USER']
    torch.multiprocessing.set_start_method('spawn') 

    is_coding_mode = sys.gettrace() is not None
    

    hidden_size = 16

    networks_para = grid_search({
        'hidden_size': [hidden_size],
    })

    device = 7
    # loss = 'naive'
    loss = 'pde_100.0'
    # loss = 'pde_10.0'
    # loss = 'pde_1.0'
    # loss = 'pde_0.1'
    # loss = 'pde_0.01'



    for network in networks_para:
        train_sample_eval(*get_arg(
            loss, network,
            is_coding_mode,
            device=device
        ))

