import re, os, shutil, random
import datetime
import numpy as np
from tqdm import tqdm
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from torch.func import functional_call

import subprocess
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

def get_free_memory():
    """
    Get the free memory of each GPU using nvidia-smi command.
    """
    result = subprocess.run(['nvidia-smi', '--query-gpu=memory.free', '--format=csv,noheader,nounits'], stdout=subprocess.PIPE)
    gpu_free_memory = [int(x) for x in result.stdout.decode('utf-8').split('\n') if x]
    return gpu_free_memory

def get_available_device(memory_threshold=0.9):
    """
    Get the available GPU with more than memory_threshold free memory.
    """
    if not torch.cuda.is_available():
        print("No available GPU. Falling back to CPU.")
        return torch.device("cpu")
    free_memory_list = get_free_memory()
    available_device = None
    for i in range(0,torch.cuda.device_count()): 
        total_memory = torch.cuda.get_device_properties(i).total_memory // (1024 * 1024)  # Convert to MB
        free_memory = free_memory_list[i]
        free_memory_percentage = free_memory / total_memory
        print(f"GPU {i}: Free Memory Percentage: {free_memory_percentage * 100:.2f}%")
        if free_memory_percentage > memory_threshold:
            available_device = torch.device(f"cuda:{i}")
            print(f"Using GPU {i} with {free_memory_percentage * 100:.2f}% free memory.")
            break
    if available_device is None:
        print(f"No GPU has more than {memory_threshold * 100}% free memory. Falling back to CPU.")
        available_device = torch.device("cpu")
    return available_device

def save_model(state_dict, is_best, log_dir):
    """
    Save the best and last model
    """
    torch.save(state_dict, log_dir+'/latest.pth')
    if is_best:
        torch.save(state_dict, log_dir+'/best.pth')

def save_py(log_dir, py_dir='./'):
    """
    Save a backup for the scripts in the work folder
    """
    for filename in os.listdir(py_dir):
        if filename.endswith(".py"):
            src_path = os.path.join(py_dir, filename)
            dst_path = os.path.join(log_dir, 'codes', filename)
            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            shutil.copy(src_path, dst_path)

def validation(net, data, criterion, device):
    """
    Valid the model's performance on the test set every epoch
    """
    valid_loss, valid_accuracy = 0., 0.
    net.eval()
    outputs = net(data)
    loss = criterion(outputs[data.test_mask], data.y[data.test_mask])
    pred = outputs.argmax(dim=1)
    valid_accuracy += (pred[data.test_mask] == data.y[data.test_mask]).sum() / int(data.test_mask.sum())
    valid_loss += loss.cpu().detach().numpy()
    net.train()
    return valid_loss, valid_accuracy

class FakeBatchNorm2d(nn.BatchNorm2d):
    """
    A fake BatchNorm2d that does not normalize the input.
    """
    def __init__(self, num_features: int):
        super().__init__(num_features)

    def forward(self, x):
        return self.weight.view(1,-1,1,1) * x

def replace_module(model, target, alternative=None):
    """ 
    Replace target modules with selected alternatives.
    """
    modules_to_replace = []
    for name, module in model.named_modules():
        if isinstance(module, target):
            modules_to_replace.append((name, module))
    for name, module in modules_to_replace:
        if isinstance(module, nn.BatchNorm2d):
            alternative = FakeBatchNorm2d(module.num_features)
        parent_module = model
        *parent_path, child_name = name.split(".")
        for sub_name in parent_path:
            parent_module = getattr(parent_module, sub_name)
        setattr(parent_module, child_name, deepcopy(alternative))
    return model

def para_normalization(x):
    """
    Normalize the parameters to ensure they are non-negative and sum to 1. 
    """
    x = torch.abs(x)
    return x / torch.sum(x)

