import random
import time
import torch
import torch.nn as nn
from constants import GPU, FLOAT_PRECISION_MAP
from logger import MetricsLogger
from torch.utils.data import DataLoader
from utils import (evaluate, 
                   cross_entropy_high_precision,
                   stable_cross_entropy,
                   stable_sum,
                   cross_entropy_float32,
                   kahan_cross_entropy,
                   kahan_sum,
                   cross_entropy_low_precision,
                   update_results,
                   get_specified_args,
                   get_dataset,
                   get_model,
                   parse_args,
                   get_optimizer,
                   stablemax_cross_entropy)


torch.set_num_threads(5) 
random.seed(42)
torch.manual_seed(42)
parser, args = parse_args()

FLOAT_PRECISION = FLOAT_PRECISION_MAP[args.float_precision]

device = GPU if torch.cuda.is_available() else "cpu"
print("Using device:", device)

train_dataset, test_dataset = get_dataset(args)
if args.full_batch:
    args.batch_size = len(train_dataset)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)


torch.save(train_dataset, "last_train_loader.pt")
torch.save(test_dataset, "last_test_loader.pt")

args.lr = args.lr/(args.alpha**2)

model = get_model(args).to(torch.float32)
layer_names = [f"linear_{i}" for i in range(len(args.hidden_sizes) + 1)]
logger = MetricsLogger(layer_names, args.num_epochs, args.log_frequency)
optimizer = get_optimizer(model, args)



print(args.loss_function)
cross_entropy_function = {
    16: cross_entropy_low_precision,
    32: cross_entropy_float32,
    64: cross_entropy_high_precision
}

loss_functions = {
    "cross_entropy":cross_entropy_function[args.float_precision],
    "l1": nn.L1Loss(),
    "MSE": nn.MSELoss(),
    "stable_ce": stable_cross_entropy,
    "kahan": kahan_cross_entropy,
    "stablemax": stablemax_cross_entropy
}
loss_function = loss_functions[args.loss_function]
save_model_checkpoints = range(0, args.num_epochs, args.log_frequency)
saved_models = {epoch: None for epoch in save_model_checkpoints}

logger.metrics["train"]["classification_loss"] = torch.zeros(logger.num_logged_epochs)
if args.regularization!="None":
    logger.metrics["train"][f"{args.regularization}_loss"] = torch.zeros(logger.num_logged_epochs)
softmax_temperature = 1

logger.log_frequency = args.log_frequency
if args.full_batch == True:
    if "MNIST" in args.dataset:
        all_data = train_dataset.data.to(device).to(FLOAT_PRECISION)/255
        all_targets = train_dataset.targets.to(device).to(FLOAT_PRECISION)
    else:
        print(train_dataset.dataset.data.shape)
        print(len(train_dataset.indices))
        all_data = train_dataset.dataset.data[train_dataset.indices].to(device).to(torch.float32)
        all_targets = train_dataset.dataset.targets[train_dataset.indices].to(device).to(torch.float32)
    print(all_data.shape)
    print(all_targets.shape)
logger.metrics["train"]["normalized_margin"] = []
logger.metrics["train"]["loss_after_update"] = []
logger.metrics["train"]["zero_terms"] = []
logger.metrics["train"]["softmax_collapse"] = []
logger.metrics["train"]["exponential_underflow"] = []
logger.metrics["train"]["samples_with_zero_gradients"] = []
logger.metrics["train"]["percentage_absoption"] = []
logger.metrics["train"]["percentage_zero_grad"] = []


logger.metrics["train"]["gradient_norm"] = {name:[] for name,_ in model.named_parameters()}
logger.metrics["train"]["first_order_moment"] = {i:[] for i in range(3)}

logger.metrics["train"]["second_order_moment"] = {i:[] for i in range(3)}
logger.metrics["cosine_nlm"] = {name:[] for name,_ in model.named_parameters()}



loss = torch.inf
for epoch in range(args.num_epochs):

    model.train()
    optimizer.zero_grad()
    output = model(all_data) 
    output = output*args.alpha
    loss = loss_function(output, all_targets)
    loss.backward()

    if args.orthogonal_gradients:
        for name, param in model.named_parameters():
            if epoch<4000:
                if param.grad is not None:
                    w = param.data.view(-1)
                    g = param.grad.data.view(-1)

                    squared_norm = torch.dot(w, w) + 1e-30
                    proj = torch.dot(w, g) / squared_norm
                    
                    g_orth = g - proj * w

                    g_orth = g_orth.view(param.grad.data.shape)
                    norm_g = torch.norm(g)
                    norm_g_orth = torch.norm(g_orth) + 1e-30

                    g_orth_scaled = g_orth * (norm_g / norm_g_orth)
                    param.grad.data.copy_(g_orth_scaled)
    if epoch % logger.log_frequency==0:
        for name, p in model.named_parameters():
            logger.metrics["cosine_nlm"][name].append(torch.nn.functional.cosine_similarity(p.view(1,-1),-p.grad.view(1, -1), dim=1).item())

        grad_nonzero_any = [(p.grad != 0).bool() for p in model.parameters()]
        total_numel = sum(list(parameter.numel() for parameter in model.parameters()))
        for name, p in model.named_parameters():
            logger.metrics["train"]["gradient_norm"][name].append(sum(p.grad.square().flatten()).item())

    optimizer.step()

    if epoch % logger.log_frequency==0:
        if args.loss_function!="MSE":
            full_loss = loss_function(output, all_targets, reduction="none")
            logger.metrics["train"]["zero_terms"].append(((full_loss==0).sum()/(full_loss.shape[0]*full_loss.shape[1])).item())
            if args.float_precision == 64:
                output = output.to(torch.float64)
            if args.float_precision == 32:
                output = output.to(torch.float32)
            if args.float_precision == 16:
                output = output.to(torch.float16)

            output_off = output - output.amax(1, keepdim=True)
            exp_output = torch.exp(output_off)
            sum_function = stable_sum if args.loss_function == "stable_ce" else torch.sum
            sum_exp = sum_function(exp_output, dim=-1, keepdim=True)
            sofmax_collapse = exp_output.amax(1)==sum_exp.unsqueeze(1)
            exponential_underflow = ((exp_output/sum_exp)==0) | (sum_exp==0)
            logger.metrics["train"]["softmax_collapse"].append(sofmax_collapse.float().mean().item()) 
            logger.metrics["train"]["exponential_underflow"].append(exponential_underflow.float().mean().item()) 


        logger.log_activation_statistics(model, test_loader, epoch)
        logger.log_weight_statistics(model, train_loader, test_loader, epoch, save_model_checkpoints, saved_models, layer_names)
        print(f'Epoch {epoch}: Training loss: {loss.item():.4f}')
        if epoch>1:
            print(f"Time taken for the last {args.log_frequency} epochs: {(time.time() - start_time)/60} min")
        start_time = time.time()
        model.to(device)


model.eval().to('cpu')
test_loss, test_accuracy = evaluate(model, test_loader)
print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}')
args.lr = args.lr*(args.alpha**2)

specified_args = get_specified_args(parser, args)
if len(specified_args.keys()) == 0:
    experiment_key = f'{args.dataset}_default'
else:
    experiment_key = f'{args.dataset}|'+ '|'.join([f'{key}-{str(specified_args[key])}' for key in specified_args.keys()])

torch.save(saved_models, 'last_run_saved_model_checkpoints.pt')
torch.save(optimizer, 'last_optimizer.pt')
update_results('experiment_results.pt', experiment_key, logger.metrics)
print(f"Saving run: {experiment_key}")