import os
import sys
import torch
from tqdm import tqdm
import numpy as np
import torch.nn as nn
from model import build_model
from data import build_dataset
from utils import *
from torch.amp import autocast
from torchvision.models import vit_b_16, ViT_B_16_Weights
import argparse
import pickle

device = torch.device('cuda')

@torch.no_grad()
def eval_loop(model, test_loader, criterion, device, ds="target"):
    
    top1 = AverageMeter('Acc@1', ':6.2f')
    losses = AverageMeter('Loss', ':.4e')
    model.eval()
    for _, batch in enumerate(tqdm(test_loader)):
        x, label = batch[0].to(device), batch[1].long().to(device)
        with autocast('cuda'):
            output = model(x, ds=ds)

        loss = criterion(output, label)
        acc1 = accuracy(output, label, topk=(1,))
        top1.update(acc1[0].item(), x.size(0))
        losses.update(loss.item(), x.size(0))

    return losses.avg, top1.avg 

def train_loop(model, train_loader, criterion, optimizer, device, scheduler, steps, norm_interval):
    
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    model.train()
    Step = []
    for _, batch in enumerate(tqdm(train_loader)):
        # data, target = batch["image"].to(device), batch["label"].long().to(device)
        x, label = batch[0].to(device), batch[1].long().to(device)
        steps += 1
        lr = scheduler(steps)
        optimizer.param_groups[0].update(lr=lr)
        optimizer.zero_grad()
        with autocast('cuda'):
            output = model(x)
            loss = criterion(output, label)

        loss.backward()
        optimizer.step()

        acc1 = accuracy(output, label, topk=(1,))
        losses.update(loss.item(), x.size(0))
        top1.update(acc1[0].item(), x.size(0))

    return losses.avg, top1.avg, steps

def get_args_parser():
    parser = argparse.ArgumentParser('ViT fine-tuning', add_help=False)

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--dataset", default="svhn", choices=["cifar100", "food101", "flowers102", "svhn"])
    parser.add_argument("--num_workers", default=16, type=int)
    parser.add_argument("--epochs", default=10, type=int)
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--lr", default=5e-3, type=float)
    parser.add_argument("--weight_decay", default=0.0, type=float)
    
    parser.add_argument("--ft_mode", type=str, default="lora", choices=["linear", "full", "pissa", "lora", "hra", "vera", "dora", "oft", "adalora", "nblora"])
    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_gamma", type=float, default=0.0)
    parser.add_argument("--lora_sp", type=str, default="nuc", choices=["nuc", "fro", "inf", "delora"])
    parser.add_argument("--lora_dropout", type=float, default=0.0)
    parser.add_argument("--init_r", type=int, default=12)
    parser.add_argument("--log_steps", type=int, default=50)

    parser.add_argument("--output_dir", type=str, default=None)

    args = parser.parse_args()
    if args.output_dir is None:
        if args.ft_mode == "nblora":
            args.output_dir = f"./output/{args.dataset}-{args.batch_size}-{args.ft_mode}-r{args.lora_r}-g{args.lora_gamma}-{args.lora_sp}-lr{args.lr}-epoch{args.epochs}-seed{args.seed}"
        else:
            args.output_dir = f"./output/{args.dataset}-{args.batch_size}-{args.ft_mode}-r{args.lora_r}-lr{args.lr}-epoch{args.epochs}-seed{args.seed}"

    print(f"Output_dir: {args.output_dir}")

    return args

def print_trainable_parameters(model):
    lora_params = 0
    all_params = 0
    trainable_params = 0
    for name, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            if "lora" in name:
                lora_params += param.numel()
    print(
        f"all params: {all_params}, trainable: {trainable_params}, lora: {lora_params} | {100 * trainable_params / all_params:.2f}%"
    )
            

