import os
import copy
import dill
from datetime import datetime
from random import randint
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch import nn
from tqdm.auto import tqdm
from pactl.nn.projectors import FixedNumpySeed, FixedPytorchSeed
from pactl.bounds.quantize_fns import quantize_vector
from pactl.bounds.quantize_fns import finetune_quantization, finetune_quantization_all_tasks
from pactl.bounds.quantize_fns import get_message_len
from pactl.bounds.quantize_fns import do_arithmetic_encoding
from pactl.bounds.get_pac_bounds import pac_bayes_bound_opt
from pactl.bounds.get_pac_bounds import compute_catoni_bound
from pactl.bounds.compute_kl_mixture import get_gains
import torch.nn.functional as F
from pathlib import Path
from copy import deepcopy
import json






def compute_quantization_all_tasks(
    meta_learner,
    param_name,
    levels,
    device,
    epochs,
    lr,
    use_kmeans
):
    if levels == 0:
        return None, 0

    learned_centroids = None
    use_finetuning = True if epochs > 0 else False
    vectors = [getattr(model, param_name).cpu().data.numpy() for model in meta_learner.nets]
    length = len(getattr(meta_learner.nets[0], param_name))
    
    if use_finetuning:
        criterion = nn.CrossEntropyLoss()
        qw = finetune_quantization_all_tasks(
            meta_learner=meta_learner,
            param_name=param_name,
            levels=levels,
            device=device,
            epochs=epochs,
            criterion=criterion,
            optimizer='sgd',
            lr=lr,
            use_kmeans=use_kmeans,
        )
        quantized_vec = qw.quantizer(qw.subspace_params, qw.centroids)
        quantized_vec = quantized_vec.cpu().detach().numpy()
        vec = (qw.centroids.unsqueeze(-2) - qw.subspace_params.unsqueeze(-1))**2.0
        symbols = torch.min(vec, -1)[-1]
        symbols = symbols.cpu().detach().numpy()
        centroids = qw.centroids.cpu().detach().numpy()
        learned_centroids = deepcopy(centroids)
        centroids = centroids.astype(np.float16)
        '''
        #save model
        filename = 'ELLA_Subspace/codebooks.txt'
        dict={}
        dict['centroids']= centroids
        dict['symbols']= symbols
        text_dict = {key: value.tolist() for key, value in dict.items()}
        with open(filename, 'w') as f:
            json.dump(text_dict, f, indent=4)
        print(f"Parameters loaded and saved to {filename}")
        '''
        probabilities = np.array([np.mean(symbols == i) for i in range(levels)])
        _, coded_symbols_size = do_arithmetic_encoding(symbols, probabilities,
                                                       qw.centroids.shape[0])
        message_len = get_message_len(
            coded_symbols_size=coded_symbols_size,
            codebook=centroids,
            max_count=len(symbols),
        )
    else:
        module = model.module if isinstance(model,
                                            torch.nn.parallel.DistributedDataParallel) else model
        vector = sum(vectors, [])
        quantized_vec, message_len = quantize_vector(vector, levels=levels, use_kmeans=use_kmeans)
        
    quantized_vecs = [quantized_vec[length*i:length*(i + 1)] for i in range(len(meta_learner.nets))]
    
    return quantized_vecs, message_len, learned_centroids


def quantize(quant_type = 'default', model = None, param_names = None, levels = None, device = None, train_loader = None, epochs = None, lr = None, use_kmeans = None, partition = None, learned_centroids = None, Transfer=False):  
    quantized_vecs, message_len = None, None

    if quant_type == 'default':
        quantized_vecs, message_len = compute_quantization(model, param_names, levels, device, train_loader, epochs, lr, use_kmeans, partition = partition, learned_centroids = learned_centroids, Transfer=Transfer) 
    #in the meta_learner classes , you cast them finally to float16!
    elif quant_type == 'none' or 'float8' or 'float16' or 'float32' or 'float64':
        message_len = 0
        quantized_vecs = []
        for name in param_names:
            vec, m_len = my_quantize_vector(getattr(model, name).cpu().data.numpy(), quant_type)
            quantized_vecs.append(vec)
            message_len += m_len
    else:
        print("quantize type not implemented!")


    return quantized_vecs, message_len


