import shutil
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
from torchvision import datasets
from torch.utils.data import DataLoader
from GrowthNew import GrowthModel, GrowthBlock
from regression_utils_add import return_layers, return_layers_data
from torch.multiprocessing import Pool, Process, set_start_method
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    set_start_method("spawn")
except RuntimeError:
    pass
from joblib import Parallel, delayed
import time
import datetime
import gc


def convert(n):
    'returns seconds'
    return str(datetime.timedelta(seconds=n))


BATCH_SIZE = 128


def get_all_linear_layers(model, typ="dict"):
    """
    Gets all Linear layers in the model.
    Input: Model
    Output: A List of all Linear Layers in the model
    """
    if typ == "dict":
        children = [i for i in model.children()]
        linear_layers = {}
        l_c = 0
        c = 0
        l = len(children)

        while c < l:
            grandchildren = [i for i in children[c].children()]
            if grandchildren == []:
                if isinstance(children[c], nn.Linear):
                    linear_layers[str(l_c)] = children[c]
                    l_c += 1

            else:
                children += grandchildren
                l = len(children)
            c += 1
        return linear_layers
    else:
        children = [i for i in model.children()]
        linear_layers = []
        l_c = 0
        c = 0
        l = len(children)

        while c < l:
            grandchildren = [i for i in children[c].children()]
            if grandchildren == []:
                if isinstance(children[c], nn.Linear):
                    linear_layers.append(children[c])
                    l_c += 1

            else:
                children += grandchildren
                l = len(children)
            c += 1
        return linear_layers


def split_matrix(gradient, weight=None):
    "calculates the splitting matrix"
    def second_order_derivative():
        return torch.ones(gradient.shape)

    sm = gradient.view(1, -1).cpu() * second_order_derivative().view(-1, 1)
    return sm


def calculate_min_eig(gradient):
    'Calculates minimum eigenvalues for a given neuron gradient'
    try:
        splitting = split_matrix(gradient)
        eig, _ = torch.linalg.eig(splitting.cpu())
        min = torch.min(eig.cpu().double())
        del eig, splitting
        gc.collect()
        torch.cuda.empty_cache()
    except:
        min = 0
    return min


def growth_wrapper(model):
    "Wraps the VisionTransformer class with GrowthBlocks"
    for i in range(len(model.blocks)):
        model.blocks[i].attn.qkv = GrowthBlock(model.blocks[i].attn.qkv)
        model.blocks[i].attn.proj = GrowthBlock(model.blocks[i].attn.proj)
        model.blocks[i].mlp.fc1 = GrowthBlock(model.blocks[i].mlp.fc1)
        model.blocks[i].mlp.fc2 = GrowthBlock(model.blocks[i].mlp.fc2)
    return model


def return_arc_array(a_array, i_num, sel_layer, positional):
    """
    Updates the architecture array upon scaling
    """

    def create_numbered_arc_array(arc_array, init_num):
        name_array = []
        arc_new_arr = []
        for i in range(len(arc_array)):
            if arc_array[i] == 0:
                name_array.append(init_num)
                if init_num != sel_layer:
                    arc_new_arr.append(0)
                else:
                    arc_new_arr.append([[0, 0], positional])
                init_num += 1
            else:
                named_child_array, init_num, narr = create_numbered_arc_array(
                    arc_array[i][0], init_num
                )
                name_array.append([named_child_array, arc_array[i][1]])
                arc_new_arr.append([narr, arc_array[i][1]])
        return name_array, init_num, arc_new_arr

    return create_numbered_arc_array(a_array, i_num)


def get_all_linear_layers_transformer(model):
    'Gets all linear layers present in a transformer and its attributes (block,growth_block,position)'
    l = []
    l_att = []
    for i in range(len(model.blocks)):
        a = get_all_linear_layers(model.blocks[i].attn.qkv, typ="list")
        l += a
        a_att = [[i, 0, j] for j in range(len(a))]
        l_att += a_att
        b = get_all_linear_layers(model.blocks[i].attn.proj, typ="list")
        l += b
        b_att = [[i, 1, j] for j in range(len(b))]
        l_att += b_att
        c = get_all_linear_layers(model.blocks[i].mlp.fc1, typ="list")
        l += c
        c_att = [[i, 2, j] for j in range(len(c))]
        l_att += c_att
        d = get_all_linear_layers(model.blocks[i].mlp.fc2, typ="list")
        l += d
        d_att = [[i, 3, j] for j in range(len(d))]
        l_att += d_att
    return l, l_att


