import torch
import os
import numpy as np
import random
import json
from datetime import datetime
from torch import nn
from tqdm import tqdm
from dng_dataset import ViTGraphDataset
from dng_models import ViT_Gen_Predictor
from torch.utils.data import DataLoader
from sklearn.metrics import r2_score
from scipy.stats import kendalltau
import argparse
import warnings
warnings.filterwarnings("ignore")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

@torch.no_grad()
def eval(model, dl, criterion, device):
    orig_state = model.training
    model.eval()
    pred, actual = [], []
    err, losses = [], []
    with tqdm(total=len(dl)) as pbar:
        for data, acc in dl:
            data = (data[0].to(device), data[1].to(device))
            acc = acc.float().to(device)
            with torch.no_grad():
                pred_acc = model(data).squeeze(-1)
            err.append(torch.abs(pred_acc - acc).mean().item())
            losses.append(criterion(pred_acc, acc).item())
            pred.append(pred_acc.detach().cpu().numpy())
            actual.append(acc.cpu().numpy())
            pbar.update(1)
    avg_err, avg_loss = np.mean(err), np.mean(losses)
    actual, pred = np.concatenate(actual), np.concatenate(pred)
    rsq = r2_score(actual, pred)
    tau = kendalltau(actual, pred).correlation  # NOTE: on newer scipy this is called "statistic"
    model.train(orig_state)
    return avg_err, avg_loss, rsq, tau, actual, pred

def main(args):
    seed = args.seed 
    setup_seed(seed)
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    batch_size = args.batch_size
    lr = args.lr
    epochs = args.epoch
    
    fourier_dim = args.f_dim
    fourier_scale = args.f_scale
    rnn_mode = args.rnn_mode
    emb_dim = args.n_dim
    head_dim = args.head_dim
    head_drop = args.head_drop

    loss_fn = args.loss_fn
    sigmoid = args.sigmoid

    current_time = datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = f'./dng_predict_gen_vit_models/{current_time}/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    data_root = './data/cifar10_vit'
    train_ds = ViTGraphDataset(data_root, 'train')
    val_ds = ViTGraphDataset(data_root, 'val')
    test_ds = ViTGraphDataset(data_root, 'test')
    print('Datasets loaded')
    
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=100, num_workers=8, pin_memory=True)
    test_dl = DataLoader(test_ds, batch_size=100, num_workers=8, pin_memory=True)

    graph_spec = test_ds.graph_spec

    model = ViT_Gen_Predictor(graph_spec, fourier_dim, fourier_scale, rnn_mode, emb_dim, head_dim, head_drop, sigmoid=sigmoid).to(device)
    print(model)

    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total: {total_num}, Trainable: {trainable_num}')
  
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(1, 0.5 + 0.0005 * i))
    criterion = {"mse": nn.MSELoss(), "bce": nn.BCELoss()}[loss_fn]
    
    best_rsq, best_tau = -float('inf'), -float('inf')
    for epoch in range(epochs):
        epoch_loss = 0
        with tqdm(total=len(train_dl)) as pbar:
            for data, acc in train_dl:
                data = (data[0].to(device), data[1].to(device))
                acc = acc.float().to(device)
                optimizer.zero_grad()
                pred_acc = model(data).squeeze(-1)
                loss = criterion(pred_acc, acc)
                epoch_loss += loss
                loss.backward()
                optimizer.step()
                scheduler.step()
                pbar.update(1)
        avg_epoch_loss = epoch_loss/len(train_dl)
        print(avg_epoch_loss)

        # evaluate
        train_avg_err, train_avg_loss, train_rsq, train_tau, _, _ = eval(model, train_dl, criterion, device)
        avg_err, avg_loss, rsq, tau, _, _ = eval(model, val_dl, criterion, device)
        test_avg_err, test_avg_loss, test_rsq, test_tau, _, _ = eval(model, test_dl, criterion, device)
        print(f"Epoch {epoch}, train L1 err: {train_avg_err:.5f}, train loss: {train_avg_loss:.5f}, train Rsq: {train_rsq:.5f}, train tau: {train_tau:.5f}.")
        print(f"Epoch {epoch}, val L1 err: {avg_err:.5f}, val loss: {avg_loss:.5f}, val Rsq: {rsq:.5f}, val tau: {tau:.5f}.")
        print(f"Epoch {epoch}, test L1 err: {test_avg_err:.5f}, test loss: {test_avg_loss:.5f}, test Rsq: {test_rsq:.5f}, test tau: {test_tau:.5f}.")

        save_dict = {
            "weights": model.state_dict(),
            "val_l1": avg_err,
            "val_loss": avg_loss,
            "val_rsq": rsq,
            "epoch": epoch,
        }

        results = {'train_avg_err':train_avg_err, 'train_avg_loss':train_avg_loss, 'train_rsq':train_rsq, 'train_tau':train_tau,
                   'avg_err':avg_err, 'avg_loss':avg_loss, 'rsq':rsq, 'tau':tau,
                   'test_avg_err':test_avg_err, 'test_avg_loss':test_avg_loss, 'test_rsq':test_rsq, 'test_tau':test_tau}

        if rsq > best_rsq:
            val_files = os.listdir(output_dir)
            for f in val_files:
                if 'best_rsq' in f:
                    os.remove(output_dir+f)
            torch.save(save_dict, os.path.join(output_dir, f"best_rsq_{epoch}_{test_rsq}.pt"))
            with open(os.path.join(output_dir, "best_rsq.json"), "w") as f:
                json.dump(results, f, indent=4)
            best_rsq = rsq
        if tau > best_tau:
            val_files = os.listdir(output_dir)
            for f in val_files:
                if 'best_tau' in f:
                    os.remove(output_dir+f)
            torch.save(save_dict, os.path.join(output_dir, f"best_tau_{epoch}_{test_tau}.pt"))
            with open(os.path.join(output_dir, "best_tau.json"), "w") as f:
                json.dump(results, f, indent=4)
            best_tau = tau

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DNGNN_PREDICT_GEN_VIT')
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--epoch', type=int, default=200, help='number of epoch')
    parser.add_argument('--batch-size', type=int, default=100, help='training batch size')
    parser.add_argument('--lr', type=float, default=1e-3)
    
    parser.add_argument('--f-scale', type=int, default=1)
    parser.add_argument('--f-dim', type=int, default=128)
    parser.add_argument('--n-dim', type=int, default=128)
    parser.add_argument('--head-dim', type=int, default=1024)
    parser.add_argument('--rnn-mode', type=str, default='gru')
    parser.add_argument('--head-drop', type=float, default=0.0)

    parser.add_argument('--loss-fn', type=str, default='bce')
    parser.add_argument('--sigmoid', action=argparse.BooleanOptionalAction)
    
    args = parser.parse_args()
    main(args)