def compute_quantization(
    model,
    param_names,
    levels,
    device,
    train_loader,
    epochs,
    lr,
    use_kmeans,
    partition = None,
    Transfer=False,
    learned_centroids = None
):
    if levels == 0:
        return None, 0

    use_finetuning = True if epochs > 0 else False
    vectors = [getattr(model, name).cpu().data.numpy() for name in param_names]
    lens = [0] + list(np.cumsum([len(vec) for vec in vectors]))
    
    if use_finetuning:
        criterion = nn.CrossEntropyLoss()
        qw = finetune_quantization(
            model=model,
            param_names=param_names,
            levels=levels,
            device=device,
            train_loader=train_loader,
            epochs=epochs,
            criterion=criterion,
            optimizer='sgd',
            lr=lr,
            use_kmeans=use_kmeans,
            partition = partition,
            Transfer=Transfer,
            learned_centroids = learned_centroids
        )
        quantized_vec = qw.quantizer(qw.subspace_params, qw.centroids)
        quantized_vec = quantized_vec.cpu().detach().numpy()
        vec = (qw.centroids.unsqueeze(-2) - qw.subspace_params.unsqueeze(-1))**2.0
        symbols = torch.min(vec, -1)[-1]
        symbols = symbols.cpu().detach().numpy()
        centroids = qw.centroids.cpu().detach().numpy()
        centroids = centroids.astype(np.float16)
        probabilities = np.array([np.mean(symbols == i) for i in range(levels)])
        _, coded_symbols_size = do_arithmetic_encoding(symbols, probabilities,
                                                       qw.centroids.shape[0])
        message_len = get_message_len(
            coded_symbols_size=coded_symbols_size,
            codebook=centroids,
            max_count=len(symbols),
        )
    else:
        module = model.module if isinstance(model,
                                            torch.nn.parallel.DistributedDataParallel) else model
        vector = sum(vectors, [])
        quantized_vec, message_len = quantize_vector(vector, levels=levels, use_kmeans=use_kmeans)
        
    quantized_vecs = [quantized_vec[lens[i]:lens[i + 1]] for i in range(len(lens) - 1)]
    
    return quantized_vecs, message_len 


def set_runnings(model,quantized_mean_vec,quantized_var_vec,device):
    slc = 0
    for _, layer in model._forward_net[0].named_modules():
        if isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            length = len(layer.running_mean)
            layer.running_mean = torch.tensor(quantized_mean_vec[slc:slc+length]).float().to(device)
            layer.running_var = torch.tensor(quantized_var_vec[slc:slc+length]).float().to(device)
            slc += length
    return model

def float_to_fp8(num):
    # 1 sign bit, 4 exponent bits, 3 mantissa bits
    if num == 0:
        return 0
    sign = np.sign(num)
    abs_num = np.abs(num)  
    exponent = np.floor(np.log2(abs_num)) 
    mantissa = abs_num / (2**exponent) - 1 
    bias = 7
    fp8_exponent = int(exponent + bias)
    if fp8_exponent < 0:
        return 0.0  
    if fp8_exponent > 15:
        return np.inf * sign 
    fp8_mantissa = int(np.round(mantissa * (2**3))) 
    if fp8_mantissa >= 8:
        fp8_mantissa = 0
        fp8_exponent += 1
    if fp8_exponent > 15:
        return np.inf * sign  # Overflow
    fp8_value = sign * (1 + fp8_mantissa / (2**3)) * (2**(fp8_exponent - bias))
    return fp8_value

def to_float8(vector):
    return np.array([float_to_fp8(x) for x in vector])
    
#DO NOT use this function for default 
def my_quantize_vector(vector,quantize_type):
    typ = None
    bit = None
    if quantize_type == 'none':
        quantize_type = vector.dtype
        

    if quantize_type == 'float16':
        typ = np.float16
        bit = 16
    elif quantize_type == 'float32':
        typ = np.float32
        bit = 32
    elif quantize_type == 'float64':
        typ = np.float64
        bit = 64
    elif quantize_type == 'wrong':
        typ = np.float32
        bit = 0

    quantized_vec = vector.astype(typ)
    
    if quantize_type == 'float8':
        quantized_vec = to_float8(vector)
        bit = 8
    
    message_len = len(vector)*bit
    return quantized_vec, message_len

