import argparse
import copy
import itertools
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from torch.utils.data.dataset import random_split
from datasets import (AlgorithmicDataset, 
                      SparseParityDataset, 
                      BinaryAlgorithmicDataset,
                      ScalarAlgorithmicDataset,)
from models import MLP
from binary_operations import (product_mod,
                               add_mod,
                               subtract_mod)
from constants import GPU, FLOAT_PRECISION_MAP




def stable_softmax(x):
    exp = torch.exp(x)
    sum_exp = stable_sum(torch.exp(x), dim=-1, keepdim=True)
    return exp/sum_exp

def stable_sum(tensor, dim=-1, keepdim=True):
    sorted_tensor, _ = torch.sort(tensor, dim=dim)
    result = torch.cumsum(sorted_tensor, dim=dim)[:,-1]
    if keepdim:
        result = result.unsqueeze(dim)
    return result

def stable_logsoftmax(x):
    x_off = x  - x.amax(1, keepdim=True)
    sum_exp = stable_sum(torch.exp(x_off), dim=-1, keepdim=True)
    return x_off - torch.log(sum_exp)


def one_hot_encode(number, size):
    one_hot = torch.zeros(size)
    one_hot[number] = 1
    return one_hot

def kahan_sum(input_tensor, dim=-1, keepdim=True):
    sum_tensor = torch.zeros_like(input_tensor.select(dim, 0))
    c_tensor = torch.zeros_like(sum_tensor)

    # Iterate along the specified dimension
    for i in range(input_tensor.size(dim)):
        current_slice = input_tensor.select(dim, i)
        t = sum_tensor + current_slice
        mask = (torch.abs(sum_tensor) >= torch.abs(current_slice))
        c_tensor += torch.where(mask, (sum_tensor - t) + current_slice, (current_slice - t) + sum_tensor)
        sum_tensor = t
        result = sum_tensor + c_tensor
        if keepdim:
            result = result.unsqueeze(dim)
    return result

def kahan_softmax(x):
    return torch.exp(x)/ kahan_sum(torch.exp(x), dim=-1, keepdim=True)

def log_kahan_softmax(x, dim=-1, keepdim=True):
    x_off = x - x.amax(1, keepdim=True)
    return x_off -  torch.log(kahan_sum(torch.exp(x_off), dim=dim, keepdim=keepdim))



def kahan_cross_entropy(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = log_kahan_softmax(logits)#torch.log(kahan_softmax(logits))
    if (logprobs==0).sum()>0 and False:
        print(f"Number of logprobs == 0 in batch : {(logprobs==0).sum()}")
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1)#.to(torch.float64)

    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss


def stable_cross_entropy(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = stable_logsoftmax(logits.to(torch.float64))#
    if (logprobs==0).sum()>0 and False:
        print(f"Number of logprobs == 0 in batch : {(logprobs==0).sum()}")
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(torch.float64)

    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss




def cross_entropy_high_precision(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float64), dim=-1)
    if (logprobs==0).sum()>0 and False:
        print(f"Number of logprobs == 0 in batch : {(logprobs==0).sum()}")
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(torch.float64)

    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss

def soften_logits(x, threshold=0):
    return torch.where(
        x < -threshold,
        -threshold - torch.log((-x - threshold) + 1),
        torch.where(
            x > threshold,
            threshold + torch.log((x- threshold) + 1 ),
            x
        )
    )

def log_softermax(x, dim=-1):
    with torch.no_grad():
        mean = x.mean(1, keepdim=True).detach()
    softer_x = soften_logits(x) #-mean
    return softer_x - torch.log(stable_sum(torch.exp(softer_x), dim=dim, keepdim=True))

def s(x):
    return torch.where(
        x<0,
        1/(1-x),
        x + 1
    )

def log_stablemax(x, dim=-1):
    with torch.no_grad():
        mean = x.mean(1, keepdim=True).detach()
    s_x = s(x)
    return torch.log(s_x/stable_sum(s_x, dim=dim, keepdim=True))