def calc_all_eigs(layers):
    """
    Calculate eigenvalues for all layers
    """
    eigs = []
    for i, layer in enumerate(layers):
        gradient = layer.weight.grad.cpu().detach().clone()
        s = gradient.shape[0]
        min_eigs = Parallel(n_jobs=24)(
            delayed(calculate_min_eig)(gradient[i]) for i in range(s)
        )
        eigs.append(min_eigs)
        del gradient
        gc.collect()
        torch.cuda.empty_cache()

    return eigs


def ret_flattened(eig):
    "Returns a flattened eigenvalue vector and which layer each eigenvalue belongs to"
    flat_eig = []
    flat_eig_layer = []
    for i, e in enumerate(eig):
        flat_eig += e
        flat_eig_layer += [i for j in range(len(e))]
    return flat_eig, flat_eig_layer


def layer_negative(eigs):
    "Return indexes of negative eigenvalues for each layer.."
    neg_index_dic = {}
    neg_eig_dic = {}
    for layer, eig in enumerate(eigs):
        for j, e in enumerate(eig):
            if e < 0:
                try:
                    neg_index_dic[str(layer)].append(j)
                    neg_eig_dic[str(layer)].append(e)
                except:
                    neg_index_dic[str(layer)] = [j]
                    neg_eig_dic[str(layer)] = [e]

    for k in neg_index_dic.keys():
        sort_pos = np.argsort(np.array(neg_eig_dic[k]))
        neg_index_dic[k] = np.array(neg_index_dic[k])[sort_pos]
    return neg_index_dic


def find_split_layers_param_quota(model, epoch, param_budget,layer_threshold=60):
    """
    Find the Neurons to split
    Inputs:
    model -> given model
    epoch -> current_epoch
    parameter_budget -> Number of parameters to add this scaling iteration
    layer_threshold ->minimum number of neurons to be selected in a layer to allow scaling
    Outputs:
    sel_layer_data -> an array with [layer_num,layer,layer_attributes,number of selected neurons] for each selected neurons
    neg_index_dic -> a dictionary containing indexes of neurons with negative eigs for each layer
    """
    l, la = get_all_linear_layers_transformer(model)
    #Calculate eigs
    eigs = calc_all_eigs(l)
    #Flatten Eigs
    flat_eig, flat_eig_layer = ret_flattened(eigs)
    #Create neg_index_dic
    neg_index_dic = layer_negative(eigs)
    rank = torch.Tensor(flat_eig).argsort()
    limit = rank.shape[0]

    # divides parameter budget equally among ATTN and MLP blocks
    param_budget_left_mlp = int(param_budget / 2)
    param_budget_left_attn = int(param_budget / 2)
    sel_count_attn = 0
    sel_count_mlp = 0
    d = {}
    sel_layer_nums = []
    for i in range(int(limit)):
        pos = rank[i]
        if flat_eig[pos] > 0:
            continue
        layer_num = flat_eig_layer[pos]
        # Number of Parameters caused by 1 neuron  = No. of Previous Layer Neurons + 1 for bias                                                                                                            ns * 1 + 1. The +1 is for bias
        # param_budget_left -= l[int(layer_num)].in_features + 1
        try:
            d[str(layer_num)] += 1
        except:
            d[str(layer_num)] = 1
        v = d[str(layer_num)]
        layer = l[layer_num]
        if v >= layer_threshold:
            l_attribute = la[int(layer_num)]
            if l_attribute[1] in ([2, 3]):
                if (
                    layer_num not in sel_layer_nums
                    and sel_count_mlp <= 4
                    and param_budget_left_mlp >= 0
                ):
                    sel_layer_nums.append(layer_num)
                    param_budget_left_mlp -= (
                        (l[int(layer_num)].in_features + 1) * v
                    ) * 2
                    sel_count_mlp += 1
                elif layer_num in sel_layer_nums and param_budget_left_mlp >= 0:
                    param_budget_left_mlp -= (l[int(layer_num)].in_features + 1) * 2
                else:
                    d[str(layer_num)] -= 1
            elif l_attribute[1] in ([0, 1]):
                if (
                    layer_num not in sel_layer_nums
                    and sel_count_attn <= 4
                    and param_budget_left_attn >= 0
                ):
                    sel_layer_nums.append(layer_num)
                    param_budget_left_attn -= (
                        (l[int(layer_num)].in_features + 1) * v
                    ) * 2
                    sel_count_attn += 1
                elif layer_num in sel_layer_nums and param_budget_left_attn >= 0:
                    param_budget_left_attn -= (l[int(layer_num)].in_features + 1) * 2
                else:
                    d[str(layer_num)] -= 1
        if param_budget_left_mlp <= 0 and param_budget_left_attn <= 0:
            break

    sel_layer_data = []

    c = 0
    for layer_num in sel_layer_nums:
        k = str(layer_num)
        v = d[k]
        sel_layer_data.append([k, l[int(k)], la[int(k)], v])
        c += 1
    del eigs, flat_eig, flat_eig_layer
    torch.cuda.empty_cache()
    return sel_layer_data, neg_index_dic


