import torch
import math
import torch.nn as nn
import numpy as np
import random
import os
from datetime import datetime

from torch.distributed import init_process_group, destroy_process_group

from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist

def set_seed(seed, deterministic = True):
    """Set all relevant seeds for full reproducibility in 
    PyTorch + NumPy + random."""
    
    # Python built-in random
    random.seed(seed)
    
    # NumPy random
    np.random.seed(seed)
    
    # PyTorch CPU
    torch.manual_seed(seed)
    
    # PyTorch CUDA
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # Environment hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    # CuDNN deterministic behavior
    if deterministic:
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)  # PyTorch >= 1.8
    else:
        os.environ.pop('CUBLAS_WORKSPACE_CONFIG', None)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.use_deterministic_algorithms(False)


def create_name(conf):
    date_time = datetime.now() .strftime("_%Y-%m-%d_%H-%M-%S")
    name = f"{conf.dataset}_{conf.optim}_ne{conf.num_epochs}"
    # Sparse updates
    if conf.optim in ['GradSkip', 'LinBregSparse', 'SGD-sparse']:
        name += f'_lambda{conf.lambda0}_updtf{conf.full_update_frequency}_updtd{conf.full_update_duration}_{conf.full_update_mode[0]}'
    # Sparse updates using multilevel update rule
    if conf.optim in ['LinBregSparseML']:
        name += f'_kappa{conf.kappa}_eps{conf.eps}'
    # Split learning
    if conf.optim in ['LinBregSplit']:
        name += f'_binf{conf.make_binary_frequency}_dropf{conf.drop_in_frequency}p{conf.drop_in_prob}'
    return name + date_time


def get_writer(log_dir, local_rank=0, distributed=False):
    if distributed:
        if local_rank == 0:
            writer = SummaryWriter(log_dir=log_dir)
        else:
            writer = None
    else:
        writer = SummaryWriter(log_dir=log_dir)
    
    return writer

#%% Distributed learning

def ddp_setup(rank, world_size, backend='nccl'):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    if world_size > 1:    
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12355"
        torch.cuda.set_device(rank)
        device = torch.device(f"cuda:{rank}")
        backend = backend if torch.cuda.is_available() else "gloo"
        init_process_group(backend=backend, rank=rank, world_size=world_size)
        
        
        print(f"\nHello, I am on rank:{rank}\n")
        if rank == 0:
            print(f"[Distributed] Initialized {world_size} processes with backend '{backend}'.")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("[Distributed] Running in single-process (non-distributed) mode.")
    
    return device


def cleanup_distributed():
    if dist.is_initialized():
        dist.destroy_process_group()
        
        
        


#%% Retrieval of weights

def get_weights_conv(model):
    for m in model.modules():
        if isinstance(m, torch.nn.Conv2d):
            yield m.weight
        else:
            continue
            
def get_weights_linear(model):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            yield m.weight
        else:
            continue

def get_weights_batch(model):
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            yield m.weight
        else:
            continue

def get_bias(model):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.BatchNorm2d):
            if not (m.bias is None):
                yield m.bias
        else:
            continue
        
#%% Definition/construction of loss function
def get_loss(conf):
    if conf.loss == 'CrossEntropy':
        return nn.CrossEntropyLoss()
    else:
        raise NotImplementedError(f'Did not recognise choice of loss function {conf.loss}')

#%% Model sparsity initialisation.
# Taken from BregmanLearning https://github.com/TimRoith/BregmanLearning
def init_weight_bias_normal(m):
    if type(m) == nn.Linear:
        m.weight.data = torch.randn_like(m.weight.data)
        m.bias.data = torch.randn_like(m.bias.data)
        
def sparsify_(model, sparsity, ltype = nn.Linear, conv_group=True, row_group = False):       
    for m in model.modules():
        if not isinstance(m, ltype):
            continue
            
        elif (isinstance(m, nn.Linear) and not row_group) or (isinstance(m, nn.Conv2d) and not conv_group):
            s_loc = 1-sparsity # Proportion of nonzero entrys
            mask = torch.bernoulli(s_loc*torch.ones_like(m.weight))
            m.weight.data.mul_(mask)
            
        elif isinstance(m, nn.Linear): # row sparsity
            s_loc = 1-sparsity
            w = m.weight.data
            mask = torch.bernoulli(s_loc*torch.ones(size=(w.shape[0],1),device=w.device))
            #
            m.weight.data.mul_(mask)
            
        elif isinstance(m, nn.Conv2d): # kernel sparsity
            s_loc = 1-sparsity
            w = m.weight.data
            n = w.shape[0]*w.shape[1]
            
            # assign mask
            mask = torch.zeros(n,1,device=w.device)
            idx = torch.randint(low=0,high=n,size=(math.ceil(n*s_loc),))
            mask[idx] = 1
            
            # multiply with mask
            c = w.view(w.shape[0]*w.shape[1],-1)
            m.weight.data = mask.mul(c).view(w.shape)
        
def sparse_bias_uniform_(model,r0,r1,ltype = nn.Linear):
    for m in model.modules():
        if isinstance(m,ltype):
            if hasattr(m, 'bias') and not (m.bias is None):
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound0 = r0 / math.sqrt(fan_in)
                bound1 = r1/math.sqrt(fan_in)
                nn.init.uniform_(m.bias, -bound0, bound1)
                
def bias_constant_(model,r):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            if type(m) == nn.Linear:
                nn.init.constant_(m.bias, r)           
                
def sparse_weight_normal_(model,r,ltype = nn.Linear):
    for m in model.modules():
        if isinstance(m,ltype):
            nn.init.kaiming_normal_(m.weight)
            m.weight.data.mul_(r)
                

def sparse_weight_uniform_(model,r):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            #nn.init.kaiming_uniform_(m.weight, a=r*math.sqrt(5))
            fan = nn.init._calculate_correct_fan(m.weight, 'fan_in')
            std = r / math.sqrt(fan)
            bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
            
            with torch.no_grad():
                m.weight.uniform_(-bound, bound)

def init_weights(conf, model):
    # Initialize linear and conv layers
    sparse_bias_uniform_(model, 0,conf.r[0]) 
    sparse_bias_uniform_(model, 0,conf.r[0], ltype=torch.nn.Conv2d) 
    sparse_weight_normal_(model, conf.r[1])  
    sparse_weight_normal_(model, conf.r[2], ltype=torch.nn.Conv2d)  
    #
    sparsify_(model, conf.model_init_sparsity, ltype = nn.Conv2d, conv_group=conf.conv_group)
    sparsify_(model, conf.model_init_sparsity, ltype = nn.Linear)
    # model = model.to(conf.device)    
    return model
    