import sys
import time
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import datetime
import wandb
import numpy as np
import cv2
from tqdm import tqdm
import gc
import torch
from torchvision import transforms
from torch import optim, nn
from torch.utils.data import DataLoader

import lpips
from pytorch_msssim import MS_SSIM

from util import *
from dataset import *
from loss_fn import *
from IFIN import *
from block import *

from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(split_batches=True, log_with="wandb", kwargs_handlers=[ddp_kwargs])

Name = 'IFIN'
NOTES = Name

seed_everything(4716)

START_DATE = Name + '_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")


# Define network hyperparameters:
HPARAMS = {
    'N_CHANNEL': 32,
    'BATCH_SIZE': 2,
    'NUM_WORKERS': 4,
    'EPOCHS_NUM': 100,
    'LR': 1e-4,
    'LR_PSF': 1e-3,
    'K': 1,
    'DEPTH':3, 
    'LAMBDA_IMG':0.8,
    'LAMBDA_CONS_IM':0.01,
    'LAMBDA_CONS_FT':0.01,
    'LAMBDA_WIENER':0.1,
    'LAMBDA_PSF':0.1,
    'WALLER_PATH': '/mnt/ssd1/wallerlab/dataset',
    'PSF_PATH': '/mnt/NAS/_datasets/wallerlab/dataset/psf.tiff',
}

# normalize_lambdas(HPARAMS, anchor_key="LAMBDA_IMG")

accelerator.init_trackers(
    project_name = "Lensless Reconstruction",
    init_kwargs={"wandb": {
                          "name":START_DATE,
                          # "mode":'disabled',
                          "notes":NOTES,}}
    )
DEVICE = accelerator.device

psf = cv2.imread(HPARAMS['PSF_PATH'], 0)
psf = cv2.resize(psf, (480, 270))
psf = np.asarray(psf)
psf = torch.from_numpy(psf).unsqueeze(0).unsqueeze(0).to(DEVICE)
psf = psf / 255.
psf_bg = torch.mean(psf[:,:,0 : 15, 0 : 15])
psf = psf-psf_bg
psf[psf<0] = 0

bbox = get_center_bbox(psf)
psf = crop_from_center_bbox(psf, bbox)
print(bbox)

TPARAMS = {'PSF': psf,
           'accelerator': accelerator}

