import wandb
import torch
import sys
from metrics import gradient_norm, hessian_trace_and_top_eig, hessian_trace_and_top_eig_rf, residual_and_top_eig_ggn, top_k_dir_sharpness, top_k_hessian_alignment, process_gradients, get_projected_gradients, process_eigenvectors
from metrics import activation_norm_dict, entropies_dict, empirical_ntk_jacobian_contraction, fnet_single, activ_skewness_dict, directional_sharpness, ntk_eigenvalues
from pyhessian import hessian
import numpy as np
from asdl.kernel import kernel_eigenvalues
import time
from pyhessian import get_params_grad, hessian_vector_product, normalization, group_product
import os
# Training
def train(train_loss, epoch, batches_seen, nets, metrics, num_classes, trainloader, optimizers, criterion, device, schedulers, log=True, max_updates=-1, activations=None, get_entropies=False, logging_steps=200, use_mse_loss=False,
          eval_inputs=None, eval_targets=None, eval_hessian_random_features=False, eval_hessian=False, top_eig_ggn=False, get_top_k_dir_sharpness=False, top_hessian_eigvals=10, ntk_eigs=0,save_ckpt_every_nth_batch=-1, save_path=None, testloader=None):
    
    print('\nEpoch: %d' % epoch)
    for e, net in enumerate(nets):
        net.train()

    E = len(nets)
        
    compute_every = logging_steps
    # train_loss = 0
    
    correct = 0
    total = 0
    batches_to_log = 1
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        start = time.time()
        inputs, targets = inputs.double().to(device), targets.to(device)
        for e, net in enumerate(nets):
            if use_mse_loss:
                targets = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
            optimizers[e].zero_grad()
            outputs = net(inputs)
            
            # for centering!
            if batch_idx == 0 and epoch == 0:
                assert not torch.is_nonzero(outputs.mean()), f"outputs mean: {outputs.mean()}"
            
            loss = criterion(outputs, targets)
            if torch.isnan(loss):
                raise ValueError("Loss is nan, quitting training")
            
            loss.backward()
            optimizers[e].step()
            train_loss += (loss.item() / len(nets))
            _, predicted = outputs.max(1)
            total += targets.size(0)
            if use_mse_loss:
                _, t_max = targets.max(1)
            else:
                t_max = targets
            correct += predicted.eq(t_max).sum().item()
        
        if batches_seen % compute_every == 0 and batches_seen > 0:
            print("train_loss: {}, train_acc: {}, {}, {}".format(train_loss/compute_every, correct/total, len(targets), batches_seen))
            metrics['train_loss'] += [train_loss/compute_every]

            optimizers[0].zero_grad()
            nets[0].eval()
            
           
            if eval_hessian_random_features:
                top_eigenvalues, trace = hessian_trace_and_top_eig_rf(nets[0], criterion, eval_inputs, eval_targets, cuda=True)
                metrics["trace_rf"] += [np.mean(trace)]
                metrics["top_eig_rf"] += [top_eigenvalues[-1]]
            if eval_hessian:
                top_eigenvalues = hessian_trace_and_top_eig(nets[0], criterion, eval_inputs, eval_targets, top_n=top_hessian_eigvals, cuda=True)
                # metrics["trace"] += [np.mean(trace)]
                for i in range(top_hessian_eigvals):
                    metrics[f"top_eig_{i}"] += [top_eigenvalues[i]]
            if top_eig_ggn:
                top_eig_ggn, residual = residual_and_top_eig_ggn(nets[0], eval_inputs, eval_targets, use_mse_loss)
                metrics['residual'] += [residual]
                metrics['top_eig_ggn'] += [top_eig_ggn]
            if ntk_eigs > 0:
                top_ntk_eigs = ntk_eigenvalues(nets[0], eval_inputs, eval_targets, ntk_eigs)
                for i in range(ntk_eigs):
                    metrics[f"ntk_eig_{i}"] += [top_ntk_eigs[i].item()] 
                
            if total > 0: #and total_ens > 0:
                metrics['train_acc'] += [100.0 * correct/total]
            if testloader is not None:
                # metrics = test_single_batch(nets, metrics, num_classes, testloader, criterion, device, use_mse_loss)
                metrics = test(nets, metrics, num_classes, testloader, criterion, device, use_mse_loss)
                print('Saving..')
                state = {
                    'metrics': metrics,
                    'epoch': epoch,
                    'batches': batches_seen
                }
                torch.save(state, save_path + f'/ckpt_batches_{batches_seen}_.pth')  
            nets[0].train()
            train_loss = 0
            correct = 0
            total = 0
            batches_to_log = 1
            end = time.time()
            print(f"Time: {(end-start):.2f}s")
            
        if batches_seen >= max_updates and max_updates != -1:
            return metrics, batches_seen
        
        if save_ckpt_every_nth_batch > 0 and batches_seen % save_ckpt_every_nth_batch == 0:
            torch.save(net.state_dict(), save_path + f'/model_ckpt_epoch_{epoch}_batches_{batches_seen}_.pth')
            
        batches_seen += 1
        batches_to_log += 1
        
        if len(schedulers) > 0:
            for scheduler in schedulers:
                scheduler.step()
    return metrics, batches_seen, train_loss