def stablemax_cross_entropy(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
    if (logprobs==0).sum()>0 and False:
        print(f"Number of logprobs == 0 in batch : {(logprobs==0).sum()}")
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(torch.float64)

    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss

def cross_entropy_high_precision(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float64), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(torch.float64)

    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss

def cross_entropy_float32(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float32), dim=-1)

    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(torch.float32)
    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss

def log_softmax(x, dim=-1, keepdim=True):
    x_off = x - x.amax(dim, keepdim=True)
    return x_off -  torch.log(torch.sum(torch.exp(x_off), dim=dim, keepdim=keepdim))

def cross_entropy_low_precision(logits, labels, label_argmax=True, reduction="mean"):
    if label_argmax:
        if len(labels.shape) ==1:
            labels = torch.nn.functional.one_hot(labels.long(), 10)
        labels = labels.argmax(dim=-1)
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float16), dim=-1)

    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(torch.float16)
    loss = -torch.mean(prediction_logprobs) if reduction=="mean" else - prediction_logprobs
    return loss

def update_results(filename, experiment_key, logger_metrics):
    try:
        results = torch.load(filename)
    except:
        results = {}
        
    results[experiment_key] = logger_metrics
    torch.save(results, filename)

def evaluate(model, data_loader, loss_function=cross_entropy_high_precision):
    
    model.eval()
    loss = 0
    correct = 0
    device = next(model.parameters()).device
    float_precision = next(model.parameters()).dtype
    with torch.no_grad():
        for data, target, *_ in data_loader:
            label_argmax = len(target.shape)!=1
            output = model(data.to(device).to(float_precision)).to("cpu")#[:, -1]
            loss += loss_function(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            if label_argmax:
                target = target.argmax(dim=1)
            correct += pred.eq(target.to("cpu").view_as(pred)).sum().item()
    loss /= len(data_loader)
    accuracy = 100 * correct / len(data_loader.dataset)
    return loss, accuracy


def evaluate_transformer(model, data_loader, loss_function=cross_entropy_high_precision):
    model.eval()
    loss = 0
    correct = 0
    device = next(model.parameters()).device
    float_precision = next(model.parameters()).dtype
    with torch.no_grad():
        for data, target, *_ in data_loader:
            label_argmax = len(target.shape)!=1
            output = model(data.to(device)).to("cpu")[:, -1]
            loss += loss_function(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            if label_argmax:
                target = target.argmax(dim=1)
            correct += pred.eq(target.to("cpu").view_as(pred)).sum().item()
    loss /= len(data_loader)
    accuracy = 100 * correct / len(data_loader.dataset)
    return loss, accuracy

def evaluate_full_batch(model, data, target, loss_function=cross_entropy_high_precision):
    
    model.eval()
    loss = 0
    correct = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        output = model(data.to(device)).to("cpu")#[:, -1]
        loss += loss_function(output, target.to("cpu")).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.argmax(dim=1).to("cpu").view_as(pred)).sum().item()
    loss /= len(target)
    accuracy = 100. * correct / len(target)
    return loss, accuracy


def get_specified_args(parser, args):

    defaults = {action.dest: action.default
                for action in parser._actions
                if action.dest != 'help'}
    
    specified = {arg: getattr(args, arg)
                 for arg in vars(args)
                 if getattr(args, arg) != defaults.get(arg)}
    
    return specified

class Config:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
            
    def __repr__(self):
        attributes = ', '.join(f"{key}={value!r}" for key, value in self.__dict__.items())
        return f"Config({attributes})"
    

def split_dataset(dataset, train_fraction, batch_size):
    total_size = len(dataset)
    train_size = int(train_fraction * total_size)
    test_size = total_size - train_size
    print(f'Starting trining. Train dataset size: {train_size}, Test size: {test_size}')
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    return train_dataset, test_dataset

def reduce_train_dataset(original_train_dataset, reduced_fraction, batch_size):
    original_indices = original_train_dataset.indices
    reduced_train_size = int(reduced_fraction * len(original_indices))
    reduced_indices = original_indices[:reduced_train_size]
    reduced_train_dataset = Subset(original_train_dataset, reduced_indices)
    
    reduced_train_loader = DataLoader(reduced_train_dataset, batch_size=batch_size, shuffle=True)
    return reduced_train_loader

BINARY_OPERATION_MAP =  {"add_mod": add_mod,
                         "product_mod": product_mod,
                         "subtract_mod": subtract_mod
                         }
def get_dataset(args):
    if args.dataset == "sparse_parity":
        print(args.num_parity_features, args.num_noise_features)
        dataset = SparseParityDataset(args.num_parity_features, args.num_noise_features, args.num_samples)
        
    elif args.dataset == "MNIST":
        train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
        test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())
        train_dataset.targets =  torch.nn.functional.one_hot(train_dataset.targets[:args.num_samples], 10)
        test_dataset.targets =  torch.nn.functional.one_hot(test_dataset.targets[:args.num_samples], 10)
        train_dataset.data = train_dataset.data[:args.num_samples]
        test_dataset.data = test_dataset.data[:args.num_samples]
        
    elif args.dataset == "binary_alg":
        dataset = BinaryAlgorithmicDataset(BINARY_OPERATION_MAP[args.binary_operation], p=args.modulo, input_size=args.input_size, output_size=args.modulo)
    elif args.dataset == "scalar_alg":
        dataset = ScalarAlgorithmicDataset(BINARY_OPERATION_MAP[args.binary_operation], p=args.modulo, input_size=args.input_size, output_size=args.modulo)
    else: 
        dataset = AlgorithmicDataset(BINARY_OPERATION_MAP[args.binary_operation], p=args.modulo, input_size=args.input_size, output_size=args.modulo)
    
    if not "MNIST" in args.dataset:
        train_dataset, test_dataset = split_dataset(dataset, args.train_fraction, args.batch_size)

    return train_dataset, test_dataset