def train(train_parameters):
    train_parameters['model'].train()
    result = {}
    result['loss_sum'] = 0
    result['loss_image'] = 0
    result['loss_cons_im'] = 0
    result['loss_cons_ft'] = 0
    result['loss_wiener'] = 0
    result['loss_psf'] = 0
    result['PSNR'] = 0
    criterion = nn.MSELoss()
    train_parameters['train_bar'] = tqdm(train_parameters['trainset_loader'], position=1, leave=False, disable=not train_parameters['accelerator'].is_local_main_process,)
    for i, (image, label) in enumerate(train_parameters['train_bar']):
        train_parameters['optimizer'].zero_grad()
        train_parameters['optimizer_psf'].zero_grad()
        label = label.to(DEVICE)
        image = image.to(DEVICE)
        result['output'], result['output_raw'], result['wiener'] = train_parameters['model'](image)
        result['output_wiener'] = train_parameters['model'].module.WieNerH(result['output_raw'], train_parameters['model'].module.psf)
        loss_image = train_parameters['loss_function'](result['output'], label, train_parameters['epoch_now'])
        loss_cons_im = criterion(result['output_raw'], image)
        loss_cons_ft = criterion(result['output_wiener'], result['wiener'])
        loss_wiener = criterion(result['wiener'], label)
        loss_psf = non_negativity_loss(train_parameters['model'].module.psf)
        loss = (loss_image * HPARAMS['LAMBDA_IMG'] +
                loss_cons_im * HPARAMS['LAMBDA_CONS_IM'] +
                loss_cons_ft * HPARAMS['LAMBDA_CONS_FT'] +
                loss_wiener * HPARAMS['LAMBDA_WIENER'] +
                loss_psf * HPARAMS['LAMBDA_PSF'])
        train_parameters['accelerator'].backward(loss)
        train_parameters['optimizer_psf'].step()
        train_parameters['optimizer'].step()
        train_parameters['accelerator'].wait_for_everyone()
        with torch.no_grad():
            label = preplot_t(label)
            result['output'] = preplot_t(result['output'])
            result['loss_sum'] += loss.mean()
            result['loss_image'] += loss_image.mean()
            result['loss_cons_im'] += loss_cons_im.mean()
            result['loss_cons_ft'] += loss_cons_ft.mean()
            result['loss_wiener'] += loss_wiener.mean()
            result['loss_psf'] += loss_psf.mean()
            result['PSNR'] += PSNR(result['output'], label)

            train_parameters['train_bar'].set_description('Train :: loss: {:.5}, {:.5}, {:.5}, {:.5}, {:.5}, {:.5}., PSNR: {:.5}.'.format(result['loss_sum'] / (i+1), result['loss_image'] / (i+1),result['loss_cons_im'] / (i+1),result['loss_cons_ft'] / (i+1),result['loss_wiener'] / (i+1),result['loss_psf'] / (i+1), result['PSNR'] / (i+1)))


    with torch.no_grad():        
        result['output'] = result['output'][:3].detach() # use preplot_t for wallerlab dataset
        result['output_raw'] = result['output_raw'][:3].detach()
        result['wiener'] = preplot_t(result['wiener'][:3]).detach()
        result['output_wiener'] = preplot_t(result['output_wiener'][:3]).detach()
        result['input'] = image[:3].detach()
        result['label'] = label[:3].detach()
        result['psf'] = train_parameters['model'].module.psf.permute(1,0,2,3).detach()
        result['loss_sum'] = train_parameters['accelerator'].reduce(result['loss_sum'], reduction='mean') / len(train_parameters['trainset_loader'])
        result['loss_image'] = train_parameters['accelerator'].reduce(result['loss_image'], reduction='mean') / len(train_parameters['trainset_loader'])
        result['loss_cons_im'] = train_parameters['accelerator'].reduce(result['loss_cons_im'], reduction='mean') / len(train_parameters['trainset_loader'])
        result['loss_cons_ft'] = train_parameters['accelerator'].reduce(result['loss_cons_ft'], reduction='mean') / len(train_parameters['trainset_loader'])
        result['loss_wiener'] = train_parameters['accelerator'].reduce(result['loss_wiener'], reduction='mean') / len(train_parameters['trainset_loader'])
        result['loss_psf'] = train_parameters['accelerator'].reduce(result['loss_psf'], reduction='mean') / len(train_parameters['trainset_loader'])
        result['PSNR'] = train_parameters['accelerator'].reduce(result['PSNR'], reduction='mean') / len(train_parameters['trainset_loader'])
        
        # train_parameters['scheduler'].step()
    return result