def dwf_init_param(param, std, D, eps=3e-3, fan_in=0):
    """
    Initialize parameters using the DWF method.
    """
    omega_min = eps ** (1 / D)
    omega_max = min(1.0, 2 * (np.sqrt(2)/ np.sqrt(fan_in)) ** (1 / D))

    # Sample with rejection until all entries fall in range
    initialized = torch.empty_like(param)
    mask = torch.ones_like(param, dtype=torch.bool)

    while mask.any():
        resample = torch.empty_like(param).normal_(0, std)
        accept = (resample.abs() > omega_min) & (resample.abs() < omega_max)
        initialized[mask & accept] = resample[mask & accept]
        mask = (initialized.abs() <= omega_min) | (initialized.abs() >= omega_max)

    return initialized

class GCN(nn.Module):
    def __init__(self, dataset, let=None, D=2, eps=3e-3):
        """
        Initialize the GCN model.

        DWF args:
        - let: a string representing the symbol for the factor (e.g., 'u', 'v', etc.)
        - D: the number of factors to use in the DWF method
        - eps: a minimum absolute value for the parameters to avoid numerical issues
        """
        super(GCN, self).__init__()

        self.D = D
        self.eps = eps

        self.conv1 = GCNConv(dataset.num_node_features, 512)
        self.conv2 = GCNConv(512, 256)
        self.conv3 = GCNConv(256, 256)
        self.conv4 = GCNConv(256, 256)
        self.conv5 = GCNConv(256, 256)
        self.conv6 = GCNConv(256, 64)
        self.conv7 = GCNConv(64, dataset.num_classes)

        self.act = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, GCNConv):
                if let is not None:
                    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.lin.weight)
                    std_dev = (2.0 / fan_in) ** (1.0 / (2 * D))
                    m.lin.weight.data = dwf_init_param(m.lin.weight.data, std_dev, D=self.D, eps=self.eps, fan_in=fan_in)
                    if hasattr(m.lin, 'bias') and m.lin.bias is not None:
                        m.lin.bias.data = dwf_init_param(m.lin.bias.data, std_dev, D=self.D, eps=self.eps, fan_in=fan_in)
                else:
                    nn.init.kaiming_normal_(m.lin.weight, mode='fan_in', nonlinearity='relu')

        self.control0 = nn.Parameter(torch.ones(512))
        self.control1 = nn.Parameter(torch.ones(256))
        self.control2 = nn.Parameter(torch.ones(256))
        self.control3 = nn.Parameter(torch.ones(256))
        self.control4 = nn.Parameter(torch.ones(256))
        self.control5 = nn.Parameter(torch.ones(64))
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.act(x) * self.control0
        x = self.conv2(x, edge_index)
        x = self.act(x) * self.control1
        x = self.conv3(x, edge_index)
        x = self.act(x) * self.control2
        x = self.conv4(x, edge_index)
        x = self.act(x) * self.control3
        x = self.conv5(x, edge_index)
        x = self.act(x) * self.control4 
        x = self.conv6(x, edge_index)
        x = self.act(x) * self.control5
        x = self.conv7(x, edge_index)
        return x

    def total_channels(self):
        """
        Calculate the total number of channels in the GCN.
        """
        return (
            self.conv1.out_channels +
            self.conv2.out_channels +
            self.conv3.out_channels +
            self.conv4.out_channels +
            self.conv5.out_channels +
            self.conv6.out_channels
        )

    def l1_norm(self):
        """
        Calculate the L1 norm of all parameters in the model.
        """
        total_l1_norm = 0.0
        for p in self.parameters():
            total_l1_norm += torch.sum(torch.abs(p))
        return total_l1_norm

    def connectivity(self, data):
        """
        We use the skeleton to build the computational relationships by modifying the copy of original model when calculating connectivity.
        functional_call:    allows us to temporarily assign a set of parameters to the modified skeleton for computation,
                            while also supporting gradient propagation. The skeleton functions as a callable function.
        """
        data_ = deepcopy(data)
        data_.x = torch.ones_like(data_.x)
        skeleton = deepcopy(self)
        skeleton = replace_module(skeleton, target=(nn.ReLU, nn.Dropout), alternative=nn.Identity())
        skeleton.to(next(self.parameters()).device)
        params = {}
        for k, v in self.named_parameters():
            if 'bias' in k:
                params[k] = torch.zeros_like(v).to(next(self.parameters()).device)
            elif 'control' in k:
                params[k] = para_normalization(v)
            else:
                params[k] = para_normalization(v)
        connectivity = torch.log( torch.sum(functional_call(skeleton, params, (data_,))) )
        return connectivity

