import os
import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from soul.model import *
from soul.neuron import *
from soul.utils import *


# init all config settings
config = init_config()

# activate distributed
config['is_distributed'] = "RANK" in os.environ and "WORLD_SIZE" in os.environ
if config['is_distributed']:
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    # gpu for current process
    device = torch.device("cuda", local_rank)
    # main process
    global_rank = dist.get_rank()
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    local_rank = 0
    global_rank = 0

# init logger
data_dir = config['data_dir']
channel_info = data_dir.split('/')[-1]  
T_info = f"T{config['time_step']}"
if global_rank == 0:
    log_path = os.path.join(
        config['log_dir'], 
        config['dataset_name'].lower(), 
        channel_info,
        f"seed{config['seed']}",
        config['model'].lower(), 
        config['neuron_type'].lower(),
        
    )
    ensure_dir(log_path)
    print(log_path)
    logger = setup_logger(os.path.join(log_path, f'record-{get_local_time()}.log'), default_level=config['state'])
    logger.info(f'Distributed Training: {config["is_distributed"]}')
else:
    logger = None

# report all configuration
for k, v in sorted(config.items()):
    if global_rank == 0:
        logger.debug(f'{k} = {v}')

# reproducibility
if global_rank == 0:
    logger.info(f'Reproducibility with random seed {config["seed"]}')
    init_seed(config["seed"])
    logger.info('=' * 50)

# load data TODO
if global_rank == 0:
    logger.info('Load data...')
train_dataset, test_dataset = load_dataset(config)

if config['is_distributed']:
    train_sampler = torch.utils.data.DistributedSampler(train_dataset)
    # define the batch size per gpu, usually we define the numer of process equal to the number of used gpus
    world_size = dist.get_world_size()
    config['batch_size'] //= world_size
else:
    train_sampler = None

# load dataloader
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=config['batch_size'], 
    shuffle= False if config['is_distributed'] else True,
    sampler=train_sampler, 
    num_workers=config['workers'], 
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=config['batch_size'], 
    shuffle=False,
    num_workers=config['workers'], 
    pin_memory=True
)

# load SNN model
if global_rank == 0:
    logger.info(f'Load SNN model: {config["model"]} featured {config["neuron_type"].upper()} neuron...')
    logger.info(f'#Training Samples: {len(train_dataset)}; #Test Samples: {len(test_dataset)}')

if global_rank == 0:
    logger.debug(f'surrogate function: {config["surrogate"]}')
config['surrogate_function'] = surrogate_map[config['surrogate']]
config['neuron'] = neuron_map[config['neuron_type'].lower()](config) 

model = model_map[config['application']][config['model'].lower()](config)
if global_rank == 0:
    logger.debug('\n'+ str(model))
model.to(device)

# calculate number of parameters
if global_rank == 0:
    n_parameters = count_parameters(model, trainable=True) 
    logger.info(f"Number of params for model {config['model']}: {n_parameters / 1e6:.2f} M")

if config['is_distributed']:
    model = DDP(model, device_ids=[local_rank])

criterion = nn.CrossEntropyLoss()
# init optimzer
if config['optimizer'].lower() == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], weight_decay=config['weight_decay'])
elif config['optimizer'].lower() == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
elif config['optimizer'].lower() == 'adamw':
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
elif config['optimizer'].lower() == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
else:
    if global_rank == 0:
        logger.warning(f"Received unrecognized optimizer {config['optimizer']}, set default Adam optimizer")
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

# init scheduler
if config['scheduler'].lower() == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])
elif config['scheduler'].lower() == 'linear':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(config["epochs"] * 0.25), gamma=0.1)
elif config['scheduler'].lower() == 'warmup':
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(config["epochs"] * 0.1), T_mult=2)
else:
    if global_rank == 0:
        logger.warning(f"Received unrecognized scheduler {config['scheduler']}, set default ConsineAnnealing Scheduler")
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])


SNN_LAYER_SOP_STATS = {}  # {layer_name: total_synops_over_dataset}

def _reset_snn_sop_stats():

    SNN_LAYER_SOP_STATS.clear()
