import numpy as np
import time
import json
import argparse
import sys
import os
from datetime import datetime
from utils import *
from modules import *

def train_func(dataset_name, lr=1e-4, lambda_mi=0.01, lambda_fair=0.01, lambda_cls=1.0, lambda_fpr=0.01, num_epochs=1000, experiment_dir="res"):

    X, y, s = load_data(dataset_name)
    
    num_classes = len(torch.unique(y))

    print(f"Number of classes in dataset: {num_classes}")
    
    full_dataset = TensorDataset(X, y, s)
    
    train_size = int(0.8 * len(full_dataset))

    val_size = len(full_dataset) - train_size

    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    
    train_loader = DataLoader(full_dataset, batch_size=256, shuffle=True)

    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    input_dim = X.shape[1]

    hidden_dim = 64 

    time_emb_dim = 16

    latent_dim = hidden_dim

    sensitive_dim = s.shape[1]

    num_classes = len(torch.unique(y))
    
    diffusion_model = DiffusionModel(input_dim, hidden_dim, time_emb_dim).to(device)

    mutual_info_estimator = MutualInfoEstimator(latent_dim, sensitive_dim, hidden_dim).to(device)

    classifier = Classifier(input_dim, hidden_dim, num_classes).to(device) 
    

    diff_optimizer = optim.Adam(diffusion_model.parameters(), lr=lr, weight_decay=1e-5)

    mi_optimizer = optim.Adam(mutual_info_estimator.parameters(), lr=lr, weight_decay=1e-5)

    cls_optimizer = optim.Adam(classifier.parameters(), lr=lr, weight_decay=1e-5)
    
    num_timesteps = 1000

    alpha = 0.9

    max_grad_norm = 1.0 
    
    for epoch in range(num_epochs):

        diffusion_model.train()

        classifier.train()

        epoch_losses = []
        
        for batch in train_loader:

            x, label, s_batch = batch

            batch_size = x.size(0)
            
            t = torch.randint(0, num_timesteps, (batch_size,), device=device)

            noise = torch.randn_like(x)

            x_noisy = torch.sqrt(torch.tensor(alpha, device=device)) * x + torch.sqrt(torch.tensor(1 - alpha, device=device)) * noise
            
            noise_pred = diffusion_model(x_noisy, t)

            diffusion_loss = F.mse_loss(noise_pred, noise)
            
            t_emb = diffusion_model.time_emb(t)

            latent_input = torch.cat([x_noisy, t_emb], dim=1)

            latent = F.relu(diffusion_model.fc1(latent_input))
            
            try:
                mi_loss = compute_mutual_info_loss(mutual_info_estimator, latent, s_batch)

                mi_loss = torch.clamp(mi_loss, -10, 10)  
                
                fair_loss = intersectional_fairness_loss(latent, s_batch)

                fair_loss = torch.clamp(fair_loss, 0, 10)  
                
                logits = classifier(latent)

                cls_loss = F.cross_entropy(logits, label)

                probs = F.softmax(logits, dim=1)

                fpr_reg_loss = fpr_regularizer(probs, label, s_batch)
                
                total_loss = (diffusion_loss + 
                            lambda_mi * mi_loss + 
                            lambda_fair * fair_loss + 
                            lambda_cls * cls_loss + 
                            lambda_fpr * fpr_reg_loss)
                
                diff_optimizer.zero_grad()

                mi_optimizer.zero_grad()

                cls_optimizer.zero_grad()
                
                total_loss.backward()
                
                torch.nn.utils.clip_grad_norm_(diffusion_model.parameters(), max_grad_norm)

                torch.nn.utils.clip_grad_norm_(mutual_info_estimator.parameters(), max_grad_norm)

                torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_grad_norm)
                
                diff_optimizer.step()

                mi_optimizer.step()

                cls_optimizer.step()
                

                if torch.isfinite(total_loss):

                    epoch_losses.append(total_loss.item())
                    
                if epoch % 100 == 0:

                    print(f"Losses - Diff: {diffusion_loss:.4f}, MI: {mi_loss:.4f}, Fair: {fair_loss:.4f}, Cls: {cls_loss:.4f}, FPR: {fpr_reg_loss:.4f}")
                
            except RuntimeError as e:

                print(f"Error in batch: {str(e)}")

                continue
                
        if torch.cuda.is_available():

            torch.cuda.empty_cache()
                
        if epoch_losses:

            avg_loss = sum(epoch_losses) / len(epoch_losses)

            print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_loss:.4f}")

    
    torch.save(diffusion_model.state_dict(), f"{experiment_dir}/models/diffusion_model.pt")

    torch.save(classifier.state_dict(), f"{experiment_dir}/models/classifier.pt")

    print(f"Models saved to {experiment_dir}/models/")

    
    print("Final Evaluation:")

    accuracy, group_fprs, dp = evaluate_representations(diffusion_model, classifier, val_loader, dataset_name)

    print(f"Accuracy: {accuracy}")

    print(f"Group FPRs: {group_fprs}")

    print(f"Demographic parity ratios: {dp}")
    
    eval_results = {

        "accuracy": accuracy,

        "group_fprs": group_fprs,

        "demographic_parity_ratios": dp

    }
    
    with open(f"{experiment_dir}/results/evaluation.json", 'w') as f:

        json.dump(eval_results, f, indent=4)
    
    return diffusion_model, classifier, train_dataset

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Give me dataset.")
    
    parser.add_argument(

        "--dataset", 

        type=str, 

        required=True, 

        choices=["COMPAS", "Credit", "MIMIC-IV", "MIMIC-III"],

        help="Name of the dataset."

    )

    args = parser.parse_args()

    dataset_name = args.dataset  

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    experiment_dir = f"experiments/{timestamp}_{dataset_name}"

    os.makedirs(experiment_dir, exist_ok=True)

    os.makedirs(f"{experiment_dir}/models", exist_ok=True)

    os.makedirs(f"{experiment_dir}/logs", exist_ok=True)

    os.makedirs(f"{experiment_dir}/results", exist_ok=True)
    
    log_file = f"{experiment_dir}/logs/training.log"

    logger = Logger(log_file)

    sys.stdout = logger

    sys.stderr = logger

    hyperparams = {

        "dataset_name": dataset_name,

        "seed": seed,

        "lr": 0.0001,

        "lambda_mi": 0.1,

        "lambda_fair": 0.1,

        "lambda_cls": 0.5,

        "lambda_fpr": 0.1,

        "num_epochs": 1000,

        "num_epochs": 1000,

        "experiment_timestamp": timestamp,

        "experiment_dir": experiment_dir

    }

    with open(f"{experiment_dir}/hyperparameters.json", 'w') as f:

        json.dump(hyperparams, f, indent=4)

    print(f'Training on {dataset_name} dataset')

    print(f'Setting seeds: {seed}')

    print(f'Experiment directory: {experiment_dir}')

    print(f'Hyperparameters: {json.dumps(hyperparams, indent=2)}')

    print(f'Using device: {device}')

    start_time = time.time()
    
    diffusion_model, classifier, _ = train_func(

        dataset_name=dataset_name,

        lr=hyperparams["lr"],

        lambda_mi=hyperparams["lambda_mi"],

        lambda_fair=hyperparams["lambda_fair"],

        lambda_cls=hyperparams["lambda_cls"],

        lambda_fpr=hyperparams["lambda_fpr"],

        num_epochs=hyperparams["num_epochs"],

        experiment_dir = experiment_dir

    )

    end_time = time.time()

    training_time = end_time - start_time

    hyperparams["training_time_seconds"] = training_time

    hours, rem = divmod(training_time, 3600)

    minutes, seconds = divmod(rem, 60)

    hyperparams["training_time_formatted"] = f"{int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}"
    
    with open(f"{experiment_dir}/hyperparameters.json", 'w') as f:

        json.dump(hyperparams, f, indent=4)

    print(f"Training completed in {training_time:.2f} seconds")

    print(f"Training time: {hyperparams['training_time_formatted']}")
    
    if torch.cuda.is_available():

        torch.cuda.empty_cache()
    
    logger.close()

    sys.stdout = logger.terminal

    sys.stderr = logger.terminal
    
    print(f"Experiment done!")