import os
import pickle
from typing import Dict
import numpy as np

import torch
import torch.optim as optim
from tqdm import tqdm
import time
from torch.utils.data import DataLoader

from DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer
from ConditionFiLMNet1D import EpsNet_FiLMMLP
from DiffusionFreeGuidence.dataset import Feature_dataset
from Scheduler import GradualWarmupScheduler


def preprocess_features(features, method='none'):
    if method == 'none':
        stats = {'method': 'none'}
        return features, stats
    
    elif method == 'zscore':
        mean = features.mean()
        std = features.std()
        normalized = (features - mean) / std
        stats = {'method': 'zscore', 'mean': mean, 'std': std}
        return normalized, stats
    
    elif method == 'minmax':
        min_val = features.min()
        max_val = features.max()
        normalized = 2 * (features - min_val) / (max_val - min_val) - 1
        stats = {'method': 'minmax', 'min': min_val, 'max': max_val}
        return normalized, stats
    
    elif method == 'robust':
        q25 = torch.quantile(features, 0.25)
        q75 = torch.quantile(features, 0.75)
        normalized = (features - q25) / (q75 - q25) * 2 - 1
        stats = {'method': 'robust', 'q25': q25, 'q75': q75}
        return normalized, stats
    
    else:
        raise ValueError(f"Unknown preprocessing method: {method}")


def denormalize_features(features, stats):
    if stats['method'] == 'none':
        return features
    elif stats['method'] == 'zscore':
        return features * stats['std'] + stats['mean']
    elif stats['method'] == 'minmax':
        return (features + 1) / 2 * (stats['max'] - stats['min']) + stats['min']
    elif stats['method'] == 'robust':
        return (features + 1) / 2 * (stats['q75'] - stats['q25']) + stats['q25']
    else:
        raise ValueError(f"Unknown denormalization method: {stats['method']}")


def train(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    with open('feature-embedding/cifar100-lt', 'rb') as file:
        data = pickle.load(file)

    train_data = data['feature'].to(torch.float32).unsqueeze(1)
    train_label = data['label'].long().squeeze()
    C = modelConfig["num_labels"]
    D = train_data.size(-1)
    class_means = torch.zeros(C, 1, D, device=train_data.device)
    for c in range(C):
        idx = (train_label == c)
        mu_c = train_data[idx].mean(dim=0, keepdim=True)
        train_data[idx] = train_data[idx] - mu_c
        class_means[c] = mu_c
    mean_path = os.path.join(modelConfig["save_dir"], "class_means.pt")
    torch.save(class_means.cpu(), mean_path)
    print(f"[mean] saved to {mean_path}, shape={tuple(class_means.shape)}")  # [C,1,D]
    
    # Choose preprocessing method - 'none' is often best for feature data
    preprocessing_method = modelConfig.get('preprocessing', 'none')  # 'none', 'zscore', 'minmax', 'robust'
    train_data, norm_stats = preprocess_features(train_data, preprocessing_method)
    
    print(f"Feature preprocessing: {preprocessing_method}")
    print(f"Feature data stats: mean={train_data.mean():.4f}, std={train_data.std():.4f}")
    print(f"Feature data range: [{train_data.min():.4f}, {train_data.max():.4f}]")
    
    # Save normalization stats for later use
    modelConfig['norm_stats'] = norm_stats
    dataset = Feature_dataset(train_data, train_label)
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)

    # model setup
    # net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
    #                  num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    # net_model = UNet1D(T=modelConfig["T"], num_labels=100, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
    #                  num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    net_model = EpsNet_FiLMMLP(T=modelConfig["T"], num_labels=100, ch=modelConfig["channel"],
                               n_blocks=modelConfig["n_blocks"], dropout=modelConfig["dropout"],
                               data_dim=modelConfig["data_dim"]).to(device)

    if modelConfig["training_load_weight"] is not None:
        net_model.load_state_dict(torch.load(os.path.join(
            modelConfig["save_dir"], modelConfig["training_load_weight"]), map_location=device), strict=False)
        print("Model weight load down.")
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
    warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"],
                                             warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

    # start training
    best_loss = float('inf')
    best_epoch = -1
    best_path = os.path.join(modelConfig["save_dir"], "ckpt_best_Film_manifold_k=3_T=100_48020.pt")
    last_path = os.path.join(modelConfig["save_dir"], "ckpt_last_Film_manifold_k=3_T=100_48020.pt")
    for e in range(modelConfig["epoch"]):
        epoch_loss_sum, epoch_batches = 0.0, 0
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                b = images.shape[0]
                optimizer.zero_grad()
                x_0 = images.to(device)
                labels = labels.to(device) + 1
                if np.random.rand() < 0.1:
                    labels = torch.zeros_like(labels).to(device)
                loss = trainer(x_0, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])
                optimizer.step()

                epoch_loss_sum += loss.item()
                epoch_batches += 1
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })
        warmUpScheduler.step()

        epoch_avg_loss = epoch_loss_sum / max(epoch_batches, 1)
        torch.save(net_model.state_dict(), last_path)
        if epoch_avg_loss < best_loss:
            best_loss, best_epoch = epoch_avg_loss, e
            torch.save(net_model.state_dict(), best_path)
            print(f"[BEST] epoch={best_epoch} avg_loss={best_loss:.6f} -> {best_path}")