def generate_random_one_hot(length):
    index = torch.randint(0, length, (1,)).item()
    one_hot_vector = torch.zeros(length)
    one_hot_vector[index] = 1
    return one_hot_vector

def get_model(args):
    NUM_SPURIOUS = 1
    device = GPU

    if args.dataset == "sparse_parity":
        model = MLP(input_size= args.num_parity_features + args.num_noise_features, output_size=2, hidden_sizes=args.hidden_sizes, 
                    freeze_layers=args.freeze_layers).to(device) 
    elif args.dataset == "MNIST":
        model = MLP(input_size=28*28, output_size=10, hidden_sizes=args.hidden_sizes, 
                    freeze_layers=args.freeze_layers).to(device)
        with torch.no_grad():
            for name, p in model.named_parameters():
                p.data = 100. * p.data

    elif args.dataset == "binary_alg":
        model = MLP(input_size=(args.input_size - 1).bit_length()*2, output_size=args.modulo, hidden_sizes=args.hidden_sizes, 
                        freeze_layers=args.freeze_layers).to(device)
    elif args.dataset == "scalar_alg":
        model = MLP(input_size=2, output_size=args.modulo, hidden_sizes=args.hidden_sizes, 
                        freeze_layers=args.freeze_layers).to(device)
                    
    else:
        print("Using AlgorithmicDataset")
        model = MLP(input_size=args.input_size*2, output_size=args.modulo, hidden_sizes=args.hidden_sizes, 
                    freeze_layers=args.freeze_layers, bias=False).to(device).to(FLOAT_PRECISION_MAP[args.float_precision])
    return model
        
