import copy
from tqdm import tqdm
import torch
import torch.nn.functional as F
import sam
import numpy as np

def get_fs_stats(args, stats, y_hat, y, acc, loss, orig_idx, raw_idx):
    y_hat_correct_classes = y_hat[torch.arange(y_hat.size(0)), y].unsqueeze(1)
    y_hat_copy = y_hat.clone()
    y_hat_copy[torch.arange(y_hat.size(0)), y] = float('-inf')
    y_hat_highest_incorrect_classes = torch.max(y_hat_copy, dim=1).values.unsqueeze(1)
    margins = y_hat_correct_classes - y_hat_highest_incorrect_classes
    if args.aus_avg_epochs or args.get_grad_stats:
        acts = F.softmax(y_hat, dim=1)
        id_mat = torch.eye(args.num_classes).cuda()
        ll_grad = acts - id_mat[y.long()]
        grad_norm = ll_grad.norm(dim=1)
        

    loss_items = loss.detach().unsqueeze(1)
    acc_items = acc.float().unsqueeze(1)
    margins = margins.detach()

    for j, idx in enumerate(raw_idx):
        if args.aus_avg_epochs:
            index_stats = stats.get(idx.item(), [[], [], [], [], [], [], [], [], [], [], []])
            index_stats[7].append(grad_norm[j].item())
            index_stats[8].append(y[j].item())
            index_stats[9].append(acts[j].detach().cpu())
            index_stats[10].append(y_hat[j].detach().cpu())
        elif not args.get_grad_stats:
            index_stats = stats.get(idx.item(), [[], [], [], [], [], [], []])
        else:
            index_stats = stats.get(idx.item(), [[], [], [], [], [], [], [], []])
            index_stats[7].append(grad_norm[j].item())
        index_stats[0].append(loss_items[j].item())
        index_stats[1].append(acc_items[j].item())
        index_stats[2].append(margins[j].item())
        index_stats[6].append(orig_idx[j].item())
        stats[idx.item()] = index_stats

def get_sharpness_stats(args, stats, model, criterion, x, y, orig_idx, raw_idx):
    model.eval()

    acc_batch = [(
        x,
        range(args.batch_size),
        y,
        range(args.batch_size),
        range(args.batch_size),
        raw_idx)]
    batch_sharpness, batch_err, batch_grad_norm = eval_sharpness(model, acc_batch, criterion, args.sam_rho)

    for j, idx in enumerate(raw_idx):
        index_stats = stats.get(idx.item(), [[], [], [], [], [], [], []])
        index_stats[3].append(batch_sharpness)
        index_stats[4].append(batch_err)
        index_stats[5].append(batch_grad_norm)
        index_stats[6].append(orig_idx[j].item())
        stats[idx.item()] = index_stats
    model.train()

# TODO: Refactor
def eval_sharpness(model, batches, loss_f, rho, step_size=1, n_iters=1, n_restarts=1, no_grad_norm=False, layer_name_pattern='all', batch_transfer=False, rand_init=False, verbose=False, use_tqdm=False):
    orig_model_state_dict = copy.deepcopy(model.state_dict())

    n_batches, best_obj_sum, final_err_sum, final_grad_norm_sum = 0, 0, 0, 0
    for x, _, y, _, _, _ in tqdm(batches) if use_tqdm else batches:
        x, y = x.cuda(), y.cuda()
        def f(model):
            obj = loss_f(model(x), y).mean()
            return obj

        obj_orig = f(model).detach()  
        err_orig = (model(x).max(1)[1] != y).float().mean().item()

        delta_dict = {param: torch.zeros_like(param) for param in model.parameters()}
        orig_param_dict = {param: param.clone() for param in model.parameters()}
        best_obj, final_err, final_grad_norm = 0, 0, 0
        for restart in range(n_restarts):
            # random init on the sphere of radius `rho`
            if rand_init:
                delta_dict = sam.random_init_on_sphere_delta_dict(delta_dict, rho)
                for param in model.parameters():
                    param.data += delta_dict[param]
            else:
                delta_dict = {param: torch.zeros_like(param) for param in model.parameters()}

            if rand_init:
                n_cls = 10
                y_target = torch.clone(y)
                for i in range(len(y_target)):
                    lst_classes = list(range(n_cls))
                    lst_classes.remove(y[i])
                    y_target[i] = np.random.choice(lst_classes)
            def f_opt(model):
                if not rand_init:
                    return f(model)
                else:
                    return -loss_f(model(x), y_target).mean()

            for _ in range(n_iters):
                step_size_curr = step_size
                delta_dict = sam.weight_ascent_step(model, f_opt, orig_param_dict, delta_dict, step_size_curr, rho, layer_name_pattern, no_grad_norm=no_grad_norm, verbose=False)
            
            if batch_transfer:
                delta_dict_loaded = torch.load('deltas/gn_erm/batch{}.pth'.format(restart))  
                delta_dict_loaded = {param: delta for param, delta in zip(model.parameters(), delta_dict_loaded.values())}  # otherwise `param` doesn't work directly as a key
                for param in model.parameters():
                    param.data = orig_param_dict[param] + delta_dict_loaded[param]

            zero_grad(model)
            obj = f(model)
            obj.backward()
            grad_norm = get_flat_grad(model).norm()
            zero_grad(model)

            obj, grad_norm = obj.detach(), grad_norm.detach()

            err = (model(x).max(1)[1] != y).float().mean().item()

            # if obj > best_obj:
            if err > final_err:
                best_obj, final_err, final_grad_norm = obj, err, grad_norm
            model.load_state_dict(orig_model_state_dict)
            if verbose:
                delta_norm_total = torch.cat([delta_param.flatten() for delta_param in delta_dict.values()]).norm()
                print('[restart={}] Sharpness: obj={:.4f}, err={:.2%}, delta_norm={:.2f} (step={:.3f}, rho={}, n_iters={})'.format(
                      restart+1, obj - obj_orig, err - err_orig, delta_norm_total, step_size, rho, n_iters))

        best_obj, final_err = best_obj - obj_orig, final_err - err_orig  # since we evaluate sharpness, i.e. the difference in the loss
        best_obj_sum, final_err_sum, final_grad_norm_sum = best_obj_sum + best_obj, final_err_sum + final_err, final_grad_norm_sum + final_grad_norm
        n_batches += 1
    
    if type(best_obj_sum) == torch.Tensor:
        best_obj_sum = best_obj_sum.item()
    if type(final_grad_norm_sum) == torch.Tensor:
        final_grad_norm_sum = final_grad_norm_sum.item()
    
    # TODO: Note: best_obj_sum / n_batches is what we eventually plot for 'sharpness'.
    return best_obj_sum / n_batches, final_err_sum / n_batches, final_grad_norm_sum / n_batches

def zero_grad(model):
    for p in model.parameters():
        if p.grad is not None:
            p.grad.zero_()

def get_flat_grad(model):
    return torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None])