import os
import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from soul.model import *
from soul.neuron import *
from soul.utils import *
from soul.model.mmwave.mmWave_Ann_model import *
import torch.distributed as dist
import yaml



try:
    from thop import profile
    HAS_THOP = True
except ImportError:
    HAS_THOP = False

def build_mmwave_ann_model(config):

    name = config["model"].lower()
    input_shape = config["input_shape"]
    num_classes = config["num_classes"]


    if name == "mmmlp":

        hidden_dim = config.get("hidden_dim", 1024)
        return MmWaveAnn_MLP(input_shape, num_classes, hidden_dim=hidden_dim)

    elif name == "mmlenet5":
        # lenet.yaml: hidden_dim: 512
        hidden_dim = config.get("hidden_dim", 512)
        return MmWaveAnn_LeNet5(input_shape, num_classes, hidden_dim=hidden_dim)

    # ========= ResNet =========
    elif name == "mmresnet18":
          return MmWaveAnn_ResNet18(input_shape, num_classes)

    elif name == "mmresnet50":
        return MmWaveAnn_ResNet50(input_shape, num_classes)

    elif name == "mmresnet101": 
         return MmWaveAnn_ResNet101(input_shape, num_classes)

    # ========= RNN/GRU/LSTM =========
    elif name == "mmrnn":
        hidden_dim = config.get("hidden_dim", 128)
        num_layers = config.get("num_layers", 1)
        return MmWaveAnn_RNN(input_shape, num_classes,
                             hidden_size=hidden_dim,
                             num_layers=num_layers)

    elif name == "mmgru":
        hidden_dim = config.get("hidden_dim", 128)
        num_layers = config.get("num_layers", 1)
        return MmWaveAnn_GRU(input_shape, num_classes,
                             hidden_size=hidden_dim,
                             num_layers=num_layers)

    elif name == "mmlstm":
        hidden_dim = config.get("hidden_dim", 128)
        num_layers = config.get("num_layers", 1)
        return MmWaveAnn_LSTM(input_shape, num_classes,
                              hidden_size=hidden_dim,
                              num_layers=num_layers)

    elif name == "mmbilstm":
        hidden_dim = config.get("hidden_dim", 128)
        num_layers = config.get("num_layers", 1) 
        return MmWaveAnn_BiLSTM(input_shape, num_classes,
                                hidden_size=hidden_dim,
                                num_layers=num_layers)

    # ========= CNN-GRU =========
    elif name == "mmcnn_gru":
        hidden_dim = config.get("hidden_dim", 128)
        num_layers = config.get("num_layers", 1)
        return MmWaveAnn_CNN_GRU(input_shape, num_classes,
                                 hidden_size=hidden_dim,
                                 num_layers=num_layers)

    # ========= ViT / Transformer =========
    elif name == "mmvit":
        patch_size = config.get("patch_size", 16)
 
        return MmWaveAnn_ViT(input_shape, num_classes,
                             patch_size=patch_size)

    else:
        raise ValueError(f"Unknown mmWave ANN model: {config['model']}")



# Energy
def measure_ops_and_energy(model, test_loader, device, config, logger):


    net = model.module if isinstance(
        model, torch.nn.parallel.DistributedDataParallel
    ) else model
    net.eval()
    net.to(device)

    MODULE_SOP_DICT.clear()

    logger.info('Counting FLOPs/SOPs for theoretical inference cost...')

    ops_monitor(net, is_sop=config['sop'])

    with torch.no_grad():
        for inputs, _ in tqdm(test_loader, unit='batch', ncols=80, desc='Count OPs: '):
            
            inputs = inputs.transpose(0, 1).to(device)
            _ = net(inputs)

    
    total_ops = sum(MODULE_SOP_DICT.values())
    num_samples = len(test_loader.dataset)
    avg_ops_per_sample = total_ops / num_samples

    cost_per_op = config['e_ac'] if config['sop'] else config['e_mac']
    energy_uJ = avg_ops_per_sample * cost_per_op / 1e6   

    logger.info(
        f"Average number of {'SOPs' if config['sop'] else 'FLOPs'} "
        f"per sample: {avg_ops_per_sample / 1e6:.2f} M"
    )
    logger.info(
        f"Theoretical energy per sample: {energy_uJ:.4f} uJ"
    )

    return avg_ops_per_sample, energy_uJ


class BinaryActivationFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.sign()   # >0 -> +1, <=0 -> -1

    @staticmethod
    def backward(ctx, grad_output):
        # Straight-Through Estimator (STE)
        grad_input = grad_output.clone()
        grad_input = grad_input * (grad_output.abs() <= 1).float()
        return grad_input


class BinaryActivation(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return BinaryActivationFn.apply(x)
    
class LeNet(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.app = config['application']

        self.num_classes = config['num_classes']
        self.T = config['time_step']

        if self.app in ['mmwave']:
            C, H, W = config['input_channels'], config['input_height'], config['input_width']
        else:
            raise ValueError(self.app)
        
        conv_kernel_size = (1, 5) if self.app == 'motion' else 5
        pool_kernel_size = (1, 2) if self.app == 'motion' else 2
        pool_stride_size = (1, 2) if self.app == 'motion' else 2

        self.encoder = nn.Sequential(
            nn.Conv2d(C, 32, kernel_size=conv_kernel_size, stride=1, padding=0),
            nn.BatchNorm2d(32),
            # BinaryActivation(),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride_size),
            nn.Conv2d(32, 64, kernel_size=conv_kernel_size, stride=1, padding=0),
            nn.BatchNorm2d(64),
            # BinaryActivation(),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride_size),
            nn.Conv2d(64, 96, kernel_size=conv_kernel_size, stride=1, padding=0),
            nn.BatchNorm2d(96),
            # BinaryActivation(),
            nn.ReLU(True),
        )

        if self.app in ['mmwave']:
            H = (H - 4) // 2
            W = (W - 4) // 2
            H = (H - 4) // 2
            W = (W - 4) // 2
            H -= 4
            W -= 4
        else:
            W = (W - 4) // 2
            W = (W - 4) // 2
            W -= 4

        dim = 96 * H * W

        self.fc = nn.Sequential(
            nn.Linear(dim, 512),
            # BinaryActivation(),
            nn.ReLU(),
            nn.Linear(512, self.num_classes)
        )

    def forward(self, x):
        x = x.mean(0) # (T, B, C, H, W) -> (B, C, H, W)

        if self.app in ['motion', 'acoustic']:
            x = x.unsqueeze(1) # (B, C, W) -> (B, 1, C, W)

        x = self.encoder(x)
        x = x.flatten(1) # (B, C, H, W) -> (B, C*H*W)

        out = self.fc(x) # (B, C*H*W) -> (B, num_classes)

        return out

# init all config settings
config = init_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


config = init_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model_name = config['model'].lower()


model_cfg_dir = "/home/Firewall/src/ICML-SRC/soul/config/model/mmwave"   


yaml_candidates = [
    os.path.join(model_cfg_dir, f"{model_name}.yaml"),
]


if model_name == "mmcnn_gru":
    yaml_candidates.append(os.path.join(model_cfg_dir, "mmcnn.yaml"))

for ypath in yaml_candidates:
    if os.path.exists(ypath):
        with open(ypath, "r") as f:
            model_cfg = yaml.safe_load(f) or {}
        
        config.update(model_cfg)
        print(f"[INFO] Loaded model config from {ypath}: {model_cfg}")
        break
else:
    print(f"[WARN] No model-specific YAML found for model={model_name}, "
          f"searched: {yaml_candidates}")


data_dir = config['data_dir']
channel_info = data_dir.split('/')[-1]  
model_log_name = config['model'].lower()


log_path = os.path.join(
    config['log_dir'], 
    config['dataset_name'].lower(), 
    channel_info,
    'ann',    
    f"seed{config['seed']}",
    model_log_name,
)


ensure_dir(log_path)
print(log_path)
logger = setup_logger(os.path.join(log_path, f'record-{get_local_time()}.log'), default_level=config['state'])

if config['dataset_name'].lower() in ['dvsgesture', 'ssc', 'shd', 'cifar10dvs']:
    config['time_step'] = 10 

# report all configuration
for k, v in sorted(config.items()):
    logger.debug(f'{k} = {v}')

# reproducibility
logger.info(f'Reproducibility with random seed {config["seed"]}')
init_seed(config["seed"])
logger.info('=' * 50)

logger.info('Load data...')
train_dataset, test_dataset = load_dataset(config)


x0, y0 = train_dataset[0]      
x0_shape = x0.shape

if x0.dim() == 4:
    # (T, C, H, W)
    T, C, H, W = x0_shape
elif x0.dim() == 3:
    # (C, H, W)
    C, H, W = x0_shape
elif x0.dim() == 2:
    # (H, W)
    H, W = x0_shape
    C = 1
else:
    raise ValueError(f"Unsupported sample shape: {x0_shape}")

config["input_shape"]    = (C, H, W)   
config["input_channels"] = C           
config["input_height"]   = H
config["input_width"]    = W


if "num_classes" not in config or config["num_classes"] is None:

    if hasattr(train_dataset, "num_classes") and train_dataset.num_classes is not None:
        config["num_classes"] = train_dataset.num_classes
    else:

        all_labels = set()
        for ds in (train_dataset, test_dataset):
            for _, y in ds:
                all_labels.add(int(y))
        config["num_classes"] = len(all_labels)

# =========================================================


train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=config['batch_size'], 
    shuffle=True,
    num_workers=config['workers'], 
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=config['batch_size'], 
    shuffle=False,
    num_workers=config['workers'], 
    pin_memory=True
)


logger.info(f'#Training Samples: {len(train_dataset)}; #Test Samples: {len(test_dataset)}')

model_name = config["model"].lower()
TIME_MAJOR_MODELS = ["mmrnn", "mmgru", "mmlstm", "mmbilstm"]


if model_name.startswith("mm"):
    logger.info(f'Load ANN model: {config["model"]} (mmWave_Ann_model)...')
    model = build_mmwave_ann_model(config)
else:
    logger.info(f'Load ANN model: LeNet...')
    model = LeNet(config)

model.to(device)


# calculate number of parameters
n_parameters = count_parameters(model, trainable=True) 
logger.info(f"Number of params for model {config['model']}: {n_parameters / 1e6:.2f} M")


criterion = nn.CrossEntropyLoss()
# init optimzer
if config['optimizer'].lower() == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], weight_decay=config['weight_decay'])
elif config['optimizer'].lower() == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
elif config['optimizer'].lower() == 'adamw':
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
elif config['optimizer'].lower() == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
else:
    logger.warning(f"Received unrecognized optimizer {config['optimizer']}, set default Adam optimizer")
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