def get_optimizer(model, args):
    if args.optimizer == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0, eps=1e-8)#, weight_decay=WEIGHT_DECAY)
    elif args.optimizer == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=1e-30, betas=(0.9, 0.99))#, betas=(0.9, 0.98))
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.8, weight_decay=0)
    else: 
        raise ValueError(f'Unsupported optimizer type: {args.optimizer}')
    return optimizer
    
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Train a neural network with specified parameters.")

    parser.add_argument('--hidden_sizes', type=int, nargs='+', default=[200, 200],
                        help='List of hidden layer sizes. Default is [200, 200].')

    parser.add_argument('--num_epochs', type=int, default=1500,
                        help='Number of epochs. Default is 1500.')

    parser.add_argument('--train_fraction', type=float, default=0.3,
                        help='Fraction of data to be used for training. Default is 0.3.')

    parser.add_argument('--modulo', type=int, default=113,
                        help='Modulo value for modular arithmetic datasets. Default is 113.')

    parser.add_argument('--input_size', type=int, default=113,
                        help='Input size for the model. Default is 113.')

    parser.add_argument('--optimizer', type=str, default='AdamW',
                        help='Optimizer to use. Options: AdamW, Adam, SGD. Default is AdamW.')

    parser.add_argument('--loss_function', type=str, default='cross_entropy',
                        help='Loss function to use. Options: cross_entropy, MSE. Default is cross_entropy.')

    parser.add_argument('--log_frequency', type=int, default=50,
                        help='Logging frequency (in epochs). Default is 50.')

    parser.add_argument('--regularization', type=str, default="None",
                        help='Regularization method. Options: None, l1, l2. Default is None.')
    
    parser.add_argument('--binary_operation', type=str, default="add_mod",
                        help='Binary operation for algorithmic tasks. Options: add_mod, product_mod, subtract_mod')

    parser.add_argument('--lr', type=float, default=None,
                        help='Learning rate. Default is None.')

    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size. Default is 128.')

    parser.add_argument('--freeze_layers', action='store_true', default=False,
                        help='Freeze layers during training. Default is False.')

    parser.add_argument('--full_batch', action='store_true', default=True,
                        help='Use full batch gradient descent. Default is True.')

    parser.add_argument('--dataset', type=str, default="add_mod",
                        help='Dataset to use. Options: rotated_mnist, add_mod. Default is add_mod.')

    parser.add_argument('--temperature_schedule', action='store_true', default=False,
                        help='Use a schedule for softmax temperature. Default is False.')

    parser.add_argument('--num_noise_features', type=int, default=50,
                        help='Number of noise features used for SparseParityDataset. Default is 50.')

    parser.add_argument('--num_parity_features', type=int, default=4,
                        help='Number of parity features used for SparseParityDataset. Default is 4.')

    parser.add_argument('--num_samples', type=int, default=1000,
                        help='Number of samples for SparseParityDataset. Default is 1000.')

    parser.add_argument('--alpha', type=float, default=1.0,
                        help='Alpha coefficient that multiplies the logits. Default is 1.0.')

    parser.add_argument('--lambda_l1', type=float, default=0.00001,
                        help='L1 regularization coefficient. Default is 0.00001.')

    parser.add_argument('--lambda_l2', type=float, default=0.00005,
                        help='L2 regularization coefficient. Default is 0.00005.')

    parser.add_argument('--scale_down', type=str, default="None",
                        help='Option to scale down data. Default is None.')

    parser.add_argument('--float_precision', type=int, default=32,
                        help='Floating point precision: 16, 32, or 64. Default is 32.')

    parser.add_argument('--weight_decay', type=float, default=0,
                        help='Weight decay (L2 penalty) coefficient. Default is 0.')

    parser.add_argument('--use_lr_scheduler', action='store_true', default=False,
                        help='Use a learning rate scheduler. Default is False.')

    parser.add_argument('--orthogonal_gradients', action='store_true', default=False,
                        help='Use orthogonal gradients regularization. Default is False.')

    parser.add_argument('--asc', action='store_true', default=False,
                        help='Use Anti-Symmetric Connection regularization. Default is False.')

    return parser, parser.parse_args()
