# General settings for comparison and some codes were referred from StyDeSty

import torchaudio
import argparse
import os
import random
import numpy as np
import torch
import torch.utils.data
import torch.nn.functional as F
import torch.backends.cudnn
import torchvision.models
import network
import time
from torchvision import transforms
from torchvision.utils import save_image
from dataset import get_datasets
from util import save_options
from model1_TAU import AugNet
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
import itertools
from sklearn.metrics import confusion_matrix
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from torch.utils.data import TensorDataset
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB


import torch
from torch import nn




def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, help="Task", default='PACS')
    parser.add_argument("--data_root", type=str, help="Data root", default='../../datasets/PACS')
    parser.add_argument("--source", type=str, help="Source", default='photo')
    parser.add_argument("--target", type=str, help="Target", default='art_painting,cartoon,sketch')
    parser.add_argument("--ckpt_dir", type=str, help="Path of saving checkpoint", default='checkpoint/P2ACS')
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--seed", type=int, default=1, help="random seed")
    parser.add_argument("--learning_rate", "-l", type=float, default=.001, help="Learning rate")
    parser.add_argument("--iters", type=int, default=1000, help="Number of training iterations")
    parser.add_argument("--inner_iters", type=int, default=10, help="Number of inner training iterations")
    parser.add_argument("--network", help="Which network to use", default="resnet18")
    parser.add_argument("--optimizer", help='Which optimizer to use, Adam or SGD', default='SGD')
    parser.add_argument("--nesterov", default=True, type=bool, help="Use nesterov")
    parser.add_argument("--scheduler", default='linear', type=str, help="Learning rate scheduler")
    parser.add_argument("--lr_aug", default=0.005, type=float)
    parser.add_argument("--aug_weight", default=0.6, type=float)
    parser.add_argument("--weight_decay", default=0.0005, type=float)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--n_workers", type=int, default=8)
    parser.add_argument("--test_freq", type=int, default=50)
    parser.add_argument("--print_freq", type=int, default=10)
    parser.add_argument("--lambda_center", type=float, default=0.003)
    parser.add_argument("--sr", type=int, default=16000, help="sampling rate for audio")
    parser.add_argument("--n_mels", type=int, default=256, help="number of mel bins")
    parser.add_argument("--max_lr", type=float, default=0.1, help="Peak LR after warm-up (0.1 for BC-ResNet-1, 0.06 for BC-ResNet-8)")
    parser.add_argument("--epochs", type=int, default=100)
    return parser.parse_args()

def timer(message, start_time, cum):
    now = time.time()
    elapsed = now - start_time
    cum[message] = f"{elapsed // 3600:.0f}h {(elapsed % 3600) // 60:.0f}m {elapsed % 60:.0f}s"
    print(f"{message} {elapsed // 3600:.0f}h {(elapsed % 3600) // 60:.0f}m {elapsed % 60:.0f}s")
    

