import os
import sys
import yaml
import wandb  
import random
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from datetime import datetime, timedelta, timezone
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from src.loss import SinkhornLoss, WassersteinLoss
from src.model import create_mamba_model
from src.data.datagen import *
from src.data.dataset import TimeSeriesDataset
from src.data.dataload import load_and_split_data, create_data_loaders
from src.eval import test_model
from src.train_model import train_model

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Training model parameters')
    parser.add_argument('--config', type=str, default='example/1.yaml', help='Path to configuration file')
    return parser.parse_args()

def load_config(config_path):
    """Load and process configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    if 'device' not in config:
        config['device'] = "cuda:0" if torch.cuda.is_available() else "cpu"
    elif config['device'].startswith('cuda') and ':' not in config['device']:
        config['device'] = "cuda:0" if torch.cuda.is_available() else "cpu"
    return config

def setup_logging(config):
    """Set up logging configuration and create log file."""
    log_dir = Path(config["save"]["log_dir"])
    log_dir.mkdir(parents=True, exist_ok=True)
    beijing_time = datetime.now(timezone.utc) + timedelta(hours=8)
    timestamp = beijing_time.strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f"{config['save']['model_name']}_{timestamp}.log"

    class Logger:
        def __init__(self, filename):
            self.terminal = sys.stdout
            self.log = open(filename, 'w', encoding='utf-8')
        def write(self, message):
            self.terminal.write(message)
            self.log.write(message)
            self.log.flush()
        def flush(self):
            self.terminal.flush()
            self.log.flush()

    sys.stdout = Logger(log_file)
    return log_file

def get_save_path(config):
    """Generate save directory path with timestamp."""
    beijing_time = datetime.now(timezone.utc) + timedelta(hours=8)
    timestamp = beijing_time.strftime("%Y%m%d_%H%M%S")
    save_dir = Path(config["save"]["save_dir"]) / f"{config['save']['model_name']}_{config['data']['type']}_{timestamp}"
    save_dir.mkdir(parents=True, exist_ok=True)
    return save_dir

def save_generated_data(data, timestamps, save_dir):
    """Save generated data and timestamps to numpy files."""
    os.makedirs(save_dir, exist_ok=True)
    np.save(os.path.join(save_dir, 'data.npy'), data)
    np.save(os.path.join(save_dir, 'timestamps.npy'), timestamps)

def generate_or_load_data(config):
    """Generate new data or load existing data based on configuration."""
    if "data_path" in config["data"]:
        data_dir = config["data"]["data_path"]
        data = np.load(os.path.join(data_dir, 'data.npy'))
        timestamps = np.load(os.path.join(data_dir, 'timestamps.npy'))
        return data, timestamps

    model_name = config["save"]["model_name"]
    data_type = config["data"]["type"]
    beijing_time = datetime.now(timezone.utc) + timedelta(hours=8)
    timestamp = beijing_time.strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(config["data"]["save_dir"], f"{model_name}_{data_type}_{timestamp}")

    total_len = config["input_len"] + config["output_len"]
    data_file = os.path.join(save_dir, 'data.pt')
    if os.path.exists(data_file):
        saved = torch.load(data_file)
        return saved['data'], saved['timestamps']

    if data_type == "generate_hybrid_gbm_data":
        data, timestamps = generate_hybrid_gbm_data(
            num_samples=config["data"]["num_samples"],
            seq_len=total_len,
            input_len=config["input_len"],
            output_len=config["output_len"],
            mu_gbm1=config["data"]["mu1"],
            sigma_gbm1=config["data"]["sigma1"],
            mu_gbm2=config["data"]["mu2"],
            sigma_gbm2=config["data"]["sigma2"],
            x0_base=config["data"]["x0_base"],
            x0_perturb=config["data"]["x0_perturb"],
            switch_probability=config["data"]["switch_probability"]
        )
    elif data_type == "generate_hybrid_gbm_ou_data":
        data, timestamps = generate_hybrid_gbm_ou_data(
            num_samples=config["data"]["num_samples"],
            seq_len=total_len
        )
    # ... include other data_type branches similarly ...
    else:
        raise ValueError(f"Unsupported data type: {data_type}")

    save_generated_data(data, timestamps, save_dir)
    return data, timestamps

def main():
    """Main function to run the training pipeline."""
    args = parse_args()
    config = load_config(args.config)
    log_file = setup_logging(config)
    save_dir = get_save_path(config)
    config["model_save_path"] = str(save_dir)
    data, timestamps = generate_or_load_data(config)

    model = create_mamba_model(
        input_dim=config["input_dim"],
        output_dim=config["output_dim"],
        input_len=config["input_len"],
        output_len=config["output_len"],
        d_model=config["d_model"],
        n_layer=config["n_layer"]
    )
    d_inner = model.args.d_inner

    train_data, val_data, test_data, train_ts, val_ts, test_ts = load_and_split_data(
        data, timestamps, config["train_ratio"], config["val_ratio"]
    )
    train_loader, val_loader, test_loader = create_data_loaders(
        train_data, val_data, test_data,
        train_ts, val_ts, test_ts,
        config["input_len"], config["output_len"], d_inner, config["batch_size"]
    )

    train_model(model, train_loader, val_loader, config)
    test_model(model, test_loader, config)

if __name__ == "__main__":
    main()
