# ============================ Imports ============================
import os
import sys
import glob
import math
import time
import random
import argparse
import pickle
import json
import warnings
warnings.filterwarnings("ignore")
from sklearn.metrics import f1_score
import ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
from torchvision import transforms

from scipy.signal import stft, hilbert
from torch.autograd import Function

from thop import profile
from math import ceil
# ============================ Local Imports ============================
sys.path.append('./data/')
sys.path.append('./utils/')
sys.path.append('./models/')

from utils.helper_function import set_seed, count_model_parameters, AverageMeter, ProgressMeter, sax_tokenizer
# from models.expert_models import InformerEncoderLayer, ProbSparseAttention
from models.our_models import *

from utils.dataset_cfg import WESAD

# ============================ Argument Parser ============================
def string_to_list(arg):
    try:
        return ast.literal_eval(arg)
    except (ValueError, SyntaxError):
        return arg.split(',')

parser = argparse.ArgumentParser(description='HeteroIrregTS')

parser.add_argument('--modalities', type=string_to_list, default=['wrist_ACC', 'wrist_BVP', 'wrist_EDA', 'wrist_TEMP', 'chest_ACC'], help='List of modalities')
parser.add_argument('--log_comment', default='ours_wesad_all_modalities', type=str)
parser.add_argument('--chkpt_pth', default='WESAD_1_May/', type=str)
parser.add_argument('--results_dir', default='WESAD_1_May/', type=str)
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--cuda_pick', default='cuda:5', type=str)
parser.add_argument('--seed_num', default=2711, type=int)
parser.add_argument('--transform', default='None', type=str)
parser.add_argument('--num_experts', default='4', type=int)
parser.add_argument('--base_factor', default='5', type=int)
# parser.add_argument('--model_name', default='ExpertCNNClf', type=str)


args = parser.parse_args()

# ============================ Config Extraction ============================
num_epochs     = args.num_epochs
chkpt_pth      = './saved_chk_dir/' + args.chkpt_pth
log_comment    = args.log_comment
cuda_pick      = args.cuda_pick
seed_num       = args.seed_num
modalities     = args.modalities
batch_size     = args.batch_size
results_dir    = './results_dir/' + args.results_dir
num_experts     = args.num_experts
base_factor         = args.base_factor
# model_name = args.model_name


# ============================ Setup ============================
set_seed(seed_num)
device = torch.device(cuda_pick if torch.cuda.is_available() else "cpu")
print(device)

