import argparse
import os
import csv
from tqdm import tqdm
import torch
from torch import nn
from torch.optim import Adam
import numpy as np
import json
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import pickle
from kornia.filters import sobel
from transformers import get_cosine_schedule_with_warmup

from feedforward import FF, FF_MoCap
from unet import Unet, Unet_MoCap
from diffusion import Diffusion, Diffusion_MoCap
from utils.data import VF_Dataset, MoCap_Dataset
from utils.utils import masked_loss

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', type=str, choices=['uwhvf', 'break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
parser.add_argument('--representation', type=str, choices=['hvf','td','kp_norm','kp_raw'], default='td', help='Representation of data.')
parser.add_argument('--save_dir', type=str, default='./saved_pgdm_models', help='Directory to save training results in.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--gpu', type=int, default=0, help='GPU.')
parser.add_argument('--save_tag', type=str, default=None, help='Tag to identify saved training information.')
parser.add_argument('--start_epoch', type=int, default = 0, help='Starting epoch if starting from checkpoint.')
parser.add_argument('--checkpoint_dir', type=str, default=None, help='Directory to load checkpoint from.')

# hyperparameters
parser.add_argument('--loss_fn', type=str, choices=['mse', 'mae'], default='mse', help='Training loss function.')
parser.add_argument('--target', type=str, choices=['x0', 'noise'], default='noise', help='Which target to use when training.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate.')
parser.add_argument('--epochs', type=int, default=1000, help='Epochs.')
parser.add_argument('--lmbda', type=float, default=0, help='Regularizer scaling.')
parser.add_argument('--unet_dim', type=int, default=20, help='Dimensions for UNet')
parser.add_argument('--lr_schedule', action='store_true')

# conditioning inputs
parser.add_argument('--n_frames', type=int, default=3, help='Number of frames required for prediction.')
parser.add_argument('--n_horizon', type=int, default=1, help='Number of frames to predict.')
parser.add_argument('--time_dim', type=int, default=16, help='Dimension of time embedding.')
parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimensions of pattern prediction model.')
parser.add_argument('--pattern_model_pth', type=str, help='Pattern prediction model.')
parser.add_argument('--p_uncond', type=float, default=0.2, help='Probability of unconditional generation')

args = parser.parse_args()