def eval(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    # load model and evaluate
    total_num = 20000
    num_gen = 500
    num_cls = 100
    gen_epoch = total_num // num_gen
    gen_feature = []
    gen_label = []

    # model = UNet1D(T=modelConfig["T"], num_labels=100, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
    #                num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    model = EpsNet_FiLMMLP(T=modelConfig["T"], num_labels=100, ch=modelConfig["channel"],
                               n_blocks=modelConfig["n_blocks"], dropout=modelConfig["dropout"],
                               data_dim=modelConfig["data_dim"]).to(device)

    ckpt = torch.load(os.path.join(
        modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
    model.load_state_dict(ckpt)
    print("model load weight done.")
    model.eval()

    for epoch in range(gen_epoch):
        with torch.no_grad():
            # step = int(modelConfig["batch_size"] // 10)
            step = int(num_gen // num_cls)
            labelList = []
            k = 0
            for i in range(1, num_gen + 1):
                labelList.append(torch.ones(size=[1]).long() * k)
                if i % step == 0:
                    if k < num_cls - 1:
                        k += 1
            labels = torch.cat(labelList, dim=0).long().to(device) + 1
            gen_label.append(torch.cat(labelList, dim=0).long())

            sampler = GaussianDiffusionSampler(
                model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
            # Sampled from standard normal distribution
            noisyImage = torch.randn(size=[num_gen, 1, 256], device=device)
            saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)

            sampledImgs = sampler.ddim_sample(noisyImage, labels, epoch)
            gen_feature.append(sampledImgs.detach().cpu())


    gen_feature = torch.cat(gen_feature)
    gen_label = torch.cat(gen_label)
    
    # Denormalize generated features back to original scale if needed
    if 'norm_stats' in modelConfig and modelConfig['norm_stats']['method'] != 'none':
        gen_feature = denormalize_features(gen_feature, modelConfig['norm_stats'])
        print(f"Denormalized features back to original scale using {modelConfig['norm_stats']['method']}")

    recoder = {}
    recoder['gen_feature'] = gen_feature
    recoder['gen_label'] = gen_label
    with open('gen-feature/gen_all', 'wb') as file:
        pickle.dump(recoder, file)


def generate_features(modelConfig: Dict, target_classes=None, samples_per_class=2000, output_file=None):
    device = torch.device(modelConfig["device"])

    model = EpsNet_FiLMMLP(T=modelConfig["T"], num_labels=100, ch=modelConfig["channel"],
                           n_blocks=modelConfig["n_blocks"], dropout=modelConfig["dropout"],
                           data_dim=modelConfig["data_dim"]).to(device)

    ckpt = torch.load(os.path.join(
        modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
    model.load_state_dict(ckpt)
    print("Model loaded successfully.")
    model.eval()
    
    # Set target classes
    if target_classes is None:
        target_classes = list(range(modelConfig.get("num_labels", 100)))
    
    print(f"Generating {samples_per_class} samples for {len(target_classes)} classes: {target_classes}")

    mean_path = modelConfig.get("mean_path",
                                os.path.join(modelConfig["save_dir"], "class_mean_soft.pt"))
    class_means = torch.load(mean_path, map_location=device).float().to(device)  # [C,1,D]
    
    # Initialize sampler
    sampler = GaussianDiffusionSampler(
        model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
    
    gen_feature = []
    gen_label = []
    
    # Generate features class by class
    for class_idx in target_classes:
        print(f"Generating class {class_idx}...")
        
        with torch.no_grad():
            # Create labels for this class
            labels = torch.full((samples_per_class,), class_idx, dtype=torch.long, device=device)
            labels_for_model = labels + 1
            
            # Generate noise
            noisyImage = torch.randn(size=[samples_per_class, 1, 256], device=device)
            
            # Sample using DDIM
            sampledImgs = sampler.ddim_sample(noisyImage, labels_for_model, class_idx)
            sampledImgs = sampledImgs + class_means[labels]  # [B,1,D] + [B,1,D]
            
            # Store results
            gen_feature.append(sampledImgs.detach().cpu())
            gen_label.append(labels.cpu())
    
    # Concatenate all results
    gen_feature = torch.cat(gen_feature, dim=0)
    gen_label = torch.cat(gen_label, dim=0)
    
    # Denormalize if needed
    if 'norm_stats' in modelConfig and modelConfig['norm_stats']['method'] != 'none':
        gen_feature = denormalize_features(gen_feature, modelConfig['norm_stats'])
        print(f"Denormalized features back to original scale using {modelConfig['norm_stats']['method']}")
    
    # Save results
    if output_file is None:
        if target_classes == list(range(modelConfig.get("num_labels", 100))):
            output_file = f'gen_feature_all_classes_{samples_per_class}'
        else:
            output_file = f'gen_feature_classes_{len(target_classes)}_{samples_per_class}'
    
    recoder = {
        'gen_feature': gen_feature,
        'gen_label': gen_label,
        'target_classes': target_classes,
        'samples_per_class': samples_per_class
    }
    
    with open(output_file, 'wb') as file:
        pickle.dump(recoder, file)
    
    print(f"Generated features saved to: {output_file}")
    print(f"Total samples: {gen_feature.shape[0]}")
    print(f"Feature shape: {gen_feature.shape[1:]}")
    
    return gen_feature, gen_label


def eval_gen_tail(modelConfig: Dict):
    tail_class = [i for i in range(15, modelConfig["num_labels"])]
    start_time = time.time()
    # Generate features for tail classes
    generate_features(
        modelConfig=modelConfig,
        target_classes=tail_class,
        samples_per_class=500,
        output_file='gen-feature/gen_new'
    )
    print(f'{time.time() - start_time} seconds')

def eval_gen_all(modelConfig: Dict):

    """Generate features for all classes."""
    generate_features(
        modelConfig=modelConfig,
        target_classes=None,  # All classes
        samples_per_class=1000,
        output_file='gen_feature_all'
    )