def ret_growth_model(block, num):
    'returns the growth block with respect to block and layer num'
    if num == 0:
        return block.attn.qkv
    elif num == 1:
        return block.attn.proj
    elif num == 2:
        return block.mlp.fc1
    else:
        return block.mlp.fc2


def assign_model(model, block, num, gb):
    'assigns the scaled growth block at approprite position'
    if num == 0:
        model.blocks[block].attn.qkv = gb
    elif num == 1:
        model.blocks[block].attn.proj = gb
    elif num == 2:
        model.blocks[block].mlp.fc1 = gb
    else:
        model.blocks[block].mlp.fc2 = gb
    return model


def set_all_trainable(model):
    'Sets every parameter trainable'
    for p in model.parameters():
        p.requires_grad = True


def set_all_not_trainable(model):
    'Sets every parameter not trainable'
    for p in model.parameters():
        p.requires_grad = False


def print_all_requires_grad(model):
    ' Prints every weight and bias requires_grad parameter'
    all_linear = get_all_linear_layers_transformer(model)[0]
    rg = [
        [param.requires_grad for param in linear.parameters()] for linear in all_linear
    ]
    print(rg)


def split_nodewise(
    model, optimizer, param_budget, epoch, warmup=0, layer_threshold=60,act_on=True,
):
    '''
    Scaling Neurons. 
    Inputs:
    model -> Model to scale
    optimizer -> Current model optimizer
    param_budget -> Parameter Budget
    epoch -> Current Epoch
    warmup -> spliting warmup (set to 0)
    layer_threshold -> Threshold for selecting neurons in find_split_layers_quota
    act_on -> whether activation is present
    Output:
    scaled model -> Model post scaling
    '''
    start = time.time()

    if warmup != 0:
        set_all_not_trainable(model)
        print(model)
    print_all_requires_grad(model)
    # Select Neurons to split
    sel_layers_data, neg_index_dic = find_split_layers_param_quota(
        model, epoch, param_budget,layer_threshold
    )
    # If no layers are selected to scale return model
    if len(sel_layers_data) == 0:
        return model
    cp1 = time.time()
    sum_reg = 0
    sum_max = 0
    l_neg = []
    
    #Scale all selected layers
    for layer_data in sel_layers_data:
        s_max = time.time()
        neg_index = neg_index_dic[str(layer_data[0])]
        e_max = time.time()
        sum_max += e_max - s_max
        if len(neg_index) > layer_data[-1]:
            neg_index = neg_index[: layer_data[-1]]
        l_neg.append(len(neg_index))

        block = layer_data[2][0]
        growth_block = layer_data[2][1]
        block_model = ret_growth_model(model.blocks[block], growth_block)
        layers = get_all_linear_layers(block_model, typ="list")
        layer = layer_data[1]
        choices = [n for n in neg_index]
        choices.sort()
        s_layer = None
        for i, l in enumerate(layers):
            if l == layer_data[1]:
                layers.pop(i)
                s_layer = i
                break
        #Create New Layer      
        (new_layer,optimizer) = create_new_layer(layer, choices, optimizer)

        s_reg = time.time()
        lr = optimizer.param_groups[0]["lr"]
        e_reg = time.time()
        sum_reg += e_reg - s_reg
        #Add the new layer at appropriate position
        layers = layers[:s_layer] + [layer, new_layer] + layers[s_layer:]
        layers = {str(i): layers[i] for i in range(len(layers))}
        #Create Updated Architecture Array
        _, _, architecture_array = return_arc_array(
            block_model.architecture_array, 0, s_layer, choices
        )
        # Assign the new Growth Block to the model
        model = assign_model(
            model, block, growth_block, GrowthModel(layers, architecture_array, act_on)
        )
        print_all_requires_grad(model)
        cp2 = time.time()
        torch.cuda.empty_cache()

    print("Time for find splits :", convert(cp1 - start))
    print("Total Regression Time:", convert(sum_reg))
    print("Average Regression Time:", convert(sum_reg / len(sel_layers_data)))
    print("Total Max Layer Time:", convert(sum_max))
    print("Total Number of Selected Layers", len(sel_layers_data))
    print("Average Max Layer Time:", convert(sum_max / len(sel_layers_data)))
    print("Misc Time:", convert((cp2 - start) - (cp1 - start)))
    print("Total Time:", convert(cp2 - start))
    print("Negative Index Length:", l_neg)
    return model