def _register_snn_sop_hooks(net, logger=None):

    hooks = []

    def make_hook(layer_name, module):
        def hook(m, inputs, outputs):

            x = inputs[0]


            spikes = (x > 0).float()
            num_elems = spikes.numel()
            if num_elems == 0:
                return
            firing_rate = spikes.mean().item()

      
            out = outputs
            if isinstance(out, (tuple, list)):
                out = out[0]
            if not torch.is_tensor(out):
                return
            if out.dim() == 0:
                return

    
            batch_size = out.shape[0]
            if batch_size == 0:
                return

   
            if not hasattr(m, "_dense_ops_per_sample"):
                if isinstance(m, nn.Linear):
                    # per-sample dense MAC
                    dense_ops_per_sample = float(m.in_features * m.out_features)

                elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
                    # weight: (C_out, C_in, *kernel)
                    Cout = m.out_channels
                    Cin = m.in_channels

          
                    if isinstance(m.kernel_size, int):
                        kernel_elems = m.kernel_size
                    else:
                        kernel_elems = 1
                        for k in m.kernel_size:
                            kernel_elems *= k

                    spatial_elems = 1
                    if out.dim() > 2:
                        for s in out.shape[2:]:
                            spatial_elems *= s

  
                    dense_ops_per_sample = float(Cin * Cout * kernel_elems * spatial_elems)
                else:
                    dense_ops_per_sample = 0.0

                m._dense_ops_per_sample = dense_ops_per_sample

                if logger is not None:
                    logger.debug(
                        f"[SNN-Energy] Layer {layer_name}: dense MAC per-sample ≈ {dense_ops_per_sample/1e6:.3f} M"
                    )

            dense_ops_per_sample = getattr(m, "_dense_ops_per_sample", 0.0)
            if dense_ops_per_sample <= 0.0:
                return

  
            synops_this_call = dense_ops_per_sample * firing_rate * float(batch_size)

       
            prev = SNN_LAYER_SOP_STATS.get(layer_name, 0.0)
            SNN_LAYER_SOP_STATS[layer_name] = prev + synops_this_call

        return hook

    for name, module in net.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
            h = module.register_forward_hook(make_hook(name, module))
            hooks.append(h)

    if logger is not None:
        logger.info(f"[SNN-Energy] Registered SOP hooks on {len(hooks)} Conv/Linear layers.")

    return hooks


def _remove_snn_hooks(hooks):

    for h in hooks:
        h.remove()



best_acc = 0.


for epoch in range(1, config['epochs'] + 1):
    model.train()
    if config['is_distributed']:
        train_sampler.set_epoch(epoch)
    
    train_top1_meter, train_loss_meter = AverageMeter(), AverageMeter()
    # customize progress bar for train loader
    loader = tqdm(train_loader, unit='batch', ncols=80, desc='Train: ') if global_rank == 0 else train_loader
    for inputs, targets in loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad()

        # default data shape (B, T, input_size) -> (T, B, input_size)
        inputs = inputs.transpose(0, 1)

        outputs = model(inputs)
        acc1 = accuracy(outputs, targets, topk=(1,))[0]

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_top1_meter.update(acc1.item(), targets.numel())
        train_loss_meter.update(loss.item(), targets.numel())

    train_acc = train_top1_meter.avg
    train_loss = train_loss_meter.avg

    if not config['is_distributed'] or dist.get_rank() == 0:
        model.eval()

        test_top1_meter, test_loss_meter = AverageMeter(), AverageMeter()
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

                # default data shape (B, T, input_size) -> (T, B, input_size)
                inputs = inputs.transpose(0, 1)

                outputs = model(inputs)
                acc1 = accuracy(outputs, targets, topk=(1,))[0]
                loss = criterion(outputs, targets)

                test_loss_meter.update(loss.item(), targets.numel())
                test_top1_meter.update(acc1.item(), targets.numel())

        test_acc = test_top1_meter.avg
        test_loss = test_loss_meter.avg

        logger.info(f"[Epoch {epoch:3d}/{config['epochs']}] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%; Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
        if test_acc > best_acc:
            ensure_dir(config['model_dir'])

            best_acc = test_acc
            logger.info(f'Best model saved with accuracy: {best_acc:.2f}%')
            torch.save(
                model.module.state_dict() if config['is_distributed'] else model.state_dict(), 
                os.path.join(config['model_dir'], f'best_{config["model"].lower()}_{config["neuron_type"].lower()}_{config["dataset_name"].lower()}_{config["seed"]}.pt')
            )

    scheduler.step()


if not config['is_distributed'] or dist.get_rank() == 0:

    best_model_path = os.path.join(
        config['model_dir'],
        f'best_{config["model"].lower()}_{config["neuron_type"].lower()}_{config["dataset_name"].lower()}_{config["seed"]}.pt'
    )
    if os.path.exists(best_model_path):
        state = torch.load(best_model_path, map_location='cpu')
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model.module.load_state_dict(state)
        else:
            model.load_state_dict(state)
        logger.info(f'Loaded best checkpoint from {best_model_path} for OPs measurement.')

    n_parameters = count_parameters(
        model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model,
        trainable=True
    )
    params_m = n_parameters / 1e6




    logger.info(
        f"[SUMMARY] Model={config['model']} (neuron={config['neuron_type']}, "
        f"coding={config['coding_schema']}, T={config['time_step']}) | "
        f"Dataset={config['dataset_name']} | "
        f"Params={params_m:.3f} M | "
        f"BestAcc={best_acc:.2f} %"
    )

    


# recycle all process
if config['is_distributed']:
    dist.destroy_process_group()
