import argparse
import os
import json
from datetime import datetime
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torch
from torch import nn
from torch.optim import Adam, AdamW
import csv

from feedforward import FF, FF_MoCap
from utils.data import VF_Dataset, MoCap_Dataset

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', type=str, choices=['uwhvf', 'break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
parser.add_argument('--save_dir', type=str, default='./saved_pattern_models', help='Directory to save training results in.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--gpu', type=int, default=0, help='GPU.')
parser.add_argument('--save_tag', type=str, default=None, help='Tag to identify saved training information.')

# hyperparameters
parser.add_argument('--batch_size', type=int, default=64, help='Batch size.')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
parser.add_argument('--reg', type=float, default=0, help='L2 Regularization.')
parser.add_argument('--epochs', type=int, default=500, help='Epochs.')
parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimension.')
parser.add_argument('--time_dim', type=int, default=16, help='Dimension of time embedding.')
parser.add_argument('--loss_fn', type=str, choices=['mse', 'mae', 'kld'], default='kld', help='Dataset.')
parser.add_argument('--delta', type=float, default=0, help='Early stopping threshold.')
parser.add_argument('--patience', type=int, default=20, help='Early stopping patience.')

# conditioning inputs
parser.add_argument('--n_frames', type=int, default=3, help='Number of frames required for prediction.')
parser.add_argument('--n_horizon', type=int, default=1, help='Number of frames to predict.')

args = parser.parse_args()

class KLDLoss(nn.Module):
    def __init__(self, reduction='none'):
        super(KLDLoss, self).__init__()

        self.reduction = reduction

    # true distribution P, estimated distribution Q
    def forward(self, Q, P):
        # add small value to avoid undefined values
        P_ = P + torch.finfo(torch.float).tiny #1e-10
        Q_ = Q + torch.finfo(torch.float).tiny #1e-10
        Plog_ = P_.log()
        Qlog_ = Q_.log()
        loss = (P_ * (Plog_ - Qlog_)).sum(dim=1)
        if self.reduction == 'none':
            return loss
        elif self.reduction == 'mean':
            return loss.mean()
        else:
            raise NotImplementedError
        
