import math
import sys
from typing import Iterable, Optional
from timm.utils.model import unwrap_model
import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from lib import utils
import random
import time
import numpy as np
import os


vit_operation_dict = {'q': 0, 'k': 1, 'v': 2, 'proj': 3, 'fc1': 4, 'fc2': 5}


def sample_configs(choices, is_visual_prompt_tuning=False,is_adapter=False,is_LoRA=False,is_prefix=False):

    config = {}
    depth = choices['depth']

    if is_visual_prompt_tuning == False and is_adapter == False and is_LoRA == False and is_prefix==False:
        visual_prompt_depth = random.choice(choices['visual_prompt_depth'])
        lora_depth = random.choice(choices['lora_depth'])
        adapter_depth = random.choice(choices['adapter_depth'])
        prefix_depth = random.choice(choices['prefix_depth'])
        config['visual_prompt_dim'] = [random.choice(choices['visual_prompt_dim']) for _ in range(visual_prompt_depth)] + [0] * (depth - visual_prompt_depth)
        config['lora_dim'] = [random.choice(choices['lora_dim']) for _ in range(lora_depth)] + [0] * (depth - lora_depth)
        config['adapter_dim'] = [random.choice(choices['adapter_dim']) for _ in range(adapter_depth)] + [0] * (depth - adapter_depth)
        config['prefix_dim'] = [random.choice(choices['prefix_dim']) for _ in range(prefix_depth)] + [0] * (depth - prefix_depth)

    else:
        if is_visual_prompt_tuning:
            config['visual_prompt_dim'] = [choices['super_prompt_tuning_dim']] * (depth)
        else:
            config['visual_prompt_dim'] = [0] * (depth)
        
        if is_adapter:
             config['adapter_dim'] = [choices['super_adapter_dim']] * (depth)
        else:
            config['adapter_dim'] = [0] * (depth)

        if is_LoRA:
            config['lora_dim'] = [choices['super_LoRA_dim']] * (depth)
        else:
            config['lora_dim'] = [0] * (depth)

        if is_prefix:
            config['prefix_dim'] = [choices['super_prefix_dim']] * (depth)
        else:
            config['prefix_dim'] = [0] * (depth)
        
    return config


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    amp: bool = True, scaler=None):

    model.train()
    criterion.train()

    # set random seed
    random.seed(epoch)

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):

        for p in model.parameters():
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        if amp:
            with torch.cuda.amp.autocast():
                outputs = model(samples)
                loss = criterion(outputs, targets)
        else:
            outputs = model(samples)
            loss = criterion(outputs, targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)

        if amp:
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)
        elif scaler != 'naive':
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                    model=model, create_graph=is_second_order)
        else:
            loss.backward()
            optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def get_sensitivity(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, device: torch.device, epoch: int,
                    amp: bool = True,dataset=None, nb_classes=None, low_rank_dim=8, threshold=5,
                    suffix=None, structured_vector=True):
    model.train()
    criterion.train()

    # set random seed
    random.seed(epoch)

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Getting sensitivity'
    print_freq = 10

    grad_dict = model.state_dict()
    for key in grad_dict.keys():
        grad_dict[key] = torch.zeros_like(grad_dict[key])

    # Accumulating gradient for a epoch, can be reduced to customized number of samples
    for epoch in range(1):
        for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
            samples = samples.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            if amp:
                with torch.cuda.amp.autocast():
                    outputs = model(samples)
                    loss = criterion(outputs, targets)
            else:
                outputs = model(samples)
                loss = criterion(outputs, targets)

            loss_value = loss.item()

            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            loss.backward()
            for name, param in model.named_parameters():
                grad_dict[name] += param.grad

            torch.cuda.synchronize()
            metric_logger.update(loss=loss_value)
            metric_logger.update(lr=0.)

    # gather the stats from all processes

    grad_shapes = {}
    grad_shapes_int = {}

    # For explainability
    grad_sum = {}
    grad_percentage = {}

    head_cls_params = 768 * nb_classes + nb_classes + 768

    # Pre-defined keywords for calculating sensitivity
    grad_skip_kwd_list = ['head', 'cls_token']  # Fully tune head and class token
    grad_matrix_kwd_list = ['.q.', '.k.', '.v.', 'proj', 'fc']  # Maybe structurally tune the matrices: q, k, v, proj, fc1, and fc2
    grad_vector_kwd_list = ['norm', 'bias', 'pos_embed', 'patch_embed']  # Only unstructurally tune the vectors

    for key in grad_dict.keys():
        if not any(kwd in key for kwd in grad_skip_kwd_list):
            grad_shapes[key] = grad_dict[key].shape
            grad_shapes_int[key] = np.cumprod(list(grad_dict[key].shape))[-1]

    # Rank the total sensitivity
    large_tensor = torch.cat([grad_dict[key].flatten() for key in grad_shapes.keys()])

    # pre-defined target parameter constraints
    param_num_dict = {0.8: 0, 0.6: 0, 0.4: 0, 0.2: 0}
    if not structured_vector:
        suffix += '_nostruvec'

    # We get the configurations that are close to the target parameter constraints
    # Can be simplified if the exact numbers of trainable parameters are not required
    for param_num in range(1, 60):

        param_num = param_num / 100
        _, indexes = torch.abs(large_tensor).topk(math.ceil(param_num * 1e6))

        tmp_large_tensor = torch.zeros_like(large_tensor, device='cuda')
        tmp_large_tensor[indexes] = 1.

        tmp_large_tensor_list = tmp_large_tensor.split([shape for shape in grad_shapes_int.values()])

        structured_param_num = 0
        structured_names = []
        tuned_vectors = []

        unstructured_param_num = 0
        unstructured_name_shapes = {}
        unstructured_name_shapes_int = {}
        unstructured_grad_mask = {}

        for i, key in enumerate(grad_shapes.keys()):

            if any(kwd in key for kwd in grad_vector_kwd_list):

                grad_sum = grad_threshold = tmp_large_tensor_list[i].view(grad_shapes[key]).sum()
                grad_percentage[key] = grad_sum / grad_shapes_int[key]

                # Structured, never structurally tune patch_embed and pos_embed
                if grad_threshold >= list(grad_shapes[key])[0] // threshold and structured_vector and not type(grad_shapes[key]) is torch.Size:
                    structured_param_num += list(grad_shapes[key])[0]
                    tuned_vectors.append(key)

                # Unstructured
                else:
                    unstructured_param_num += grad_sum
                    unstructured_name_shapes[key] = tmp_large_tensor_list[i].view(grad_shapes[key]).shape
                    unstructured_name_shapes_int[key] = np.cumprod(list(grad_dict[key].shape))[-1]
                    unstructured_grad_mask[key] = tmp_large_tensor_list[i].view(grad_shapes[key])

            elif any(kwd in key for kwd in grad_matrix_kwd_list):

                grad_threshold = (tmp_large_tensor_list[i].view(grad_shapes[key]).sum(1) != 0).sum()
                grad_sum = tmp_large_tensor_list[i].view(grad_shapes[key]).sum()
                grad_percentage[key] = grad_threshold / list(grad_shapes[key])[0]

                # Structured
                if grad_threshold >= list(grad_shapes[key])[0] // threshold:
                    structured_param_num += grad_shapes[key][0] * low_rank_dim + low_rank_dim * grad_shapes[key][1]
                    structured_names.append(key)

                # Unstructured
                else:
                    unstructured_param_num += grad_sum
                    unstructured_name_shapes[key] = tmp_large_tensor_list[i].view(grad_shapes[key]).shape
                    unstructured_name_shapes_int[key] = np.cumprod(list(grad_dict[key].shape))[-1]
                    unstructured_grad_mask[key] = tmp_large_tensor_list[i].view(grad_shapes[key])

            else:
                raise NotImplementedError

        # Pre-defined 12 blocks
        tuned_matrices = [[0, 0, 0, 0, 0, 0] for _ in range(12)]

        for name in structured_names:
            attr = name.split('.')

            if len(attr) != 5:
                continue

            block_idx = int(attr[1])
            operation_idx = int(vit_operation_dict[attr[3]])
            tuned_matrices[block_idx][operation_idx] = 1

        for k in param_num_dict:
            v = param_num_dict[k]
            total_params = (unstructured_param_num + structured_param_num + head_cls_params) / 1e6

            # Save the configurations when closer to the target parameter
            if abs(total_params - k) <= abs(v - k):
                param_num_dict[k] = total_params

                res = {'unstructured_name_shapes': unstructured_name_shapes,
                          'unstructured_name_shapes_int': unstructured_name_shapes_int,
                          'params': total_params,
                          'unstructured_params': unstructured_param_num,
                          'structured_params': structured_param_num,
                          'unstructured_indexes': torch.nonzero(torch.cat([unstructured_grad_mask[key].flatten() for key in unstructured_grad_mask.keys()])).squeeze(-1),
                          'tuned_matrices': tuned_matrices,
                          'tuned_vectors': tuned_vectors
                        }

                try:
                    utils.save_on_master(res, 'sensitivity_{}/{}/param_req_{}.pth'.format(suffix, dataset, k))
                    print('saving to: ' + 'sensitivity_{}/{}/param_req_{}.pth'.format(suffix, dataset, k))
                    del res

                except:
                    print('creating folder: ' + 'sensitivity_{}/{}'.format(suffix, dataset))
                    os.makedirs('sensitivity_{}/{}'.format(suffix, dataset))
                    utils.save_on_master(res, 'sensitivity_{}/{}/param_req_{}.pth'.format(suffix, dataset, k))
                    print('saving to: ' + 'sensitivity_{}/{}/param_req_{}.pth'.format(suffix, dataset, k))
                    del res

    print('final params: ', param_num_dict)
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, amp=True):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    # switch to evaluation mode
    model.eval()

    for images, target in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        # compute output
        if amp:
            with torch.cuda.amp.autocast():
                output = model(images)
                loss = criterion(output, target)
        else:
            output = model(images)
            loss = criterion(output, target)

        try:
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            batch_size = images.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

        except RuntimeError:
            # class_num <= 5
            acc1 = accuracy(output, target, topk=(1,))
            batch_size = images.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1[0].item(), n=batch_size)
            metric_logger.meters['acc5'].update(0., n=batch_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}