# init scheduler
if config['scheduler'].lower() == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])
elif config['scheduler'].lower() == 'linear':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(config["epochs"] * 0.25), gamma=0.1)
elif config['scheduler'].lower() == 'warmup':
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(config["epochs"] * 0.1), T_mult=2)
else:
    logger.warning(f"Received unrecognized scheduler {config['scheduler']}, set default ConsineAnnealing Scheduler")
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])


best_acc = 0.
for epoch in range(1, config['epochs'] + 1):
    model.train()
    
    train_top1_meter, train_loss_meter = AverageMeter(), AverageMeter()
    # customize progress bar for train loader
    loader = tqdm(train_loader, unit='batch', ncols=80, desc='Train: ')
    for inputs, targets in loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad()

        # default data shape (B, T, input_size) -> (T, B, input_size)
        inputs = inputs.transpose(0, 1)

        outputs = model(inputs)
        acc1 = accuracy(outputs, targets, topk=(1,))[0]

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_top1_meter.update(acc1.item(), targets.numel())
        train_loss_meter.update(loss.item(), targets.numel())

    train_acc = train_top1_meter.avg
    train_loss = train_loss_meter.avg

    
    model.eval()

    test_top1_meter, test_loss_meter = AverageMeter(), AverageMeter()
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            # default data shape (B, T, input_size) -> (T, B, input_size)
            inputs = inputs.transpose(0, 1)

            outputs = model(inputs)
            acc1 = accuracy(outputs, targets, topk=(1,))[0]
            loss = criterion(outputs, targets)

            test_loss_meter.update(loss.item(), targets.numel())
            test_top1_meter.update(acc1.item(), targets.numel())

    test_acc = test_top1_meter.avg
    test_loss = test_loss_meter.avg

    logger.info(f"[Epoch {epoch:3d}/{config['epochs']}] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%; Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    if test_acc > best_acc:
        ensure_dir(config['model_dir'])

        best_acc = test_acc
        logger.info(f'Best model saved with accuracy: {best_acc:.2f}%')
        torch.save(
            model.state_dict(), 
            os.path.join(
                config['model_dir'],
                f"best_{config['model'].lower()}_ann_{config['dataset_name'].lower()}_{config['seed']}.pt"
            )
        )



    scheduler.step()



is_distributed = config.get('is_distributed', False)
rank0 = (not is_distributed)
if is_distributed and dist.is_available() and dist.is_initialized():
    rank0 = (dist.get_rank() == 0)

if rank0:

    best_model_path = os.path.join(
        config['model_dir'],
        f"best_{config['model'].lower()}_ann_{config['dataset_name'].lower()}_{config['seed']}.pt"
    )
    if os.path.exists(best_model_path):
        state = torch.load(best_model_path, map_location=device)
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model.module.load_state_dict(state)
        else:
            model.load_state_dict(state)
        logger.info(f'Loaded best checkpoint from {best_model_path} for OPs measurement.')
    else:
        logger.warning(f'Best checkpoint not found at {best_model_path}, use current model for OPs measurement.')

    config['sop'] = False


    core_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
    n_parameters = count_parameters(core_model, trainable=True)
    params_m = n_parameters / 1e6


    flops_m = 0.0
    energy_uJ = 0.0

    if HAS_THOP:
        core_model.eval()

        try:
            dummy_inputs, dummy_targets = next(iter(test_loader))
        except StopIteration:
            dummy_inputs, dummy_targets = next(iter(train_loader))

        dummy_inputs = dummy_inputs.to(device, non_blocking=True)


        B = dummy_inputs.shape[0]
        dummy_inputs_for_flops = dummy_inputs.transpose(0, 1)   # (T, B, ...)
        batch_size_dummy = B



        flops_total, params_thop = profile(core_model, inputs=(dummy_inputs_for_flops,), verbose=False)


        flops_per_sample = flops_total / float(batch_size_dummy)
        flops_m = flops_per_sample / 1e6


        e_mac = config.get('e_mac', 4.6)  
        energy_uJ = flops_per_sample * e_mac / 1e6

        logger.info(f"[THOP] FLOPs per sample: {flops_per_sample/1e6:.2f} M, Energy: {energy_uJ:.4f} mJ")
    else:
     
        avg_ops, energy_uJ = measure_ops_and_energy(model, test_loader, device, config, logger)
        flops_m = avg_ops / 1e6


    logger.info(
        f"[SUMMARY] Model={config['model']} (ANN, T={config['time_step']}) | "
        f"Dataset={config['dataset_name']} | "
        f"Params={params_m:.3f} M | "
        f"FLOPs={flops_m:.2f} M | "
        f"Energy={energy_uJ:.4f} uJ | "
        f"BestAcc={best_acc:.2f} %"
    )
# ========================================================================