def train(args):
    reset_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    sys.stdout = Logger(sys.stdout, f"{args.output_dir}/log.txt")

    trainLoader, valLoader, testLoader, num_classes = build_dataset(args.dataset,
                                                                    batch_size=args.batch_size,
                                                                    workers=args.num_workers)
    
    with open('norm_bound.pickle', 'rb') as handle:
        norm_bound = pickle.load(handle)
        qwn, vwn = norm_bound["query"], norm_bound["value"]

    total_steps = len(trainLoader) * args.epochs

    model = build_model(ft_mode=args.ft_mode, 
                        num_classes=num_classes,
                        lora_r=args.lora_r,
                        lora_gamma=args.lora_gamma,
                        lora_sp=args.lora_sp,
                        lora_dropout=args.lora_dropout,
                        init_r=args.init_r,
                        total_steps=total_steps)
    
    print_trainable_parameters(model)

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)   #
    scheduler = lambda t: np.interp([t], [0, total_steps // 10, total_steps*9 // 10, total_steps], [0, args.lr, args.lr/20.0, 0])[0]

    steps = []
    source_loss, target_loss = [], []
    source_acc,  target_acc  = [], []
    step = 0
    sloss, sacc = eval_loop(model, valLoader, criterion, device,  ds="source")
    tloss, tacc = eval_loop(model, testLoader, criterion, device, ds="target")
    print(f"Step: {step:2d} | loss: {sloss:.2f}/{tloss:.2f}, acc: {sacc:.1f}/{tacc:.1f}")
    steps.append(step)
    source_loss.append(sloss)
    target_loss.append(tloss)
    source_acc.append(sacc)
    target_acc.append(tacc)

    for _ in range(args.epochs):
        model.train()
        for _, batch in enumerate(tqdm(trainLoader)):
            x, label = batch[0].to(device), batch[1].long().to(device)
            step += 1
            lr = scheduler(step)
            optimizer.param_groups[0].update(lr=lr)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            if step % args.log_steps == 0:
                sloss, sacc = eval_loop(model, valLoader, criterion, device,  ds="source")
                tloss, tacc = eval_loop(model, testLoader, criterion, device, ds="target")
                model.train()
                print(f"Step: {step:2d} | loss: {sloss:.2f}/{tloss:.2f}, acc: {sacc:.1f}/{tacc:.1f}, 1000lr{1000*lr:.1f}")
                steps.append(step)
                source_loss.append(sloss)
                target_loss.append(tloss)
                source_acc.append(sacc)
                target_acc.append(tacc)
                if hasattr(model, 'base_model'):
                    if args.ft_mode == "adalra":
                        r = args.init_r
                    else:
                        r = args.lora_r

                    dqwn = []
                    dvwn = []
                    for k, layer in enumerate(model.base_model.model.encoder.layers):
                        dw = layer.self_attention.query.get_delta_weight(adapter='default')
                        _, dws, _ = torch.svd_lowrank(dw, r, niter=16)
                        dqwn.append([dws.sum().item(), dws.norm().item(), dws[0].item()])

                        dw = layer.self_attention.value.get_delta_weight(adapter='default')
                        _, dws, _ = torch.svd_lowrank(dw, r, niter=16)
                        dvwn.append([dws.sum().item(), dws.norm().item(), dws[0].item()])
                    
                    dqwn = np.stack(dqwn, axis=0) / qwn
                    dvwn = np.stack(dvwn, axis=0) / vwn
                    ratio = np.maximum(dqwn, dvwn)
                    ratio = np.max(ratio, axis=0)
                    norm_ratio = np.hstack([dqwn, dvwn])
                    np.savetxt(f"{args.output_dir}/qv-{step}.csv", norm_ratio, delimiter=',')

                    print(f"{step} | nuc: {ratio[0]:.3f}, fro: {ratio[1]:.3f}, inf: {ratio[2]:.3f}")
            
    state_dict = {}
    for name, param in model.named_parameters():
            if param.requires_grad:
                state_dict[name] = param.data
    torch.save(state_dict, f"{args.output_dir}/last.ckpt")

    sloss, sacc = eval_loop(model, valLoader, criterion, device,  ds="source")
    tloss, tacc = eval_loop(model, testLoader, criterion, device, ds="target")
    print(f"Step: {step:2d} | loss: {sloss:.2f}/{tloss:.2f}, acc: {sacc:.1f}/{tacc:.1f}")
    steps.append(step)
    source_loss.append(sloss)
    target_loss.append(tloss)
    source_acc.append(sacc)
    target_acc.append(tacc)

    steps = np.array(steps)
    source_loss = np.array(source_loss)
    target_loss = np.array(target_loss)
    source_acc = np.array(source_acc)
    target_acc = np.array(target_acc)
    dat = np.stack([steps, source_acc, target_acc, source_loss, target_loss], axis=1)
    np.savetxt(f"{args.output_dir}/dat.csv", dat, delimiter=',', fmt=["%d", "%.1f", "%.1f", "%.2f", "%.2f"])
    

if __name__ == '__main__':
    args = get_args_parser()
    train(args)