def create_new_layer(layer, choices, optimizer, reduction_factor=0.2):
    """
    Creates a New Layer for depth addition. New Layer contains Equal Positive Weights and Negative weights with a reduction factor.
    (PW = reduction_factor * W, NW = -1 * PW)
    Inputs:
        layer -> Layer which we want to inc. depth
        choices -> List of selected neuron positions to increase depth.
        optimizer -> add the new weights and bias to the optimizer.
        reduction_factor -> multiplicative factor applied to new positive and negative weights to prevent same a neuron copy.

    Outputs:
        new_layer -> newly created layer
        optimizer -> Updated Optimizer
    """
    # A single new layer for both positive and negative weights input -> same input as previous layer. output -> num(positve+negative weights)
    new_layer = nn.Linear(layer.in_features, 2 * len(choices))
    # Layer -> Weight (nn.Paramer) -> Tensor (nn.Parameter.data)
    layer_weight = layer.weight.data
    layer_bias = layer.bias.data
    # new_bias = bias[selected_neurons] * reduction factor. suppose 5 neurons are selected out of 10 shape of bias = (1,5)
    bias = layer_bias[choices] * reduction_factor
    # new_weight = weight[selected_neurons] * reduction factor. suppose 5 neurons are selected out of 10 shape of bias = (5,X)
    weight = layer_weight[choices, :] * reduction_factor
    # Negative Weights and Bias.
    bias_neg = layer_bias[choices] * reduction_factor * -1
    weight_neg = layer_weight[choices, :] * reduction_factor * -1
    # Concatenate positive and negative weights
    weight_f = torch.cat([weight, weight_neg], axis=0)
    bias_f = torch.cat([bias, bias_neg], axis=0)
    # assign weights and bias to the weights and bias of new layer.
    weight_f.requires_grad = True
    # weight = nn.Parameter(Tensor)
    new_layer.weight = nn.Parameter(weight_f)
    # Double checking that requires_grad=True
    new_layer.weight.requires_grad = True
    bias_f.requires_grad = True
    new_layer.bias = nn.Parameter(bias_f)
    new_layer.bias.requires_grad = True
    # Append New Parameters to Param Group of optimizer. The bias has no weight decay(0) and weight has weight decay(1)
    optimizer.param_groups[0]["params"].append(new_layer.bias)
    optimizer.param_groups[1]["params"].append(new_layer.weight)
    return new_layer, optimizer


def remove_garbage(model):
    "Remove Not Needed activations and gradients to free memory"
    l, la = get_all_linear_layers_transformer(model)
    for i, layer in enumerate(l):
        gradient = layer.weight.grad.cpu().detach().clone()
        del gradient
        gc.collect()
        torch.cuda.empty_cache()