def prune(net, ratio, data, path=None):
    """
    Prune control scalars with small \partial L/ \partial w * w
    """
    data_ = deepcopy(data)
    data_.x = torch.ones_like(data_.x)
    skeleton = deepcopy(net)
    skeleton = replace_module(skeleton, target=(nn.ReLU, nn.Dropout), alternative=nn.Identity())
    skeleton.to(next(net.parameters()).device)
    skeleton.zero_grad()

    params = {}
    for k, v in net.named_parameters():
        if 'bias' in k:
            params[k] = torch.zeros_like(v).to(next(net.parameters()).device)
        elif 'control' in k:
            params[k] = para_normalization(v)
        else:
            params[k] = para_normalization(v)

    # Forward pass using custom params
    with torch.no_grad():
        for name, param in skeleton.named_parameters():
            if name in params:
                param.copy_(params[name])

    obj = torch.log( torch.sum(skeleton(data_)) )
    obj.backward()

    # global pruning
    all_values = []
    for (k,v), (_,v_s) in zip(net.named_parameters(), skeleton.named_parameters()):
        if 'control' in k:
            grad_times_param = v_s.grad * v_s
            all_values.append(grad_times_param.view(-1))
    all_values = torch.cat(all_values)
    sorted_values, _ = torch.sort(all_values)
    num_to_reset = int(len(sorted_values) * ratio)
    threshold = sorted_values[num_to_reset - 1] if num_to_reset > 0 else float('-inf')

    mask_dict = {}
    for (k, v), (_, v_s) in zip(net.named_parameters(), skeleton.named_parameters()):
        if 'control' in k:
            grad_times_param = v_s.grad * v_s # importance scores
            mask = grad_times_param < threshold
            v.data[mask] = 0.
            mask_dict[k] = mask  # Save the mask for later use
    return mask_dict