root_dir = './data/'
dataset_cfg = WESAD()
modalities = dataset_cfg.modalities
print('Modalities:', modalities)
os.makedirs(chkpt_pth, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

writer = SummaryWriter(comment=log_comment)

# ============================ Dataset ============================
class DaliaDataset(Dataset):
    def __init__(self, root_dir, modalities, subjects, cfg, transform='sax', sax_params=None):
        self.root_dir = root_dir
        self.transform = transform
        self.subjects = subjects
        self.modalities = modalities
        self.base_sr = cfg.base_sample_rate
        self.duration = cfg.duration
        self.sampling_rates = cfg.sampling_rates
        self.file_paths = []
        self.sax_params = sax_params or {'alphabet_size': 20, 'word_length': 2}

        for subject in subjects:
            subject_dir = os.path.join(root_dir, subject)
            if os.path.exists(subject_dir):
                self.file_paths.extend(glob.glob(os.path.join(subject_dir, "*.pt")))

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        data = torch.load(self.file_paths[idx])

        def resample_data(mod_data, orig_sr):
            expected_length = int(self.duration * self.base_sr)
            resample_factor = self.base_sr / orig_sr

            if resample_factor < 1:
                step = int(1 / resample_factor)
                resampled = mod_data[::step]
            else:
                if len(mod_data.shape) == 2:
                    time_dim, feature_dim = mod_data.shape
                    mod_data_reshaped = mod_data.permute(1, 0).unsqueeze(0)
                    target_len = int(time_dim * resample_factor)
                    resampled = F.interpolate(mod_data_reshaped, size=target_len, mode='linear', align_corners=False).squeeze(0).permute(1, 0)
                else:
                    raise ValueError(f"Unexpected tensor shape: {mod_data.shape}")

            current_length = resampled.shape[0]
            if current_length > expected_length:
                resampled = resampled[:expected_length]
            elif current_length < expected_length:
                padding_needed = expected_length - current_length
                last_frame = resampled[-1:].repeat(padding_needed, 1)
                resampled = torch.cat([resampled, last_frame], dim=0)

            assert resampled.shape[0] == expected_length
            return resampled

        # Resample and concatenate modalities
        if len(self.modalities) > 1:
            resampled_data = []
            for modality in self.modalities:
                if modality in data:
                    mod_data = data[modality]
                    orig_sr = self.sampling_rates[modality]
                    resampled = resample_data(mod_data, orig_sr)
                    resampled_data.append(resampled)
            x = torch.cat(resampled_data, dim=1)
        else:
            modality = self.modalities[0]
            mod_data = data[modality]
            orig_sr = self.sampling_rates[modality]
            x = resample_data(mod_data, orig_sr)

        # Apply SAX if specified
        if self.transform == 'sax':
            sax_features = []
            for ch in range(x.shape[1]):
                sax_seq = sax_tokenizer(
                    x[:, ch].numpy(),
                    alphabet_size=self.sax_params['alphabet_size'],
                    word_length=self.sax_params['word_length']
                )
                sax_features.append(torch.tensor(sax_seq, dtype=torch.long))
            x = torch.stack(sax_features, dim=1)  # [num_words, channels]

        y = data['label']
        return x, y
if (args.transform == 'sax') :
    input_length = (dataset_cfg.duration * dataset_cfg.base_sample_rate)//2
else:
    input_length = dataset_cfg.duration * dataset_cfg.base_sample_rate
print('Input length:', input_length)
print('Transform:', args.transform)
train_dataset = DaliaDataset(root_dir, modalities, dataset_cfg.train_set, dataset_cfg, args.transform)
val_dataset   = DaliaDataset(root_dir, modalities, dataset_cfg.val_set, dataset_cfg, args.transform)
eval_dataset  = DaliaDataset(root_dir, modalities, dataset_cfg.eval_set, dataset_cfg, args.transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
val_dataloader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
eval_dataloader  = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)

# ============================ Model Setup ============================
input_dim = sum(dataset_cfg.variates[modality] for modality in modalities)


class CrossAttnTransformerClf(nn.Module):
    def __init__(self, cfg, num_classes, input_length=256, d_model=64, nhead=8, num_layers_per_modal=2, num_layers=2, dropout=0.1, verbose=True, base_factor=5, num_experts=4):
        super().__init__()
        self.modalities = cfg.modalities
        self.variates = cfg.variates
        self.num_modalities = len(self.modalities)
        self.input_length = input_length
        self.verbose = verbose
        self.d_model = d_model
        self.base_factor = base_factor
        self.num_experts = num_experts

        # Dynamically create input projection layers
        self.input_projections = nn.ModuleDict({
            modality: nn.Linear(self.variates[modality], d_model)
            for modality in self.modalities
        })

        # Positional encoder shared across modalities
        self.pos_encoder = ModalityPositionalEncoder(
            d_model=d_model,
            max_len=input_length,
            num_modalities=self.num_modalities
        )
        
        self.temporal_pos_encoder = TemporalPositionalEncoder(
            d_model=d_model,
            max_len=input_length
        )

        # Per-modality Informer layers
        self.per_modal_informers = nn.ModuleDict({
            modality: nn.ModuleList([
                InformerEncoderLayer(
                    d_model=d_model,
                    n_heads=nhead,
                    d_ff=d_model * 4,
                    dropout=dropout,
                    factor=5,
                ) for _ in range(num_layers_per_modal)
            ]) for modality in self.modalities
        })

        # Final fusion Informer with sparse MoE
        self.informer_encoder = nn.ModuleList([
            InformerEncoderLayerWithMoE(
                d_model=d_model,
                n_heads=nhead,
                d_ff=d_model * 4,
                dropout=dropout,
                factor=5,
                num_experts=num_experts,
                k=1
            ) for _ in range(num_layers)
        ])

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )

        self.factor_gate = nn.Sequential(
            nn.Linear(self.num_modalities, self.num_modalities),
            nn.ReLU(),
            nn.Linear(self.num_modalities, self.num_modalities),
            nn.Sigmoid()
      )

    def forward(self, x, modality_dropout_prob=0.2, training=True):
        """
        x: [B, T, total_features]
        Assumes modality-wise features are concatenated in order defined by self.modalities
        """

        projected_modalities = []
        start_idx = 0

        x, modality_mask = modality_dropout(x, self.modalities, self.variates, dropout_prob=modality_dropout_prob, training=training
        )
        dynamic_factor = self.factor_gate(modality_mask) * self.base_factor

        for idx, modality in enumerate(self.modalities):
            num_vars = self.variates[modality]
            x_m = x[:, :, start_idx:start_idx + num_vars]
            start_idx += num_vars
            # Projection
            x_m = self.input_projections[modality](x_m)
            factor = ceil(dynamic_factor[idx] * modality_mask[idx] + 1e-3 * (1 - modality_mask[idx]))

            # Add modality + temporal position encoding
            x_m = self.temporal_pos_encoder(x_m)

            # Pass through per-modality Informer layers
            for layer in self.per_modal_informers[modality]:
                x_m = layer(x_m, factor=factor)
            
            # After per-modal Add modality + temporal position encoding Again
            x_m = self.pos_encoder(x_m, modality_id=idx)
            
            projected_modalities.append(x_m)
            

        # Concatenate across modalities
        x_cat = torch.cat(projected_modalities, dim=1)  # [B, T_total, d_model]

        # Final Informer encoder with MoE
        for layer in self.informer_encoder:
            x_cat = layer(x_cat, self.base_factor)

        # Global average pooling
        x_pooled = torch.mean(x_cat, dim=1)

        return self.classifier(x_pooled), dynamic_factor


