import os

import torch
from torch.optim import Adam, Adamax
from torch.functional import F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import numpy as np
from datetime import datetime
import random
import argparse
from copy import deepcopy

from utils.utils import (downsample, upsample, unfreeze, freeze, weights_init_D)
from utils.tb_utils import prepare_imgs_for_plotting, prepare_train_imgs_for_plotting
from utils.fid_score import (get_generated_inception_stats, get_hr_inception_stats, 
                             calculate_frechet_distance)
from utils.distributions import DatasetSampler

from models.ResNet_D import ResNet_D
from models.edsr_G import EDSR
from models.upsample_plus_unet import UNet

def train(D, G, experiment_name, G0_update=True, G_iters=1, G_lr=1e-4, D_lr=1e-4, optimizer='Adam', num_workers=6, batch_size=64, fid_interval=1, plot_interval=1, max_steps=1):
    """Train an OT model."""
    
    D.apply(weights_init_D)
    
    transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    trainA_dataset = datasets.ImageFolder('path_to_tranA_images', transform=transform)
    trainB_dataset = datasets.ImageFolder('path_to_trainB_images', transform=transform)
    test_dataset = datasets.ImageFolder('path_to_test_images', transform=transform)
    
    Y_sampler = DatasetSampler(trainA_dataset)
    X_sampler = DatasetSampler(trainB_dataset)
    test_sampler = DatasetSampler(test_dataset)
    
    Y_train_fixed = Y_sampler.sample(10)
    Y_train_fixed = downsample(Y_train_fixed)
    Y_train_fixed = Y_train_fixed.cuda()
    
    X_test_fixed = test_sampler.sample(10)
    Y_test_fixed = downsample(X_test_fixed)
    X_test_fixed = X_test_fixed.cuda()
    Y_test_fixed = Y_test_fixed.cuda()
    
    if optimizer == 'Adam':
        D_opt = torch.optim.Adam(D.parameters(), lr=D_lr, weight_decay=1e-10)
        G_opt = torch.optim.Adam(G.parameters(), lr=G_lr, weight_decay=1e-10)
    elif optimizer == 'Adamax':
        D_opt = torch.optim.Adamax(D.parameters(), lr=D_lr, weight_decay=1e-10)
        G_opt = torch.optim.Adamax(G.parameters(), lr=G_lr, weight_decay=1e-10)
    
    D_loss_history = []
    G_loss_history = []

    last_plot_step, last_fid_step = -np.inf, 0
    best_fid = np.inf
    
    stats = np.load('path_to_train_hr_inception_stats')
    mu_data, sigma_data = stats['mu'], stats['sigma']

    step = 0 # step of discriminator
    while step < max_steps: # аналог числа эпох
        if step == 0:
            G0 = upsample
        else:
            if G0_update == True:
                freeze(G)
                G0 = deepcopy(G); freeze(G0)
        torch.cuda.empty_cache()
        D_loss_history = []
        # D and G optimization cycle
        for i in range(D_iters):
            # G optimization
            G_loss_history = []
            unfreeze(G); freeze(D)
            for G_iter in range(G_iters):
                Y = Y_sampler.sample(batch_size).cuda()
                Y = downsample(Y)
                with torch.no_grad():
                    up_Y = G0(Y) # upsampled LR
                G_opt.zero_grad()
                G_Y = G(Y)
                G_loss = .5 * F.mse_loss(G_Y, up_Y) - D(G_Y).mean()
                G_loss.backward(); G_opt.step();
                G_loss_history.append(G_loss.item())
                del G_loss, G_Y, up_Y, Y; torch.cuda.empty_cache()
            writer.add_scalar('G loss', np.sum(G_loss_history), step)
            del G_loss_history

            # D optimization
            freeze(G); unfreeze(D);
            
            X = X_sampler.sample(batch_size).cuda()
            Y = Y_sampler.sample(batch_size).cuda()
            Y = downsample(Y)
            with torch.no_grad():
                G_Y = G(Y)
            D_opt.zero_grad()
            D_loss = D(G_Y).mean() - D(X).mean()
            D_loss.backward(); D_opt.step();
            writer.add_scalar('D loss', D_loss.item(), step)
            del D_loss, Y, X, G_Y; torch.cuda.empty_cache()
            
            step += 1 # increase step
            
            if step >= last_plot_step + plot_interval:
                last_plot_step = step
                
                Y_train_random = Y_sampler.sample(10)
                Y_train_random = downsample(Y_train_random)
                Y_train_random = Y_train_random.cuda()
                
                X_test_random = test_sampler.sample(10)
                Y_test_random = downsample(X_test_random)
                X_test_random = X_test_random.cuda()
                Y_test_random =  Y_test_random.cuda()
    
                test_fixed_images = prepare_imgs_for_plotting(Y_test_fixed, X_test_fixed, upsample, G)
                test_random_images = prepare_imgs_for_plotting(Y_test_random, X_test_random, upsample, G)
                train_fixed_images = prepare_train_imgs_for_plotting(Y_train_fixed, upsample, G)
                train_random_images = prepare_train_imgs_for_plotting(Y_train_random, upsample, G)
                
                writer.add_images('test fixed images: upsampled vs generated vs GT', 
                              test_fixed_images, step, dataformats='HWC')
                writer.add_images('test random images: upsampled vs generated vs GT', 
                              test_random_images, step, dataformats='HWC')
                writer.add_images('train fixed images: upsampled vs generated vs GT', 
                              train_fixed_images, step, dataformats='HWC')
                writer.add_images('train random images: upsampled vs generated vs GT', 
                              train_random_images, step, dataformats='HWC')
                del Y_train_random, X_test_random, Y_test_random

            if step >= last_fid_step + fid_interval:
                last_fid_step = step

                m, s = get_generated_inception_stats(G, dataset_name='train', batch_size=50)
                FID_G = calculate_frechet_distance(m, s, mu_data, sigma_data)
                writer.add_scalar('FID_G', FID_G, step)
                del m, s;  torch.cuda.empty_cache()

                if FID_G < best_fid:
                    best_fid = FID_G
                    freeze(G); freeze(D)
                    best_D_state_dict = D.state_dict()
                    best_G_state_dict = G.state_dict()
                    torch.save(best_D_state_dict, os.path.join('directory_to_save_runs', experiment_name+'/', 'best_state_D.pt'))
                    torch.save(best_G_state_dict, os.path.join('directory_to_save_runs', experiment_name+'/', 'best_state_G.pt'))
        