if __name__ == '__main__':
    # set up save paths
    args.save_dir = f'{args.save_dir}_{args.dataset}'
    os.makedirs(args.save_dir, exist_ok=True)
    
    if args.start_epoch == 0: # training from scratch
        if args.save_tag is None:
            save_name = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
        else:
            save_name = args.save_tag
        save_name += f'_e{args.epochs}_b{args.batch_size}_lr{args.lr:.0e}_l{args.lmbda}_d{args.unet_dim}_L{args.loss_fn}'
        if args.lr_schedule:
            save_name += '_sched'
        os.makedirs(os.path.join(args.save_dir, save_name), exist_ok=True)
    else:
        save_name = os.path.basename(args.checkpoint_dir)

    log_dir = os.path.join(args.save_dir, 'runs')
    log_pth = os.path.join(log_dir, save_name)
    writer = SummaryWriter(log_pth)

    info = []
    for a in vars(args):
        info.append(f'{a}: {getattr(args,a)}')
    
    with open(os.path.join(args.save_dir, f'{save_name}_info.csv'), 'w')as csvfile:
        wr = csv.writer(csvfile, delimiter='\n')
        wr.writerow(info)
    
    # more setup
    device = torch.device(f"cuda:{str(args.gpu)}" if torch.cuda.is_available() else "cpu")
    img_size = (1,10,10)
    beta_minmax = [1e-4, 2e-2]

    with open(f"archetypes/{args.dataset}_aa_object.pkl", 'rb') as f:
        aa = pickle.load(f)

    # seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # dataset
    if 'uwhvf' in args.dataset:
        data_range = (-37.69, 50.00)
    elif 'break' in args.dataset:
        data_range = (-217.49, 286.76)
    elif 'street_jazz' in args.dataset:
        data_range = (-102.57, 101.90)
    elif 'ballet_jazz' in args.dataset:
        data_range = (-308.86, 327.77)
    elif 'middle_hip_hop' in args.dataset:
        data_range = (-91.20, 92.65)
    elif 'house' in args.dataset:
        data_range = (-78.08, 81.47)
    elif 'krump' in args.dataset:
        data_range = (-85.32, 110.89)
    elif 'la_hip_hop' in args.dataset:
        data_range = (-77.41, 100.26)
    elif 'lock' in args.dataset:
        data_range = (-109.02, 95.57)
    elif 'pop' in args.dataset:
        data_range = (-132.20, 103.33)
    elif 'wack' in args.dataset:
        data_range = (-283.42, 413.92)
    else:
        raise NotImplementedError

    if args.dataset in ['uwhvf']:
        data_dir = f'datasets/{args.dataset}_seqp_{args.n_frames}'
    else:
        data_dir = f'datasets/aistpp_seqp_aa/{args.dataset}/h{args.n_frames}_H{args.n_horizon}'
    with open(os.path.join(data_dir, 'train.json')) as f:
        train_data = json.loads(f.read())
    with open(os.path.join(data_dir, 'val.json')) as f:
        val_data = json.loads(f.read())
    
    if args.dataset in ['uwhvf']:
        train_dataset = VF_Dataset(train_data, representation=args.representation, normalize=True, data_range=data_range)
        val_dataset = VF_Dataset(val_data, representation=args.representation, normalize=True, data_range=data_range)
        field_mask = torch.from_numpy(train_dataset.get_field_mask().astype(int)).to(device)
    else:
        train_dataset = MoCap_Dataset(train_data, representation=args.representation, normalize=True, data_range=data_range)
        val_dataset = MoCap_Dataset(val_data, representation=args.representation, normalize=True, data_range=data_range)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                batch_size=args.batch_size,
                                shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True)
    
    # Create archetype dynamics model
    n_archetypes = aa.n_archetypes_
    archetypes = torch.FloatTensor(aa.archetypes_).to(device)
    if args.dataset in ['uwhvf']:
        dynamics_model = FF(n_frames=args.n_frames,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes,
                dim=args.time_dim,
                ).to(device)
    else:
        dynamics_model = FF_MoCap(n_frames=args.n_frames,
                n_horizon=args.n_horizon,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes*args.n_horizon,
                ).to(device)
    dynamics_model.load_state_dict(torch.load(args.pattern_model_pth))
    dynamics_model.eval()

    # Create reverse diffusion (denoising) neural network and accompanying diffusion model.
    if args.dataset in ['uwhvf']:
        denoiser = Unet(dim=args.unet_dim,
                    channels=1,
                    dim_mults=(1, 2,),
                    resnet_block_groups=4,
                    n_frames=args.n_frames,
                    ).to(device)
        diffusion = Diffusion(denoiser, image_resolution=img_size, n_times=1000,
                            beta_minmax=beta_minmax, device=device, target=args.target).to(device)
    else:
        denoiser = Unet_MoCap(dim=args.unet_dim,
                    channels=1,
                    dim_mults=(1, 2,),
                    resnet_block_groups=4,
                    n_frames=args.n_frames,
                    out_dim=args.n_horizon,
                    ).to(device)
        diffusion = Diffusion_MoCap(denoiser, image_resolution=img_size, n_times=1000,
                            beta_minmax=beta_minmax, device=device, target=args.target).to(device)
    
    # load checkpoint if not starting from epoch 0
    if args.start_epoch > 0:
        diffusion.model.load_state_dict(torch.load(os.path.join(args.checkpoint_dir, f'epoch_{args.start_epoch}.pth')))
        args.start_epoch += 1 # start training at next epoch

    optimizer = Adam(diffusion.parameters(), lr=args.lr)
    if args.lr_schedule:
        num_training_steps = args.epochs * len(train_loader.dataset)/args.batch_size
        num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup
        print(f'{num_warmup_steps} warmup steps')
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )
    if args.loss_fn == 'mse':
        denoising_loss = nn.MSELoss(reduction='none')
        criterion = nn.MSELoss(reduction='mean')
    elif args.loss_fn == 'mae':
        denoising_loss = nn.L1Loss(reduction='none')
        criterion = nn.L1Loss(reduction='mean')
    else:
        raise NotImplementedError
    
    print('Start training DDPM...')
    save_period = 100
    for epoch in range(args.start_epoch, args.epochs):
        # train
        denoiser.train()
        noise_prediction_loss = 0
        null_loss, pattern_loss, n_null, n_pattern = 0, 0, 0, 0
        for batch_idx, (input, output, _,) in tqdm(enumerate(train_loader), total=len(train_loader)):
            if args.dataset in ['uwhvf']:
                x_b, start_age, deltas, baseline_aa = input
                x_b = input[0].to(device).float() # B, args.n_frames, 10, 10
                start_age = input[1].reshape(-1,1).to(device).float() # B, 1
                deltas = torch.unsqueeze(input[2],-1).to(device).float() # B, args.n_frames+1, 1
                baseline_aa = torch.squeeze(input[3]).to(device).float() # B, args.n_frames, n_archetypes

                x_f = output.to(device).float()
            else:
                x_b = input[0].to(device).float() # B, args.n_frames, 17, 3
                baseline_aa = torch.squeeze(input[2]).to(device).float() # B, args.n_frames, n_archetypes
                rot_history = input[1].to(device).float() # B, args.n_frames, 2

                pos_horizon = output[0].to(device).float() # B, H, 17, 3            
                rot_horizon = output[1].to(device).float() # B, H, 2
                x_f = torch.concatenate((pos_horizon.reshape(-1, args.n_horizon, 17*3), rot_horizon), dim=-1) # B, H, 17*3 + 2

            # create archetype conditioning
            if torch.rand(1) < args.p_uncond:
                if args.dataset in ['uwhvf']:
                    null_pattern = torch.zeros_like(x_f)
                else:
                    pattern = torch.zeros_like(pos_horizon)
                uncond = True
            else:
                if args.dataset in ['uwhvf']:
                    pred_aa = dynamics_model(baseline_aa, start_age, deltas)
                    recon = pred_aa @ archetypes
                    pattern = torch.zeros_like(x_f)
                    pattern[:,0,(field_mask > 0)] = recon[:,:]
                else:
                    pred_aa = dynamics_model(baseline_aa.reshape(-1, args.n_frames*n_archetypes))
                    pred_aa = pred_aa.reshape(-1, args.n_horizon, n_archetypes)
                    pattern = pred_aa @ archetypes
                uncond = False
            
            optimizer.zero_grad()

            # add noise and predict noise to be removed
            if args.dataset in ['uwhvf']:
                noisy_input, gt_eps, pred_eps = diffusion(x_f, x_c=x_b, start_age=start_age, age_deltas=deltas, pattern=pattern)
                loss = masked_loss(pred_eps, gt_eps, denoising_loss, field_mask, reduction='mean')
            else:
                noisy_input, gt_eps, pred_eps = diffusion(x_f, x_c=x_b, rot=rot_history, pattern=pattern)
                loss = criterion(pred_eps, gt_eps)            

            noise_prediction_loss += loss.item()
            if uncond:
                null_loss += loss.item()
                n_null += 1
            else:
                pattern_loss += loss.item()
                n_pattern += 1

            loss.backward()
            optimizer.step()
            if args.lr_schedule:
                scheduler.step()

        writer.add_scalar(f"Loss/train",noise_prediction_loss/batch_idx, epoch+1)
        writer.add_scalar(f"Loss/train_null",null_loss/n_null, epoch+1)
        writer.add_scalar(f"Loss/train_pattern",pattern_loss/n_pattern, epoch+1)

        if args.lr_schedule:
            current_lr = scheduler.get_last_lr()[0]
            print(f"Epoch {epoch+1}: lr = {current_lr:.6f}")
        print(f'\tEpoch {epoch + 1} complete!\tTraining denoising Loss {noise_prediction_loss/batch_idx}')

        # val
        denoiser.eval()
        with torch.no_grad():
            null_noise_prediction_loss, pattern_noise_prediction_loss = 0, 0
            for batch_idx, (input, output, _) in tqdm(enumerate(val_loader), total=len(val_loader)):
                if args.dataset in ['uwhvf']:
                    x_b, start_age, deltas, baseline_aa = input
                    x_b = input[0].to(device).float() # B, args.n_frames, 10, 10
                    start_age = input[1].reshape(-1,1).to(device).float() # B, 1
                    deltas = torch.unsqueeze(input[2],-1).to(device).float() # B, args.n_frames+1, 1
                    baseline_aa = torch.squeeze(input[3]).to(device).float() # B, args.n_frames, n_archetypes

                    x_f = x_f.to(device).float() # B, 1, 10, 10
                else:
                    x_b = input[0].to(device).float() # B, args.n_frames, 17, 3
                    baseline_aa = torch.squeeze(input[2]).to(device).float() # B, args.n_frames, n_archetypes
                    rot_history = input[1].to(device).float() # B, args.n_frames, 2

                    pos_horizon = output[0].to(device).float() # B, H, 17, 3            
                    rot_horizon = output[1].to(device).float() # B, H, 2
                    x_f = torch.concatenate((pos_horizon.reshape(-1, args.n_horizon, 17*3), rot_horizon), dim=-1)

                # create archetype conditioning
                null_pattern = torch.zeros_like(pos_horizon)

                if args.dataset in ['uwhvf']:
                    pred_aa = dynamics_model(baseline_aa, start_age, deltas)
                    recon = pred_aa @ archetypes
                    aa_pattern = torch.zeros_like(x_f)
                    aa_pattern[:,0,(field_mask > 0)] = recon[:,:]
                else:
                    pred_aa = dynamics_model(baseline_aa.reshape(-1, args.n_frames*n_archetypes))
                    pred_aa = pred_aa.reshape(-1, args.n_horizon, n_archetypes)
                    aa_pattern = pred_aa @ archetypes

                # add noise and predict noise to be removed
                if args.dataset in ['uwhvf']:
                    null_noisy_input, null_gt_eps, null_pred_eps = diffusion(x_f, x_c=x_b, start_age=start_age, age_deltas=deltas, pattern=null_pattern)
                    pattern_noisy_input, pattern_gt_eps, pattern_pred_eps = diffusion(x_f, x_c=x_b, start_age=start_age, age_deltas=deltas, pattern=aa_pattern)
                    null_loss = masked_loss(null_pred_eps, null_gt_eps, denoising_loss, field_mask, reduction='mean')
                    pattern_loss = masked_loss(pattern_pred_eps, pattern_gt_eps, denoising_loss, field_mask, reduction='mean')
                else:
                    null_noisy_input, null_gt_eps, null_pred_eps = diffusion(x_f, x_c=x_b, rot=rot_history, pattern=null_pattern)
                    pattern_noisy_input, pattern_gt_eps, pattern_pred_eps = diffusion(x_f, x_c=x_b, rot=rot_history, pattern=aa_pattern)
                    null_loss = criterion(null_pred_eps, null_gt_eps)
                    pattern_loss = criterion(pattern_pred_eps, pattern_gt_eps)

                null_noise_prediction_loss += null_loss.item()
                pattern_noise_prediction_loss =+ pattern_loss.item()

            writer.add_scalar(f"Loss/val_null",null_noise_prediction_loss/batch_idx, epoch+1)
            writer.add_scalar(f"Loss/val_pattern",pattern_noise_prediction_loss/batch_idx, epoch+1)

            print(f'\t\t\tValidation denoising loss: null {null_noise_prediction_loss/batch_idx}, pattern {pattern_noise_prediction_loss/batch_idx}')

        if (epoch+1) % save_period == 0:
            torch.save(denoiser.state_dict(), os.path.join(args.save_dir, save_name, f'epoch_{epoch}.pth'))
            print(f"Weights saved to {os.path.join(args.save_dir, save_name, f'epoch_{epoch}.pth')}")
    
    print('Finish!!')

    if (args.epochs+1) % save_period != 0:
        torch.save(denoiser.state_dict(), os.path.join(args.save_dir, save_name, f'epoch_{args.epochs}.pth'))
        print(f"Weights saved to {os.path.join(args.save_dir, save_name, f'epoch_{args.epochs}.pth')}")