import argparse
import numpy as np
import math
import sys
import os
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import timm.optim.optim_factory as optim_factory

import util.lr_sched as lr_sched

from linear_probing import MAGECityPolyGenProbing

from dataloader import PolyDatasetClassification


def get_args_parser():
    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
    parser.add_argument('--batch_size', default=256, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=1000, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')


    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-6, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
                        help='epochs to warmup LR')

    parser.add_argument('--data_path', default='../datasets/states', type=str,
                        help='dataset path')

    parser.add_argument('--output_dir', default='../results/train/polyclassificationprobing/output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='../results/train/polyclassificationprobing/output_log',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')

    parser.add_argument('--split_ratio', type = float, default=0.8)
    parser.add_argument('--pretrained_path', type=str, default='../datasets/states')
    parser.add_argument('--fine_tune', action= "store_true")

    parser.set_defaults(pin_mem=True)

    return parser


def main(args):


    device = torch.device(args.device)

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True


    dataset_train = PolyDatasetClassification(args.data_path, train=True,split_ratio = args.split_ratio)

    dataset_valid = PolyDatasetClassification(args.data_path, train=False,split_ratio = args.split_ratio)

    if args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, 
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    data_loader_valid = torch.utils.data.DataLoader(
        dataset_valid, 
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    
    model = MAGECityPolyGenProbing(device = args.device, fine_tuning=args.fine_tune)
    
    pretrained_model = torch.load(args.pretrained_path)
    model.mae.load_state_dict(pretrained_model)
    
    if not args.fine_tune:
        for param in model.mae.parameters():
            param.requires_grad = False

    model.to(device)

    model_without_ddp = model

    eff_batch_size = args.batch_size * args.accum_iter 
    
    if args.lr is None:  
        args.lr = args.blr * eff_batch_size / 256

    param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    

    best_valid_loss = 1000000
    train_num = 0
    for epoch in range(args.start_epoch, args.epochs):
        model.train(True)
        optimizer.zero_grad()

        for data_iter_step, (samples, pos, info, class_name) in enumerate(data_loader_train):

            if data_iter_step % args.accum_iter == 0: 
                lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)

            poly = samples.to(device).float()
            pos = pos.to(device).float()
            class_name = class_name.to(device).long()
            loss, acc = model(poly, pos, class_name)

            loss_value = loss.item()
            
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)
            loss /= args.accum_iter
            if (data_iter_step + 1) % args.accum_iter == 0:
                optimizer.zero_grad()

            loss.backward()
            optimizer.step()

            log_writer.add_scalar('loss_train', loss.item(), train_num)
            log_writer.add_scalar('acc_train', acc.item(), train_num)

            train_num+=1


        print('train_loss:', loss.item()) 
        print('train_acc:', acc.item()) 

        valid_loss = 0
        val_acc = 0
        valid_count = 0
        for valid_step, (samples, pos, info, class_name) in enumerate(data_loader_valid):

            model.eval()

            poly = samples.to(device).float()
            pos = pos.to(device).float()
            class_name = class_name.to(device).long()
            
               
            with torch.no_grad():
                loss, acc = model(poly, pos, class_name)

            valid_loss += loss.item()
            valid_count += 1 
            val_acc += acc.item()    
                  
        val_loss = valid_loss/valid_count
        val_acc = val_acc/valid_count

        print('epoch:', epoch, 'val_loss: ', val_loss)
        print('epoch:', epoch, 'val_acc: ', val_acc)

        log_writer.add_scalar('loss_valid', val_loss, train_num)
        log_writer.add_scalar('acc_valid', val_acc, train_num)

        if val_loss < best_valid_loss:
            best_valid_loss = val_loss
            if not os.path.exists(args.output_dir):
                os.makedirs(args.output_dir)
            model_path = os.path.join(args.output_dir, f'probing_classification_best.pth')
            torch.save(model.state_dict(), model_path)

        if epoch%args.save_freq == 0:
            model_fpath = os.path.join(args.output_dir, f'probing_classification_{epoch}.pth')
            torch.save(model.state_dict(), model_fpath)

    model_path = os.path.join(args.output_dir, f'probing_classification_final.pth')
    torch.save(model.state_dict(), model_path)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
