import os
import random
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from prettytable import PrettyTable
import tqdm
import wandb

from copy import deepcopy

## torch
import torch
import torch.nn as nn
from torch import optim
from torch.cuda.amp import autocast

## torchvision
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

## file-based imports
import utils.schedulers as schedulers
import utils.pruning_utils as pruning_utils
from utils.harness_utils import *
from utils.metric_utils import LabelSmoothingLoss, compute_hessian
from utils.pruning_utils import dst_pruner, prune_mag, make_dense
from utils.dataset import CIFARLoader, imagenet, imagenet_pytorch
from utils.conv_type import ConvMask, Conv1dMask, LinearMask, STRConv

## fastargs
from fastargs import get_current_config
from fastargs.decorators import param

## matplotlib
import matplotlib.pyplot as plt

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224 / 256



config = get_current_config()
parser = ArgumentParser()
config.augment_argparse(parser)
config.collect_argparse_args(parser)

this_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Add dataloaders here
num_classes = config["dataset.num_classes"]
criterion = LabelSmoothingLoss(num_classes) if config["dataset.criterion"] == 'LabelSmoothingLoss' else nn.CrossEntropyLoss() 

if "CIFAR" not in config["dataset.dataset_name"]:
    if not config["dataset.use_ffcv"]:
        dataset = imagenet_pytorch(this_device=this_device, distributed=False, rank=0, world_size=1)
        train_loader = dataset.train_loader
        test_loader = dataset.test_loader
    else:
        dataset = imagenet(this_device=this_device, distributed=False)
        train_loader = dataset.train_loader
        test_loader = dataset.test_loader
else:
    loaders = CIFARLoader(distributed=False)
    train_loader = loaders.train_loader
    test_loader = loaders.test_loader



@param('experiment_params.base_dir')
@param('experiment_params.analysis_dir')
def gen_target_dir(base_dir, analysis_dir):
    run_dir = os.path.join(base_dir, analysis_dir, 'checkpoints')
    return run_dir

def get_model_name(epoch, level=0):
    if epoch == 'init':
        return 'model_init.pt'
    elif epoch == 'final':
        return 'model_level_' + str(level) + '.pt'
    else:
        return 'model_' + str(level) + '_' + str(epoch) + '.pt'


def accumulate_grads(model):

    model.zero_grad()
    criterion = nn.CrossEntropyLoss()
    train_loss = 0
    for idx, (inputs, targets) in enumerate(tqdm.tqdm(train_loader)):
        if config["dataset.use_ffcv"] is False:
            inputs, targets = inputs.to(this_device), targets.to(this_device)

        with autocast(dtype=torch.bfloat16, enabled = True):
            outputs = model(inputs.contiguous())
            loss = criterion(outputs, targets)
        loss.backward()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        
        if idx == 50:
            break
    return model

def convert_ddp_to_standard_dict(ddp_state_dict):
    standard_state_dict = {key.replace("module.", ""): value for key, value in ddp_state_dict.items()}
    return standard_state_dict

def load_model(epoch):
    prune_harness = pruning_utils.PruningStuff()
    model = prune_harness.model.to(this_device)
    model_dir = os.path.join(gen_target_dir(), get_model_name(epoch))
    print(f'Loading model from {model_dir}')
    model_dict = convert_ddp_to_standard_dict(torch.load(model_dir))
    model.load_state_dict(model_dict)
    model = accumulate_grads(model)
    return model
    
@param('experiment_params.analysis_dir')
def get_eig_list(analysis_dir):
    seq = ['init'] + list(np.arange(9, 100, 10))

    eig_list = []
    for epoch in seq:
        model = load_model(epoch)
        eigenvals = compute_hessian(model, train_loader)
        print('Eigenvals of this are: ', eigenvals)
        eig_list.append(eigenvals)
    torch.save(eig_list, os.path.join('/home/c01adga/CISPA-az6/thunderbird-2024/thunderbird-results/results', 'eigs_{}.pt'.format(analysis_dir)))

get_eig_list()