# https://medium.com/biased-algorithms/a-practical-guide-to-implementing-early-stopping-in-pytorch-for-model-training-99a7cbd46e9d
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.no_improvement_count = 0
        self.stop_training = False
    
    def check_early_stop(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.no_improvement_count = 0
        else:
            self.no_improvement_count += 1
            if self.no_improvement_count >= self.patience:
                self.stop_training = True
                if self.verbose:
                    print("Stopping early as no improvement has been observed.")

        
if __name__ == '__main__':
    # set up save paths
    args.save_dir = f'{args.save_dir}_{args.dataset}'
    os.makedirs(args.save_dir, exist_ok=True)

    if args.save_tag is None:
        save_name = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
    else:
        save_name = args.save_tag
    save_name += f'_e{args.epochs}_b{args.batch_size}_lr{args.lr:.0e}_h{args.hidden_dim}_l{args.loss_fn}_r{args.reg}_t{args.time_dim}_p{args.patience}'
    os.makedirs(os.path.join(args.save_dir, save_name), exist_ok=True)

    log_dir = os.path.join(args.save_dir, 'runs')
    log_pth = os.path.join(log_dir, save_name)
    writer = SummaryWriter(log_pth)

    info = []
    for a in vars(args):
        info.append(f'{a}: {getattr(args,a)}')
    
    with open(os.path.join(args.save_dir, f'{save_name}_info.csv'), 'w') as csvfile:
        wr = csv.writer(csvfile, delimiter='\n')
        wr.writerow(info)

    # more setup
    device = torch.device(f"cuda:{str(args.gpu)}" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # dataset
    if 'uwhvf' in args.dataset:
        data_range = (-37.69, 50.00)
    elif 'break' in args.dataset:
        data_range = (-217.49, 286.76)
    elif 'street_jazz' in args.dataset:
        data_range = (-102.57, 101.90)
    elif 'ballet_jazz' in args.dataset:
        data_range = (-308.86, 327.77)
    elif 'middle_hip_hop' in args.dataset:
        data_range = (-91.20, 92.65)
    elif 'house' in args.dataset:
        data_range = (-78.08, 81.47)
    elif 'krump' in args.dataset:
        data_range = (-85.32, 110.89)
    elif 'la_hip_hop' in args.dataset:
        data_range = (-77.41, 100.26)
    elif 'lock' in args.dataset:
        data_range = (-109.02, 95.57)
    elif 'pop' in args.dataset:
        data_range = (-132.20, 103.33)
    elif 'wack' in args.dataset:
        data_range = (-283.42, 413.92)
    else:
        raise NotImplementedError
    
    if args.dataset in ['uwhvf']:
        data_dir = f'datasets/{args.dataset}_seqp_{args.n_frames}'
    else:
        data_dir = f'datasets/aistpp_seqp_aa/{args.dataset}/h{args.n_frames}_H{args.n_horizon}'
    with open(os.path.join(data_dir, 'train.json')) as f:
        train_data = json.loads(f.read())
    with open(os.path.join(data_dir, 'val.json')) as f:
        val_data = json.loads(f.read())

    if args.dataset in ['uwhvf']:
        train_dataset = VF_Dataset(train_data, representation='aa', normalize=False, data_range=data_range)
        val_dataset = VF_Dataset(val_data, representation='aa', normalize=False, data_range=data_range)
    else:
        train_dataset = MoCap_Dataset(train_data, representation='aa', normalize=False, data_range=data_range)
        val_dataset = MoCap_Dataset(val_data, representation='aa', normalize=False, data_range=data_range)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True)
    
    # train setup
    if args.dataset in ['uwhvf']:
        n_archetypes = len(train_data['data'][list(train_data['data'].keys())[0]]['followup_aa'][0])
        model = FF(n_frames=args.n_frames,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes,
                dim=args.time_dim,
                ).to(device)
    else:
        n_archetypes = len(train_data['data'][list(train_data['data'].keys())[0]]['horizon_aa'][0])
        model = FF_MoCap(n_frames=args.n_frames,
                n_horizon=args.n_horizon,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes*args.n_horizon,
                ).to(device)
    if args.reg == 0: # for consistency of results. otherwise, just use AdamW
        optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.reg)
    else:
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.reg)
    if args.loss_fn == 'mse':
        criterion = nn.MSELoss(reduction='mean')
    elif args.loss_fn == 'mae': # MAE is proportional to TVD
        criterion = nn.L1Loss(reduction='mean')
    elif args.loss_fn == 'kld':
        criterion = KLDLoss(reduction='mean')
    else:
        raise NotImplementedError
    
    early_stopping = EarlyStopping(patience=args.patience, delta=args.delta, verbose=True)
    
    # training loop
    best_loss = np.inf
    best_epoch = 0
    for epoch in tqdm(range(args.epochs)):
        # train
        model.train()
        batch_loss = 0
        for batch_idx, (inputs, x_f, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()

            if args.dataset in ['uwhvf']:
                baselines = torch.squeeze(inputs[0]).to(device).float() # (B, h, 1, n_archs) -> (B, h, n_archs)
                start_age = inputs[1].reshape(-1,1).to(device).float() # (B, 1)
                age_deltas = torch.unsqueeze(inputs[2],-1).to(device).float() # (B, h+H, 1)
                pred_alphas = model(baselines, start_age, age_deltas) # (B, n_archs)
            else:
                inp = inputs[0].reshape(-1,n_archetypes*args.n_frames).to(device)
                pred_alphas = model(inp) # (B, h*n_archs) -> (B, H*n_archs)

            x_f = np.squeeze(x_f).to(device).float() # (B, 1, 1, 17) -> (B, 17)

            loss = criterion(pred_alphas, x_f) # (B, 17)
            batch_loss += loss.item()

            loss.backward()
            optimizer.step()

        ave_train_loss = batch_loss/len(train_loader)
        writer.add_scalar(f"Loss/train", ave_train_loss, epoch+1)

        print(f'\tEpoch {epoch + 1} complete!\tTraining Loss {ave_train_loss}')

        # val
        model.eval()
        with torch.no_grad():
            batch_loss = 0
            for batch_idx, (inputs, x_f, _) in tqdm(enumerate(val_loader), total=len(val_loader)):
                if args.dataset in ['uwhvf']:
                    baselines = torch.squeeze(inputs[0]).to(device).float()
                    start_age = inputs[1].reshape(-1,1).to(device).float()
                    age_deltas = torch.unsqueeze(inputs[2],-1).to(device).float()
                    pred_alphas = model(baselines, start_age, age_deltas)
                else:
                    inp = inputs[0].reshape(-1,n_archetypes*args.n_frames).to(device)
                    pred_alphas = model(inp)

                x_f = np.squeeze(x_f).to(device).float()                                                # (B, 1, 1, 17) -> (B, 17)

                loss = criterion(pred_alphas, x_f)
                batch_loss += torch.mean(loss)

            ave_val_loss = batch_loss/len(val_loader)
            writer.add_scalar(f"Loss/val", ave_val_loss, epoch+1)

            print(f'\t\t\tValidation Loss {ave_val_loss}')

            if ave_val_loss < best_loss:
                best_loss = ave_val_loss
                best_epoch = epoch
                torch.save(model.state_dict(), os.path.join(args.save_dir, save_name, f'best_epoch.pth'))
                print('\t\t\tA new best model saved at epoch {}!\n'.format(epoch + 1))
    
        early_stopping.check_early_stop(ave_val_loss)
        if early_stopping.stop_training:
            print(f"Early stopping at epoch {epoch}. Best epoch {best_epoch} with best loss {best_loss}.")
            break

    print(f'Finish!! Best epoch {best_epoch}')