def main(args):
    start_time = time.time()
    cum = {"1":0.0}
    device = torch.device(args.device)
    name = str(int(time.time()))
    print('Running ID: %s' % name)
    os.makedirs(os.path.join(args.ckpt_dir, name), exist_ok=True)
    save_options(os.path.join(args.ckpt_dir, name, 'options.txt'), args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.task == 'TAU':
        backbone, classifier = network.get_network('bcresnet18', 10)
        net = AugNet(device, backbone=backbone, classifier=classifier, nc=10).to(device)
    else:
        raise NotImplementedError
    
    
    if args.task == 'TAU':
        normalize = torchvision.transforms.Normalize([0], [1]).to(device)
        n_classes = 10
    else:
        raise NotImplementedError
    
    

    train_set = get_datasets(args.task, args.data_root, domains=args.source, is_train=True)
    val_sets = get_datasets(args.task, args.data_root, domains = args.source, is_train = False)
    test_sets = get_datasets(args.task, args.data_root, domains=args.target, is_train=False)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=args.n_workers,
        pin_memory=False,
        shuffle = True
    )
    eval_loader = [
        torch.utils.data.DataLoader(
            val_set, batch_size=args.batch_size, shuffle=False,
            num_workers=args.n_workers, pin_memory=True, drop_last=False
        )
        for val_set in val_sets
    ]
    train_iter = iter(train_loader)
    eval_iter = iter(eval_loader)
    test_loaders = [
        torch.utils.data.DataLoader(
            test_set, batch_size=args.batch_size * 2, shuffle=False,
            num_workers=args.n_workers, pin_memory=True, drop_last=False
        )
        for test_set in test_sets
    ]

    # net.set_pool(train_set); timer("set pool", start_time, cum)
    # net.epsilon_list(); timer("epsilon_list", start_time, cum)
    # net.augment_pool(); timer("aug pool", start_time, cum)
    net.load_augment_pool()
    
    # mel spectrogram trnansformation
    mel_tf = MelSpectrogram(
        sample_rate=16000, n_fft=2080, win_length=int(16000 * 0.13), hop_length=int(16000 * 0.03),
        n_mels=256, f_min = 50, f_max = 8000
    ).to(device)
    # log sclae (log-mel spectrogram)
    log_tf = AmplitudeToDB().to(device)
    
    
    

    backbone, classifier = network.get_network(args.network, n_classes)
    
    backbone = backbone.to(device)
    
    warm_epochs   = 5
    iters_per_ep  = int(np.ceil(len(train_set) / args.batch_size))
    warm_iters    = warm_epochs * iters_per_ep
    total_iters   = args.epochs  * iters_per_ep
    args.iters = total_iters
    
    
    
    if args.optimizer == 'SGD':
        optimizer_backbone = torch.optim.SGD(
            backbone.parameters(), lr=0.1,
            nesterov=args.nesterov, momentum=0.9, weight_decay=args.weight_decay
        )
        
        
    elif args.optimizer == 'Adam':
        optimizer_backbone = torch.optim.Adam(
            backbone.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay
        )
        
    else:
        raise NotImplementedError

    if args.scheduler == 'linear':
        scheduler_backbone = torch.optim.lr_scheduler.StepLR(optimizer_backbone, step_size=int(args.iters * 0.45))
    elif args.scheduler == 'cos':
        scheduler_backbone = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_backbone, args.iters)
    else:
        raise NotImplementedError
    
    # 0 => max_lr
    scheduler_warm = LinearLR(optimizer_backbone, start_factor=1e-8,
                              end_factor = args.max_lr, total_iters = warm_iters)
    # cosine annealing max_lr => 0
    scheduler_cos = CosineAnnealingLR(
        optimizer_backbone,
        T_max = total_iters - warm_iters,
        eta_min = 0.0
    )
    
    scheduler_backbone = SequentialLR(
        optimizer_backbone,
        schedulers = [scheduler_warm, scheduler_cos],
        milestones = [warm_iters]
    )
    
    best_acc = 0.
    best_msg = ''
    all_losses = []

    for i in range(args.iters):
        losses = {}
        try:
            data = next(train_iter)
        except:
            train_iter = iter(train_loader)
            data = next(train_iter)
        image = data[0].to(device)
        label = data[1].to(device)
        optimizer_backbone.zero_grad()
        
        
        with torch.no_grad():
            image_aug, aug_label = net(image, lbl=label)
        wav = image_aug
        # log mel spectrogram
        mel = mel_tf(wav)
        mel_db = log_tf(mel)
        
        T = 350 # time frames
        cur_len = mel_db.size(-1)
        if cur_len < T: # pad
            mel_db = F.pad(mel_db, (0, T - cur_len))
        elif cur_len > T: # crop
            mel_db = mel_db[:, :, :T]
        
        image_aug = mel_db
        
        
        feat = backbone(image_aug)
        pred = feat
        
        ce_loss = F.cross_entropy(pred, aug_label)
        loss = ce_loss
        losses['inner_ce'] = ce_loss

        torch.cuda.empty_cache()
        loss.backward()
        optimizer_backbone.step()
        
    

        scheduler_backbone.step()

        all_losses.append({k: v.item() for k, v in losses.items()})

        if (i + 1) % args.print_freq == 0:
            msg = 'iteration %04d' % (i + 1)
            for k, v in losses.items():
                msg += ' loss_%s: %.3f' % (k, v.item())
            timer(f"train_{i+1}", start_time, cum)
            print(msg)

        if (i + 1) % args.test_freq == 0:
            backbone.eval()
            
            domain_names = args.target.split(",")
            
            

            with torch.no_grad():
                acc = []
                for test_loader in test_loaders:
                    cur_total = 0
                    cur_correct = 0
                    for data in test_loader:
                        image = data[0].to(device)
                        label = data[1].to(device)

                        feat = backbone(image)
                        pred = feat
                        
                        final_preds = pred.argmax(dim=1)

                        cur_total += image.size(0)
                        cur_correct += final_preds.eq(label).sum().item()
                    acc.append(cur_correct / cur_total)

            backbone.train()
            mean_acc = sum(acc) * 100 / len(acc)
            msg = 'Test Accuracy: %.2f' % mean_acc
            for item in acc:
                msg += '\t[%.2f]' % (item * 100)
            timer(f"test_{i+1}", start_time, cum)
            print(msg)

            if mean_acc > best_acc:
                if args.task == "Digits":
                    with torch.no_grad():
                        for idx, test_loader in enumerate(test_loaders):
                            all_preds = []
                            all_labels = []

                            for data in test_loader:
                                image = data[0].to(device)
                                label = data[1].to(device)

                                feat = backbone(image)
                        
                                pred = feat
                                
                                all_preds.append(pred.argmax(dim=1).cpu())
                                all_labels.append(label.cpu())

                            all_preds = torch.cat(all_preds)
                            all_labels = torch.cat(all_labels)

                best_acc = mean_acc
                best_msg = msg
                torch.save(backbone.state_dict(), os.path.join(args.ckpt_dir, name, 'backbone_best.pth'))
                print('Best Model Saved!')


    print('Best %s' % best_msg)
    torch.save(all_losses, os.path.join(args.ckpt_dir, name, 'all_losses.pth'))
    torch.save(backbone.state_dict(), os.path.join(args.ckpt_dir, name, 'backbone_final.pth'))
    
    backbone.load_state_dict(torch.load(os.path.join(args.ckpt_dir, name, 'backbone_best.pth')))
    backbone.eval()
    
    elapsed_time = time.time() - start_time
    print(f"Training completed: {elapsed_time // 3600:.0f}h {(elapsed_time % 3600) // 60:.0f}m {elapsed_time % 60:.0f}s")

    
if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    config = get_args()
    main(config)