def test_single_batch(nets, metrics, num_classes, testloader, criterion, device, use_mse_loss):
    #from utils import progress_bar
    global best_acc
    for e, net in enumerate(nets):
        net.eval()
    
    test_loss = 0
    ens_test_loss = 0
    correct = 0
    total = 0
    correct_ens = 0
    total_ens = 0

    with torch.no_grad():
        # Fetch a single batch from the testloader
        inputs, targets = next(iter(testloader))
        inputs, targets = inputs.double().to(device), targets.to(device)
        mean_logit = torch.zeros((targets.shape[0], num_classes)).to(device)

        for e, net in enumerate(nets):
            if use_mse_loss:
                targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
            else:
                targets_one_hot = targets

            outputs = net(inputs)
            loss = criterion(outputs, targets_one_hot)
            if torch.isnan(loss):
                raise ValueError("Loss is nan, quitting training")

            test_loss += loss.item() / len(nets)
            _, predicted = outputs.max(1)
            total += targets.size(0)

            if use_mse_loss:
                _, t_max = targets_one_hot.max(1)
            else:
                t_max = targets

            correct += predicted.eq(t_max).sum().item()

        ens_test_loss += criterion(mean_logit, targets).item()
        total_ens += targets.size(0)
        _, predict_ens = mean_logit.max(1)
        correct_ens += predict_ens.eq(t_max).sum().item()

    metrics['test_loss'] += [test_loss]
    metrics['test_acc'] += [100. * correct / total]

    print(f"test_loss: {metrics['test_loss'][-1]}, test_acc: {metrics['test_acc'][-1]}")
    return metrics


def test(nets, metrics, num_classes, testloader, criterion, device, use_mse_loss):
    #from utils import progress_bar
    global best_acc
    for e,net in enumerate(nets):
        net.eval()
    test_loss = 0
    ens_test_loss = 0
    correct = 0
    total = 0
    correct_ens = 0
    total_ens = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.double().to(device), targets.to(device)
            mean_logit = torch.zeros((targets.shape[0],num_classes)).to(device)
            for e, net in enumerate(nets):
                if use_mse_loss:
                    targets = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                if torch.isnan(loss):
                    raise ValueError("Loss is nan, quitting training")
                    exit(1)
                test_loss += loss.item()/len(nets)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                if use_mse_loss:
                    _, t_max = targets.max(1)
                else:
                    t_max = targets
                correct += predicted.eq(t_max).sum().item()
            ens_test_loss += criterion(mean_logit, targets).item()
            total_ens += targets.size(0)
            _,predict_ens = mean_logit.max(1)
            correct_ens += predict_ens.eq(t_max).sum().item()
    metrics['test_loss'] += [test_loss/(batch_idx+1)]
    # metrics['ens_test_loss'] += [ens_test_loss/(batch_idx+1)]
    metrics['test_acc'] += [100.*correct/total]
    # metrics['ens_test_acc'] += [100.*correct_ens/total_ens]
    print(f"test_loss: {metrics['test_loss'][-1]}, test_acc: {metrics['test_acc'][-1]}")
    return metrics



def eval(nets, num_classes, loader, criterion, device, use_mse_loss):
    #from utils import progress_bar
    global best_acc
    for e,net in enumerate(nets):
        net.eval()
    tot_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            for e, net in enumerate(nets):
                if use_mse_loss:
                    targets = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                if torch.isnan(loss):
                    raise ValueError("Loss is nan, quitting training")
                    exit(1)
                tot_loss += loss.item()/len(nets)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                if use_mse_loss:
                    _, t_max = targets.max(1)
                else:
                    t_max = targets
                correct += predicted.eq(t_max).sum().item()
    
    return tot_loss/(batch_idx+1), 100.*correct/total
