import torch
import copy 
from constants import INPUT_TYPES, GPU, FLOAT_PRECISION
from utils import evaluate, cross_entropy_high_precision
DEFAULT_METRICS = {'loss': {}, 'sparsity': {}, 'activation_l1': {}, 'weights_l2': {}, 'bias_l2': {},
                   'activating_samples_in_common': {}, 'accuracy': {}, 'output_mean': {}, 'output_std': {},
                   'logit_max': {}, 'logit_min': {}, 'logit_floor':{},'softmax_saturation': {},
                   'zero_weights': {}, 'zero_bias': {}, 'scale_parameters':{}, 'neuron_coherence': {}, 'positive_weight_norm':{}, 'negative_weight_norm': {},
                   'pseudo_sparsity': {}, 'exp_avg_weight': {}, 'exp_avg_sq_weight': {}, 'exp_avg_bias': {}, 'exp_avg_sq_bias': {}, 'neuron_efficiency': {},
                   'representation_orthogonality': {}, "l2_loss":{}, "classification_loss": {}}

ACTIVATION_STATISTIC_NAMES = ['output_mean', 'output_std', 'logit_max', 'logit_min', 'logit_floor', 'neuron_coherence', 'positive_weight_norm', 'negative_weight_norm' ,'softmax_saturation',
                              'l2_loss', 'classification_loss']
device = GPU if torch.cuda.is_available() else "cpu"