def test(test_parameters):
    test_parameters['model'].eval()
    result = {}
    result['loss_sum'] = 0
    result['loss_image'] = 0
    result['loss_cons_im'] = 0
    result['loss_cons_ft'] = 0
    result['loss_wiener'] = 0
    result['loss_psf'] = 0
    result['PSNR'] = 0
    criterion = nn.MSELoss()
    test_parameters['test_bar'] = tqdm(test_parameters['testset_loader'], position=1, leave=False, disable=not test_parameters['accelerator'].is_local_main_process)
    with torch.no_grad():
        for i, (image, label) in enumerate(TPARAMS['test_bar']):
            label = label.to(DEVICE)
            image = image.to(DEVICE)
            result['output'], result['output_raw'], result['wiener'] = test_parameters['model'](image)#, test_parameters['PSF'])
            result['output_wiener'] = test_parameters['model'].module.WieNerH(result['output_raw'], test_parameters['model'].module.psf)
            loss_image = test_parameters['loss_function'](result['output'], label, test_parameters['epoch_now'])
            loss_cons_im = criterion(result['output_raw'], image)
            loss_cons_ft = criterion(result['wiener'], result['output_wiener'])
            loss_wiener = criterion(result['wiener'], label)
            loss_psf = non_negativity_loss(test_parameters['model'].module.psf)
            loss = (loss_image * HPARAMS['LAMBDA_IMG'] +
                    loss_cons_im * HPARAMS['LAMBDA_CONS_IM'] +
                    loss_cons_ft * HPARAMS['LAMBDA_CONS_FT'] +
                    loss_wiener * HPARAMS['LAMBDA_WIENER'] +
                    loss_psf * HPARAMS['LAMBDA_PSF'])
            result['loss_sum'] += loss.mean()
            result['loss_image'] += loss_image.mean()
            result['loss_cons_im'] += loss_cons_im.mean()
            result['loss_cons_ft'] += loss_cons_ft.mean()
            result['loss_wiener'] += loss_wiener.mean()
            result['loss_psf'] += loss_psf.mean()
            label = preplot_t(label)
            result['output'] = preplot_t(result['output'])
            result['PSNR'] += PSNR(result['output'], label)
            test_parameters['test_bar'].set_description('Test :: loss: {:.5}, {:.5}, {:.5}, {:.5}, {:.5}., PSNR: {:.5}.'.format(result['loss_sum'] / (i+1), result['loss_image'] / (i+1),result['loss_cons_im'] / (i+1),result['loss_cons_ft'] / (i+1),result['loss_wiener'] / (i+1), result['PSNR'] / (i+1)))
            
        test_parameters['test_bar'].close()
        test_parameters['scheduler_psf'].step(result['loss_wiener'] / (i+1))
        test_parameters['scheduler'].step(result['loss_sum'] / (i+1))
        result['output'] = result['output'][:3] # use preplot_t for wallerlab dataset
        result['output_raw'] = result['output_raw'][:3]
        result['wiener'] = preplot_t(result['wiener'][:3])
        result['output_wiener'] = preplot_t(result['output_wiener'][:3])
        result['input'] = image[:3]
        result['label'] = label[:3]
        result['psf'] = test_parameters['model'].module.psf.permute(1,0,2,3)
        result['loss_sum'] = test_parameters['accelerator'].reduce(result['loss_sum'], reduction='mean') / len(test_parameters['testset_loader'])
        result['loss_image'] = test_parameters['accelerator'].reduce(result['loss_image'], reduction='mean') / len(test_parameters['testset_loader'])
        result['loss_cons_im'] = test_parameters['accelerator'].reduce(result['loss_cons_im'], reduction='mean') / len(test_parameters['testset_loader'])
        result['loss_cons_ft'] = test_parameters['accelerator'].reduce(result['loss_cons_ft'], reduction='mean') / len(test_parameters['testset_loader'])
        result['loss_wiener'] = test_parameters['accelerator'].reduce(result['loss_wiener'], reduction='mean') / len(test_parameters['testset_loader'])
        result['loss_psf'] = test_parameters['accelerator'].reduce(result['loss_psf'], reduction='mean') / len(test_parameters['testset_loader'])
        result['PSNR'] = test_parameters['accelerator'].reduce(result['PSNR'], reduction='mean') / len(test_parameters['testset_loader'])
    return result