# model = CrossAttnTransformerClf(input_dim, dataset_cfg.num_classes, input_length=input_length*5, d_model=64, nhead=8, num_layers=2, dropout=0.1)

model = CrossAttnTransformerClf(
    cfg=dataset_cfg,
    num_classes=dataset_cfg.num_classes,           
    input_length=input_length,        
    d_model=64,
    nhead=8,
    num_layers=2,
    dropout=0.1,
    verbose=True,
    base_factor=base_factor,
    num_experts=num_experts
)

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
class_loss_criterion = nn.CrossEntropyLoss()
lambda_sparse = 1e-8 ## light sparsity --> increase for aggressive sparsity
print('Number of trainable parameters:', sum(p.numel() for p in model.parameters()))


# ============================ Train / Eval Functions ============================

warmup_epochs = 10
def get_modality_dropout_prob(epoch, warmup_epochs=5, max_dropout=0.5):
    if epoch < warmup_epochs:
        return 0.0
    else:
        # Linear schedule
        progress = min((epoch - warmup_epochs) / (num_epochs - warmup_epochs), 1.0)
        return progress * max_dropout

def train_one_epoch(train_loader, model, class_loss_criterion, optimizer, epoch):
    loss_meter = AverageMeter('Class Loss', ':.4f')
    acc_meter  = AverageMeter('Class Acc', ':.4f')
    model.train()
    model.zero_grad()
    modality_dropout_prob = get_modality_dropout_prob(epoch, warmup_epochs=warmup_epochs, max_dropout=0.5)

    for i, (x, y) in enumerate(train_loader):
        correct = 0
        x, y = x.to(device).float(), y.to(device)
        class_output, mod_sparse = model(x,modality_dropout_prob)
        sparsity_loss = mod_sparse.abs().mean()
        # print('Sparsity loss:', sparsity_loss.item())
        loss = class_loss_criterion(class_output, y)
        # loss = class_loss_criterion(class_output, y) + lambda_sparse * sparsity_loss

        _, predicted = torch.max(class_output.data, 1)
        correct += predicted.eq(y).sum().item()
        acc = correct / x.size(0)

        loss_meter.update(loss.item(), x.size(0))
        acc_meter.update(acc, x.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress = ProgressMeter(len(train_loader), [loss_meter, acc_meter], prefix=f"Epoch: [{epoch}]")
        if (i % 50 == 0) or (i == len(train_loader) - 1):
            progress.display(i)
            if i == len(train_loader) - 1:
                print('End of Epoch', epoch, 'Class loss is', '%.4f' % loss_meter.avg, '    Training accuracy is ', '%.4f' % acc_meter.avg)
    return loss_meter.avg, acc_meter.avg

def evaluate_one_epoch(val_loader, model, class_loss_criterion, epoch, return_preds=False):
    loss_meter = AverageMeter('Class Loss', ':.4f')
    acc_meter  = AverageMeter('Class Acc', ':.4f')
    model.eval()
    model.zero_grad()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            x, y = x.to(device).float(), y.to(device)
            class_output, _ = model(x)
            loss = class_loss_criterion(class_output, y)

            _, predicted = torch.max(class_output.data, 1)
            correct = predicted.eq(y).sum().item()
            acc = correct / x.size(0)

            loss_meter.update(loss.item(), x.size(0))
            acc_meter.update(acc, x.size(0))

            all_preds.append(predicted.cpu())
            all_labels.append(y.cpu())

            if i == len(val_loader) - 1:
                print('End of Epoch', epoch, 
                      '| Validation Class loss:', f'{loss_meter.avg:.4f}', 
                      '| Validation accuracy:', f'{acc_meter.avg:.4f}')

    y_true = torch.cat(all_labels).numpy()
    y_pred = torch.cat(all_preds).numpy()
    return loss_meter.avg, acc_meter.avg, y_true, y_pred

#  ============================ Training Loop ============================
best_val_acc = 0
best_test_acc = 0
best_val_f1 = 0
best_eval_f1 = 0
best_epoch_val = -1
best_epoch_test = -1
best_model_val_state = None
best_model_eval_state = None

test_acc_list = np.zeros(num_epochs)
train_acc_list = np.zeros(num_epochs)
val_acc_list = np.zeros(num_epochs)

for epoch in range(num_epochs):
    print('Inside Epoch : ', epoch)

    train_loss, train_acc = train_one_epoch(train_dataloader, model, class_loss_criterion, optimizer, epoch)

    val_loss, val_acc, val_y_true, val_y_pred = evaluate_one_epoch(val_dataloader, model, class_loss_criterion, epoch)
    eval_loss, eval_acc, eval_y_true, eval_y_pred = evaluate_one_epoch(eval_dataloader, model, class_loss_criterion, epoch)

    # Compute macro F1
    val_f1 = f1_score(val_y_true, val_y_pred, average='macro')
    eval_f1 = f1_score(eval_y_true, eval_y_pred, average='macro')

    writer.add_scalar("Class Loss/train", train_loss, epoch)
    writer.add_scalar("Accuracy/train", train_acc, epoch)
    writer.add_scalar("Class Loss/val", val_loss, epoch)
    writer.add_scalar("Accuracy/val", val_acc, epoch)
    writer.add_scalar("Accuracy/eval", eval_acc, epoch)
    writer.add_scalar("F1/val", val_f1, epoch)
    writer.add_scalar("F1/eval", eval_f1, epoch)
    writer.flush()

    # Best validation model by accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_f1 = val_f1
        best_epoch_val = epoch
        best_model_val_state = model.state_dict()

        save_dir = os.path.join(chkpt_pth, log_comment)
        os.makedirs(save_dir, exist_ok=True)
        torch.save(best_model_val_state, os.path.join(save_dir, "best_val_model.pth"))

    # Best eval model by F1
    if eval_f1 > best_eval_f1:
        best_eval_f1 = eval_f1
        best_test_acc = eval_acc
        best_epoch_test = epoch
        best_model_eval_state = model.state_dict()

        save_dir = os.path.join(chkpt_pth, log_comment)
        os.makedirs(save_dir, exist_ok=True)
        torch.save(best_model_eval_state, os.path.join(save_dir, "best_eval_model.pth"))

# ============================ Save Results ============================
model_stats = {
    'best_val_acc': best_val_acc,
    'best_val_f1': best_val_f1,
    'best_epoch_val': best_epoch_val,
    'best_test_acc': best_test_acc,
    'best_eval_f1': best_eval_f1,
    'best_epoch_test': best_epoch_test,
}

filename = os.path.join(results_dir, f"{log_comment}.json")
with open(filename, 'w') as f:
    json.dump(model_stats, f, indent=4)

print(f"Best validation accuracy: {best_val_acc:.4f} | F1: {best_val_f1:.4f} at epoch {best_epoch_val}")
print(f"Best eval F1 score: {best_eval_f1:.4f} | Accuracy: {best_test_acc:.4f} at epoch {best_epoch_test}")

writer.close()