import os
import json
import numpy as np
import torch
import torchvision
import transformers
import timm
import wandb

from args import get_args
from layer_replace import layer_replace
from finetune import imagenet_finetune, cifar_finetune, wikitext_finetune
from plotting import plot_ft
from models import TinyCNN

torch.manual_seed(0)
np.random.seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def theseus():
    args = get_args()

    wandb.init(
        project = args.exp_name,
        config = args.__dict__
    )
    if args.model == 'rn18':
        orig = torchvision.models.resnet18(pretrained = args.pretrained)
    elif args.model == 'dinov2':
        orig = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
    elif args.model == 'cnn':
        orig = TinyCNN().to(device)
        orig.load_state_dict(torch.load('saved_models/tiny_cnn_cifar.pt'))
    elif args.model == 'rn50':
        orig = torchvision.models.resnet50(pretrained = args.pretrained)
    elif args.model == 'qwen2.5':
        if args.pretrained:
            orig = transformers.AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct', torch_dtype = torch.bfloat16)
        else:
            config = transformers.AutoConfig.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct')
            orig = transformers.AutoModelForCausalLM.from_config(config)
        args.model = 'Qwen/Qwen2.5-0.5B'
    elif args.model == 'gpt2':
        if args.pretrained:
            orig = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
        else:
            config = transformers.AutoConfig.from_pretrained('gpt2')
            orig = transformers.AutoModelForCausalLM.from_config(config)
        args.model = 'gpt2'
    else:
        raise NotImplementedError
    orig = orig.to(device)
    
    if not args.reload_ft and not args.baseline:
        replaced_model_final = layer_replace(args, orig, args.indices, args.replacement_type, low_rank = args.low_rank, rank = args.rank, reload_cka = args.reload_cka, reload_cka_model = args.reload_cka_model)
        torch.save(replaced_model_final.state_dict(), f'saved_models/orig_{args.exp_name}.pt')
    elif args.baseline:
        if not os.path.exists(f'logs/{args.exp_name}'):
            os.makedirs(f'logs/{args.exp_name}')
        replaced_model_final = layer_replace(args, orig, args.indices, args.replacement_type, low_rank = args.low_rank, rank = args.rank, reload_cka = args.reload_cka, reload_cka_model = args.reload_cka_model)
        replaced_model_final = replaced_model_final.to(device)

    else:
        replaced_model_final = layer_replace(args, orig, args.indices, args.replacement_type, low_rank = args.low_rank, rank = args.rank, reload_ft = args.reload_ft, reload_cka = args.reload_cka, reload_cka_model = args.reload_cka_model)
        replaced_model_final.load_state_dict(torch.load(f'saved_models/orig_{args.exp_name}.pt'))
        replaced_model_final = replaced_model_final.to(device)
    if args.setting == 'imagenet':
        if args.model == 'dinov2':
            dinov2_im = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
            dinov2_im.backbone = replaced_model_final
            replaced_model_final = dinov2_im
        replaced_model_final, acc, log_json, orig_acc = imagenet_finetune(args, replaced_model_final, orig)
        plot_ft(args, log_json, orig_top1_acc)
    elif args.setting == 'cifar':
        acc, orig_acc = cifar_finetune(args.exp_name, replaced_model_final, orig)
    else:
        replaced_model_final, acc, log_json, orig_acc = wikitext_finetune(args, replaced_model_final, orig)
    torch.save(replaced_model_final.state_dict(), f'saved_models/ft_{args.exp_name}.pt')
    with open(f'logs/{args.exp_name}/args.json', 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    print('Modifed Acc:', acc)
    print('Original Acc:', orig_acc)

if __name__ == '__main__':
    theseus()