if __name__=='__main__':
    parser = argparse.ArgumentParser(prefix_chars='--')
    parser.add_argument('--G_arch', type=str, default='EDSR',
                        help='Generator architecture: EDSR or UNet.')
    parser.add_argument('--n_resblocks', type=int, default=64,
                        help='Number of residual blocks in EDSR.')
    parser.add_argument('--n_feats', type=int, default=128,
                        help='Number of feature maps in EDSR.')
    parser.add_argument('--res_scale', type=float, default=1,
                        help='Residual scaling in EDSR.')
    parser.add_argument('--UNet_upsample', type=str, default='bilinear',
                        help='Data upsample layer in UNet arch: bicubic or bilinear.')
    parser.add_argument('--base_factor', type=int, default=64,
                        help='Base factor in UNet.')
    parser.add_argument('--G0_update', type=bool, default=True, 
                        help='Update cost on each step or not.')
    parser.add_argument('--bs', type=int, default=64,
                        help='Batch size.')
    parser.add_argument('--scale_factor', type=int, default=4, 
                        help='Scale factor')
    parser.add_argument('--steps', type=int, default=100000,
                        help='Maximum steps.')
    parser.add_argument('--D_lr', type=float, default=1,
                        help='D learning rate * 10**4.')
    parser.add_argument('--G_lr', type=float, default=1,
                        help='G learning rate * 10**4.')
    parser.add_argument('--D_iters', type=int, default=25000,
                        help='Number of D steps before cost update.')
    parser.add_argument('--G_iters', type=int, default=15,
                        help='Number of G steps per one D step.')
    parser.add_argument('--opt', type=str, default='Adam',
                        help='Adam or Adamax.')
    parser.add_argument('--n_workers', default=20, type=int)
    parser.add_argument('--fid', type=int, default=2000,
                        help='FID interval.')
    parser.add_argument('--plot', type=int, default=200,
                        help='Plot interval.')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed.')
    args = parser.parse_args()
    
    G0_update = args.G0_update
    batch_size = args.bs
    scale_factor = args.scale_factor
    max_steps = args.steps
    D_lr = args.D_lr / 10**4
    G_lr = args.G_lr / 10**4
    D_iters = args.D_iters
    G_iters = args.G_iters
    G_arch = args.G_arch
    optimizer = args.opt    
    num_workers = args.n_workers
    fid_interval = args.fid                                      
    plot_interval = args.plot
    seed = args.seed
    
    today = datetime.today()
    md = today.strftime('%m%d')
    hm = today.strftime('%H%M%S')
    
    if args.G_arch == 'EDSR':
        G_name = 'EDSRr%df%ds%.3f'%(args.n_resblocks, args.n_feats, args.res_scale)
    elif args.G_arch == 'UNet':
        G_name = 'UNet%d'%(args.base_factor, args.UNet_upsample)
    
    experiment_name = ('%s/v%s_%s_G0upd%s_bs%d_steps%d_Dlr%f_Glr%f_Diters%d_Giters%d_opt%s_nworkers%d_fid%d_plot%d_seed%d'%(md, hm, G_name, str(G0_update), batch_size, max_steps, D_lr, G_lr, D_iters, G_iters, optimizer, num_workers, fid_interval, plot_interval, seed))                                                  
    writer = SummaryWriter(log_dir='directory_to_save_runs'+experiment_name)
    
    seed = args.seed
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"# or ":16:8"
    
    if G_arch == 'EDSR':
        G = EDSR(n_resblocks=args.n_resblocks, n_feats=args.n_feats, res_scale=args.res_scale, 
                 scale_factor=args.scale_factor).cuda()
    elif G_arch == 'UNet':
        G = UNet(3, 3, scale_factor=args.scale_factor, base_factor=args.base_factor, upsample=args.UNet_upsample).cuda()
    D = ResNet_D(size=64).cuda()
    
    train(D, G, experiment_name, G0_update=G0_update, G_iters=G_iters, G_lr=G_lr, D_lr=D_lr, 
          optimizer=optimizer, num_workers=num_workers, batch_size=batch_size, fid_interval=fid_interval, 
          plot_interval=plot_interval, max_steps=max_steps)
    
    writer.close()