import time, os
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm


def MAE_(fake, real):
    mae = 0.0
    mae = np.mean(np.abs(fake-real))
    return mae

def Norm(a):
    max_ = torch.max(a)
    min_ = torch.min(a)
    a_0_1 = (a-min_)/(max_-min_)
    return (a_0_1-0.5)*2

opt = TrainOptions().parse()

dataset = create_dataset(opt, phase="train")
dataset_size = len(dataset)
print('#training images = %d' % dataset_size)

val_dataset = create_dataset(opt, phase="val")
val_dataset_size = len(val_dataset)
print('#validation images = %d' % val_dataset_size)

model = create_model(opt)
model.setup(opt)
total_steps = 0
val_total_iters = 0 

log_file_path = f'./log/{opt.name}/best_metrics.txt'
train_val_file_path = f'./log/{opt.name}/train_val_metrics.txt'
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
os.makedirs(os.path.dirname(train_val_file_path), exist_ok=True)

global_mae = 1000.00
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
    epoch_start_time = time.time()
    iter_start_time = time.time()
    epoch_iter = 0
    
    if "adap" in opt.name:
        model.update_weight_alpha()
    
    train_pbar = tqdm(dataset, desc=f'Epoch {epoch} Training', leave=True)
    for data in train_pbar:
        total_steps += opt.batch_size
        epoch_iter += opt.batch_size
        model.set_input(data)
        model.optimize_parameters(dataset, epoch_iter)

        if total_steps % opt.display_freq == 0:
            if "grad" in opt.name:
                grads = model.get_current_grads()

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batch_size
            train_pbar.set_postfix({
            'G_Loss': f"{model.loss_G:.4f}",
            'D_Loss': f"{model.loss_D:.4f}",
            'G_GAN': f"{errors['G_GAN']:.4f}",
            'G_L1': f"{errors['G_L1']:.4f}",
            'G_tv': f"{errors['G_loss_tv']:.4f}",
            'G_L1_seg': f"{errors['G_L1_seg']:.4f}",
            'G_ilm': f"{errors['G_ilm']:.4f}",
            'G_bm': f"{errors['G_bm']:.4f}",
            'proj': f"{errors['proj']:.4f}",
            'D_real': f"{errors['D_real']:.4f}",
            'D_fake': f"{errors['D_fake']:.4f}"
        })

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)
    
    if epoch % opt.val_epoch_freq == 0:
        device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 
        with torch.no_grad():
            MAE = 0
            num = 0
            total_psnr = 0
            total_ssim = 0
            
            val_pbar = tqdm(val_dataset, desc=f'Epoch {epoch} Validation', leave=True)
            for data in val_pbar:
                AtoB = opt.direction == 'AtoB'
                real_A = data['A' if AtoB else 'B'].to(device,dtype=torch.float)
                real_B = data['B' if AtoB else 'A'].to(device,dtype=torch.float).detach().cpu().numpy()

                fake_B = model.netG(real_A).detach().cpu().numpy()
                mae = MAE_(fake_B,real_B)
                
                MAE += mae
                num += 1
                
                val_pbar.set_postfix({
                    'MAE': f"{MAE/num:.4f}"
                })

            print('Val MAE:', MAE/num)

        if epoch % 2 == 0:
            with open(train_val_file_path, 'a') as log_file:
                log_file.write(f"Epoch: {epoch}, "
                             f"training G_Loss: {model.loss_G:.4f}, "
                             f"Training D_Loss: {model.loss_D:.4f}")
                log_file.write(f"Val:  "
                             f"Validation MAE: {MAE/num:.4f}\n\n")

        if MAE/num <= global_mae:
            global_mae = MAE/num
            print('saving the current best model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.save('best')
            model.save(epoch)
            print("saving best...")

            with open(log_file_path, 'a') as log_file:
                log_file.write(f"Epoch: {epoch}, "
                                f"Best MAE: {MAE/num:.4f}\n")

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))

    if epoch > opt.n_epochs:
        model.update_learning_rate()



