import os
import sys
import torch
from tqdm import tqdm
import numpy as np
import torch.nn as nn
from models.network_lora_orthgonal import CONFIGS, LoraOrthVisionTransformer
from RLRRDatasets.VTABDataLoader import get_data
from RLRRDatasets.VTABConfig import DATA_CONFIGS
from utils import (seed_torch, accuracy, AverageMeter, Logger, count_parameters)
from timm.scheduler import create_scheduler
from torch.cuda.amp import autocast
from timm.utils import NativeScaler
from timm.models import model_parameters
import argparse


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--name", default="parameter-efficient fine-tuning")
    parser.add_argument("--dataset_name", default="eurosat")
    parser.add_argument("--lora_dir", default="eurosat 94.94.pth")
    parser.add_argument("--model_type", default="ViT-B_16")
    parser.add_argument("--dataset_dir", default="/root/ARCExtend/vtab/vtab-1k/")
    parser.add_argument("--pretrained_dir", type=str, default="ViT-B_16.npz")  # imagenet21k_

    parser.add_argument("--output_dir", default="output", type=str)  # -aug all-no-res
    parser.add_argument("--device", default='cuda:4', type=str)

    parser.add_argument("--num_workers", default=6, type=int)
    parser.add_argument("--img_size", default=224, type=int)

    parser.add_argument("--local-rank", type=int, default=-1)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=2)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    args = parser.parse_args()

    return args

def setup(args, frozen_list=('',)):
    # Prepare model
    config = CONFIGS[args.model_type]
    model = LoraOrthVisionTransformer(config, args.img_size, zero_head=True, num_classes=args.num_classes,
                                      drop_path=args.drop_path)
    model.load_from(np.load(args.pretrained_dir))

    lora_path = args.output_dir + "/" + args.dataset_name + "/" + args.lora_dir
    lora_model_dict = torch.load(lora_path)
    model.load_state_dict(lora_model_dict, strict=False)
    return model


@torch.no_grad()
def valid(model, test_loader, device):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    losses = AverageMeter('Loss', ':.4e')
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.to(device), label.to(device)
        with autocast():
            output = model(x)

        loss = criterion(output, label)
        acc1 = accuracy(output, label, topk=(1,))
        top1.update(acc1[0].item(), x.size(0))
        losses.update(loss.item(), x.size(0))
    print('Test :', losses, top1)
    return top1.avg, losses.avg


def train(model, train_loader, criterion, optimizer, loss_scaler, lr_scheduler, epoch):
    model.train()
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')

    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.long().to(device)
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        optimizer.zero_grad()
        # fellow SSF
        loss_scaler(loss, optimizer, parameters=model_parameters(model))
        # loss.backward()
        # optimizer.step()

        acc1 = accuracy(output, target, topk=(1,))
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0].item(), data.size(0))

    # fellow SSF
    if lr_scheduler is not None:
        lr_scheduler.step_update(num_updates=epoch, metric=losses.avg)
    print('Train :', losses, top1)
    return top1.avg, losses.avg


def main(args):
    config = DATA_CONFIGS[args.dataset_name]
    args.data_path = os.path.join(args.dataset_dir, args.dataset_name)
    args.num_classes = config['num_classes']
    args.learning_rate = config['lr']
    args.min_lr = config['min_lr']
    args.drop_path = config['drop_path']
    args.warmup_lr = config['warmup_lr']
    args.weight_decay = config['weight_decay']
    args.batch_size = config['batch_size']
    args.simple_aug = config['simple_aug']
    if not os.path.exists(os.path.join(args.output_dir, args.dataset_name)):
        os.makedirs(os.path.join(args.output_dir, args.dataset_name))


    train_loader, test_loader = get_data(data_path=args.data_path, batch_size=args.batch_size,
                                         simple_aug=args.simple_aug)

    model = setup(args, ['head', 'scale', 'shift', 'Lora', '_house'])
    model.to(device)

    acc, _ = valid(model, test_loader, device)

if __name__ == '__main__':
    args = get_args_parser()
    seed_torch(args.seed)
    device = torch.device(args.device)
    main(args)