def main():
    """Main process function."""

    # Load dataset
    transformer1 = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    transformer2 = transforms.Compose([
        transforms.ToTensor(),
    ])
        
    trainset_load = WallerDataset('/mnt/NAS/_datasets/wallerlab/dataset', train=True, transform_raw=transformer1, transform_lab=transformer2)
    testset_load = WallerDataset('/mnt/NAS/_datasets/wallerlab/dataset', train=False, transform_raw=transformer1, transform_lab=transformer2)
    
    TPARAMS['trainset_loader'] = torch.utils.data.DataLoader(
        trainset_load,
        batch_size=HPARAMS['BATCH_SIZE'],
        shuffle=True,
        num_workers=HPARAMS['NUM_WORKERS'],
        pin_memory=False,
        # sampler=trainset_sampler,
    )

    TPARAMS['testset_loader'] = torch.utils.data.DataLoader(
        testset_load,
        batch_size=HPARAMS['BATCH_SIZE'],
        shuffle=False,
        num_workers=HPARAMS['NUM_WORKERS'],
        pin_memory=False,
    )

    # Initialize model
    print('model loading...')
    TPARAMS['model'] = CAWNet(3,3,TPARAMS['PSF'], dim=HPARAMS['N_CHANNEL'], k=HPARAMS['K'], depth=HPARAMS['DEPTH'], block_cls=DoubleConvLN).to(DEVICE)
    TPARAMS['model'] = TPARAMS['model'].to(DEVICE)
    print('model loaded')
    
    TPARAMS['loss_function'] = LossFunction().to(DEVICE)	
    other_params = [p for name, p in TPARAMS['model'].named_parameters() if name != 'psf']
    TPARAMS['optimizer'] = optim.AdamW([
        {'params': other_params, 'lr': HPARAMS['LR']}, 
    ], betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
    TPARAMS['optimizer_psf'] = optim.AdamW([
        {'params': [TPARAMS['model'].psf], 'lr': HPARAMS['LR_PSF']}
    ], betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
    TPARAMS['scheduler'] = optim.lr_scheduler.ReduceLROnPlateau(TPARAMS['optimizer'], mode='min', factor=0.5, patience=25, threshold=0.0001)
    TPARAMS['scheduler_psf'] = optim.lr_scheduler.ReduceLROnPlateau(TPARAMS['optimizer_psf'], mode='min', factor=0.5, patience=25, threshold=0.0001)
    
    TPARAMS['model'], TPARAMS['optimizer'], TPARAMS['optimizer_psf'], TPARAMS['trainset_loader'], TPARAMS['testset_loader'], TPARAMS['scheduler'], TPARAMS['scheduler_psf'] = TPARAMS['accelerator'].prepare(
        TPARAMS['model'], TPARAMS['optimizer'], TPARAMS['optimizer_psf'], TPARAMS['trainset_loader'], TPARAMS['testset_loader'], TPARAMS['scheduler'], TPARAMS['scheduler_psf']
    )
    accelerator.print('Train Start!')
    BAR = tqdm(range(HPARAMS['EPOCHS_NUM']), position=0, leave=True, disable=not accelerator.is_local_main_process)
    TPARAMS['accelerator'].log({"Epoch": 0}, step=0)
    
    best_psnr = 0.0
    keep = -1
    for epoch in BAR:
        epoch = epoch + keep + 1
        TPARAMS['epoch_now'] = epoch
        gc.collect()
        torch.cuda.empty_cache()
        
        train_result = train(TPARAMS)
        with torch.no_grad():
            wandb_log(train_result, epoch, 'train', accelerator=TPARAMS['accelerator'])
        
        gc.collect()
        torch.cuda.empty_cache()
        
        test_result = test(TPARAMS)
        BAR.set_description('{0} Epoch - Train Loss : {1:.5}. - Test Loss : {2:.5}.'.format(epoch, train_result["loss_sum"], test_result["loss_sum"]))
        
        with torch.no_grad():
            wandb_log(test_result, epoch, 'test', accelerator=TPARAMS['accelerator'])
        
        # Network save for inference
        save_filename = "/mnt/NAS/homes/DG/weights/{}_{}.pth".format(START_DATE, epoch//100)
        if accelerator.is_main_process:
            torch.save({
                'epoch': epoch,
                'model_state_dict': TPARAMS['accelerator'].unwrap_model(TPARAMS['model']).state_dict(),
                'optimizer_state_dict': TPARAMS['optimizer'].state_dict(),
                'optimizer_psf_state_dict': TPARAMS['optimizer_psf'].state_dict(),
            }, save_filename)
        if test_result['PSNR'] > best_psnr:
            best_psnr = test_result['PSNR']
            save_filename = "/mnt/NAS/homes/DG/weights/{}_{}.pth".format(START_DATE, 'best')
            if accelerator.is_main_process:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': TPARAMS['accelerator'].unwrap_model(TPARAMS['model']).state_dict(),
                    'optimizer_state_dict': TPARAMS['optimizer'].state_dict(),
                }, save_filename)
            accelerator.print(f'Model saved with PSNR: {best_psnr}')
        TPARAMS['accelerator'].log({"Epoch": epoch+1}, step=epoch+1)
        
        if epoch == HPARAMS['EPOCHS_NUM'] - 1:
            break




if __name__ == "__main__":
    main()