def run(mode, lamda, weight_decay=0., D=1, local=False, path=None, seed=0):
    """
    Run the GNN training and pruning process.   
    """
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    device = get_available_device()
    epochs = 300
    epochs_fine = 100

    total_channels = 1600 # total channels in the GCN
    min_channels = 25 # minimum channels to keep after pruning

    dataset = Planetoid(root='./', name='Cora')
    data = dataset[0]
    data = data.to(device)

    net = GCN(dataset).to(device)

    if mode == 3: # DWF
        net_factors = [GCN( dataset, let=f"{chr(117+i)}", D=D, eps=3e-3 ).to(device) for i in range(D)] # factors obtain symbols u,v,...
        params_to_optimize = []
        for factor_net in net_factors:
            params_to_optimize += [p for name, p in factor_net.named_parameters()]
        optimizer = optim.Adam(params_to_optimize, lr=5e-3, weight_decay=weight_decay)

    else:
        optimizer = optim.Adam([p for name, p in net.named_parameters() ], lr=5e-3, weight_decay=weight_decay)
    
    warmup_epochs = 10
    total_epochs = epochs

    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=1e-2, end_factor=1.0, total_iters=warmup_epochs),
            CosineAnnealingLR(optimizer, T_max=total_epochs - warmup_epochs)
        ],
        milestones=[warmup_epochs]
    )

    criterion = nn.CrossEntropyLoss()

    if path is not None:
        net.load_state_dict(torch.load(path+'/latest.pth'))
        optimizer.load_state_dict(torch.load(path+'/optimizer.pth'))
        scheduler.load_state_dict(torch.load(path+'/scheduler.pth'))

    # prepare for logging
    current_time = re.sub(r'\D', '', str(datetime.datetime.now())[4:-7])
    log_dir = './logs_gnn/'+ current_time + f'_mode_{str(mode)}_lam_{lamda:.0e}_local_{local}_decay_{weight_decay:.0e}' 
    writer = SummaryWriter(log_dir=log_dir)
    save_py(log_dir, py_dir='./')

    net.train()
    if mode == 3: # DWF
        for net_i in net_factors:
            net_i.train()

    best_accuracy = -1.
    for epoch in range(scheduler.last_epoch, epochs):
        train_loss, train_accuracy, train_count = 0., 0., 0
        log_connect, l1_norm = 0., 0.

        optimizer.zero_grad()
        
        if mode == 3: # DWF
            # Factorize parameters for forward pass using functional_call
            params = {}
            for name, _ in net.named_parameters():
                # Start with the first factor
                collapsed_param = dict(net_factors[0].named_parameters())[name]
                for factor_net in net_factors[1:]:
                    collapsed_param = collapsed_param * dict(factor_net.named_parameters())[name]
                params[name] = collapsed_param

            # Forward pass using dynamically collapsed parameters
            outputs = functional_call(net, params, (data,))
        else:
            outputs = net(data)

        loss = criterion(outputs[data.train_mask], data.y[data.train_mask])

        if mode == 3: # DWF
            # store collapsed network in net
            with torch.no_grad():
                for name_net, param_net in net.named_parameters():
                    collapsed_param = dict(net_factors[0].named_parameters())[name_net]
                    for factor_net in net_factors[1:]:
                        collapsed_param = collapsed_param * dict(factor_net.named_parameters())[name_net]
                    param_net.data.copy_(collapsed_param)

        con = net.connectivity(data)
        l1 = net.l1_norm()

        if mode == 0:
            total_loss = loss
        elif mode == 1:
            total_loss = loss + lamda * l1
        elif mode == 2:
            total_loss = loss - lamda * con
        elif mode == 3:
            total_loss = loss

        total_loss.backward()

        pred = outputs.argmax(dim=1)
        train_accuracy += (pred[data.train_mask] == data.y[data.train_mask]).sum() / int(data.train_mask.sum())
        train_loss += loss.cpu().detach().numpy()
        log_connect += con.cpu().detach().numpy()
        l1_norm += l1.cpu().detach().numpy()

        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)
        optimizer.step()
        scheduler.step()

        # log and save
        valid_loss, valid_accuracy = validation(net, data, criterion, device)
        log_info = f'Train Epoch:{epoch:3d} || train loss:{train_loss:.2e} train accuracy:{train_accuracy*100:.2f}% ' + \
                   f'valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% lr:{scheduler.get_last_lr()[0]:.2e} ' + \
                   f'log connect:{log_connect:.4e} l1 norm:{l1:.4e}'
        save_model(net.state_dict(), valid_accuracy >= best_accuracy, log_dir)
        torch.save(optimizer.state_dict(), log_dir + '/optimizer.pth')
        torch.save(scheduler.state_dict(), log_dir + '/scheduler.pth')
        best_accuracy = valid_accuracy+0. if valid_accuracy >= best_accuracy else best_accuracy
        writer.add_scalar('train/log_connect', log_connect, epoch)
        writer.add_scalar('train/l1_norm', l1_norm, epoch)
        writer.add_scalar('train/loss', train_loss, epoch)
        writer.add_scalar('train/accuracy', train_accuracy, epoch)
        writer.add_scalar('valid/loss', valid_loss, epoch)
        writer.add_scalar('valid/accuracy', valid_accuracy, epoch)
        print(log_info)

    # Generate powers-of-two down to ≥ min_channels
    remaining_channels = []
    val = total_channels
    while val >= min_channels:
        remaining_channels.append(val)
        val = val // 2  # halve each step

    remaining_channels = np.array(remaining_channels)
    pruning_ratios = 1.0 - (remaining_channels / total_channels)
    pruning_ratios = np.round(pruning_ratios, 6)
    for ratio in pruning_ratios:
        
        prune_net = deepcopy(net)
        mask_dict = prune(prune_net, ratio, data, log_dir)
        pred = prune_net(data).argmax(dim=1)
        train_accuracy = (pred[data.train_mask] == data.y[data.train_mask]).sum() / int(data.train_mask.sum())
        valid_loss, valid_accuracy = validation(prune_net, data, criterion, device)
        log_connect = prune_net.connectivity(data).cpu().detach().numpy()

        # check if is nan
        if log_connect is None or np.isnan(log_connect): valid_accuracy = 0.

        log_info = f'Pruning ratio:{ratio*100:.2f}% || train accuracy:{train_accuracy*100:.2f} valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% connectivity:{log_connect:.4e}'
        writer.add_scalar('channel_prune/valid_loss', valid_loss, int(ratio*100))
        writer.add_scalar('channel_prune/valid_accuracy', valid_accuracy, int(ratio*100))
        writer.add_scalar('channel_prune/log_connect', log_connect, int(ratio*100))
        print(log_info)

        # fine-tuning
        optimizer_fine = optim.Adam([p for name, p in prune_net.named_parameters()], lr=5e-4, weight_decay=weight_decay)
        scheduler_fine = SequentialLR(
                                optimizer_fine,
                                schedulers=[
                                    LinearLR(optimizer_fine, start_factor=1e-2, end_factor=1.0, total_iters=warmup_epochs),
                                    CosineAnnealingLR(optimizer_fine, T_max=total_epochs - warmup_epochs)
                                ],
                                milestones=[warmup_epochs]
                            )
        log_dir_fine = log_dir+f'/fine_tuning/{ratio:.2f}'
        prune_net.train()
        best_accuracy = -1.

        for epoch in range(epochs_fine):
            train_loss, train_accuracy = 0., 0.
            log_connect, l1_norm = 0., 0.
            optimizer_fine.zero_grad()

            outputs = prune_net(data)
            loss = criterion(outputs[data.train_mask], data.y[data.train_mask])
            loss.backward()

            torch.nn.utils.clip_grad_norm_(prune_net.parameters(), max_norm=1.0)
            optimizer_fine.step()
            scheduler_fine.step()

            # ensure masked params stay 0
            with torch.no_grad():
                for name, param in prune_net.named_parameters():
                    if name in mask_dict:
                        param.data[mask_dict[name]] = 0.

        log_connect = prune_net.connectivity(data).cpu().detach().numpy()
        valid_loss, valid_accuracy = validation(prune_net, data, criterion, device)

        # check if is nan
        if log_connect is None or np.isnan(log_connect): valid_accuracy = 0.

        log_info = f'Fine-tuning || valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% connectivity:{log_connect:.4e}'
        os.makedirs(log_dir_fine, exist_ok=True)
        save_model(prune_net.state_dict(), valid_accuracy >= best_accuracy, log_dir_fine)
        best_accuracy = valid_accuracy+0. if valid_accuracy >= best_accuracy else best_accuracy
        writer.add_scalar('fine/log_connect', log_connect, int(ratio*100))
        writer.add_scalar('fine/loss', valid_loss, int(ratio*100))
        writer.add_scalar('fine/accuracy', valid_accuracy, int(ratio*100))
        print(log_info)
    

if __name__ == "__main__":

    # (mode, lam, weight_decay, factor_count) # factor_count is 1 for all methods other than DWF
    process_args = [
        # (0, 0, 1e-3, 1), # mode 0 is no regularization
        # (1, 1e-4, 1e-4, 1), # mode 1 is l1 regularization
        # (2, 1e-0, 1e-3, 1), # mode 2 is connect regularization
        (3, 1e-0, 1e-4, 2), # mode 3 is DWF regularization
    ]
    for seed in range(10):
        for p in process_args:
            run(p[0], p[1], p[2], p[3], seed=seed)