class MetricsLogger:
    def __init__(self, layers: list, num_epochs: int, logg_frequency: int):
        self.metrics = {}
        self.layers = layers
        self.num_epochs = num_epochs
        self.logg_frequency = logg_frequency
        self.num_early_training_epochs = 0*int(num_epochs*0)
        self.early_training_logg_frequency = 10
        if self.num_early_training_epochs  == 0:
            self.num_logged_epochs = (num_epochs)//logg_frequency + self.num_early_training_epochs//self.early_training_logg_frequency
        else:
            self.num_logged_epochs = (num_epochs)//logg_frequency 
        for input_type in INPUT_TYPES:
            self.metrics[input_type] = self.create_metrics_dict()

    def create_metrics_dict(self) -> dict:
        metrics = copy.deepcopy(DEFAULT_METRICS)
        for metric in metrics:
            if metric in ["loss", "accuracy"] or metric in ACTIVATION_STATISTIC_NAMES:
                metrics[metric] = torch.zeros(self.num_logged_epochs)
            else:
                for layer in self.layers:
                    metrics[metric][layer] = torch.zeros(self.num_logged_epochs)
        return metrics

    def log_metrics(self,input_type, activations: dict, epoch: int, data_size: int) -> dict:

        for layer in self.layers:
            self.metrics[input_type]['sparsity'][layer][epoch] += (activations[layer] == 0).sum().item() / (activations[layer].shape[1]*data_size)
            self.metrics[input_type]['activation_l2'][layer][epoch] += activations[layer].abs().sum().item() / (activations[layer].shape[1]*data_size)
        return self.metrics
    
    def log_activation_statistics(self, model, test_loader, epoch):
        device = "cpu"
        model.to("cpu")
        epoch_position = epoch//self.logg_frequency
        float_precision = next(model.parameters()).dtype
        if epoch> self.num_early_training_epochs:
            epoch_position += self.num_early_training_epochs//self.early_training_logg_frequency
        test_size = len(test_loader.dataset)
        for data, *_ in test_loader:
            data = data.to(float_precision)
            output = model(data.to(device), keep_activations=True)#, activations_from_abs_input=True)
            #loss = high
            #gradients = torch.autograd.grad(output, model.activations[-2], grad_outputs=torch.ones_like(output))

            for i in range(len(model.layers)):
                #self.metrics["train"]["neuron_efficiency"][f"linear_{i}"][epoch_position] += (model.activations[i].abs().sum()/(model.activations_from_abs_input[i].T[model.activations[i].abs()>0]+1e-8).sum()).item()/test_size/200
                #cosine_sim = torch.nn.functional.cosine_similarity(model.activations[i].unsqueeze(1), model.activations[i].unsqueeze(0), dim=2)
                #off_diagonal_cosine_sim = cosine_sim - torch.eye(cosine_sim.size(0))
                #orthogonality_measure = off_diagonal_cosine_sim.abs().mean().item()
                #self.metrics["train"]["representation_orthogonality"][f"linear_{i}"][epoch_position] +=  orthogonality_measure/len(test_loader)
                hidden_size = model.layers[-1].weight.shape[1]
                self.metrics["train"]["sparsity"][f"linear_{i}"][epoch_position] += (model.activations[i]==0).sum().item()/test_size/hidden_size
                self.metrics["train"]["activation_l1"][f"linear_{i}"][epoch_position] += model.activations[i].abs().sum().item()/test_size/hidden_size
                self.metrics["train"]["pseudo_sparsity"][f"linear_{i}"][epoch_position] += (model.activations[i]<0.001).sum().item()/test_size/hidden_size


            #self.metrics["train"]["neuron_coherence"][epoch_position] = gradients[0].square().sum()

            self.metrics["train"]["output_mean"][epoch_position] += output.sum()/test_size
            self.metrics["train"]["output_std"][epoch_position] += output.std(dim=1).sum()/test_size
            self.metrics["train"]["logit_max"][epoch_position] += output.amax(dim=1).sum()/test_size
            self.metrics["train"]["logit_min"][epoch_position] += output.amin(dim=1).sum()/test_size
            self.metrics["train"]["logit_floor"][epoch_position] += ((output.sum(dim=1)
                                                                -output.amin(dim=1))
                                                                /output.shape[1]).sum()/test_size
            self.metrics["train"]["softmax_saturation"][epoch_position] += torch.nn.functional.softmax(
                                                                        output.to(torch.float64), 
                                                                        dim=-1).amax(dim=1).sum()/test_size
                
    def log_weight_statistics(self, model, train_loader, test_loader, epoch, save_model_checkpoints, saved_models, layers):
        epoch_position = epoch//self.logg_frequency
        if epoch> self.num_early_training_epochs:
            epoch_position += self.num_early_training_epochs//self.early_training_logg_frequency

        if epoch in save_model_checkpoints:
            saved_models[epoch] = copy.deepcopy(model.state_dict())
        model.eval()
        for i, layer in enumerate(layers):
            self.metrics["general"]["weights_l2"][layer][epoch_position] = model.layers[i].weight.square().mean().item()
            if model.uses_bias:
                self.metrics["general"]["bias_l2"][layer][epoch_position] = model.layers[i].bias.square().mean().item() 

        self.metrics["general"]["positive_weight_norm"][epoch_position] = model.layers[-1].weight[model.layers[-1].weight>0].square().sum().item()/model.layers[-1].weight.numel() 
        self.metrics["general"]["negative_weight_norm"][epoch_position] = model.layers[-1].weight[model.layers[-1].weight<0].square().sum().item()/model.layers[-1].weight.numel() 

        self.metrics["train"]["loss"][epoch_position], self.metrics["train"]["accuracy"][epoch_position] = evaluate(model, train_loader)
        self.metrics["test"]["loss"][epoch_position], self.metrics["test"]["accuracy"][epoch_position] = evaluate(model, test_loader)

    def log_optimizer_statistics(self, optimizer, optimizer_type, epoch):
        epoch_position = epoch//self.logg_frequency

        for param_group in optimizer.param_groups:
            for i, param in enumerate(param_group['params']):
                if param in optimizer.state:
                    state = optimizer.state[param]
                    first_moment = state['exp_avg'] 
                    second_moment = state['exp_avg_sq']
                    if i%2==0:
                        self.metrics["train"]["exp_avg_weight"][f"linear_{i//2}"][epoch_position] = first_moment.abs().mean()
                        self.metrics["train"]["exp_avg_sq_weight"][f"linear_{i//2}"][epoch_position] = second_moment.abs().mean()

                    else:
                        self.metrics["train"]["exp_avg_bias"][f"linear_{i//2}"][epoch_position] = first_moment.abs().mean()
                        self.metrics["train"]["exp_avg_sq_bias"][f"linear_{i//2}"][epoch_position] = second_moment.abs().mean()
                else:
                    print(f'Skipping param {i}')

