import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import time
import os
import numpy as np
from config import Config, parse_args
from model import C10AE
from util import rand_unif
import torch.nn.functional as F
import sys
sys.path.append("../")
from methods.mcstsw import mcstsw
from methods.wd import create_g_wasserstein
from utils.func import set_seed
from utils.MCS import MCS
from utils.mixture_mcs import mixture_sample, mixture_log_prob, learn_mus_by_spread, suggest_radius_sigma

global prior_params
prior_params = None

def main():
    args = parse_args()
    Config.loss1 = args.loss1
    Config.loss2 = args.loss2
    Config.d = args.d
    Config.dataset = args.dataset
    Config.prior = args.prior
    Config.device = args.device
    Config.lr = args.lr
    Config.n_epochs = args.epochs
    Config.batch_size = args.batch_size
    Config.beta = args.beta
    Config.ntrees = args.ntrees
    Config.nlines = args.nlines
    Config.delta = args.delta
    Config.n_projs = args.n_projs
    Config.seed = args.seed

    set_seed(Config.seed)

    os.makedirs('results', exist_ok=True)

    mcs = MCS(N=3, M=3, K=torch.Tensor([0, 1, -1]).to(Config.device), device=Config.device)

    if Config.dataset == 'c10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
        ])
        train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
        test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
        model = C10AE(embedding_dim=Config.d, mcs=mcs)

    train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False)

    model = model.to(Config.device)
    criterion1 = nn.BCELoss()
    if Config.loss2 in ['mcstsw']:
        criterion2 = get_loss_func(Config.loss2, Config.device, mcs=mcs)
    else:
        criterion2 = get_loss_func(Config.loss2, Config.device)
    optimizer = optim.Adam(model.parameters(), lr=Config.lr)

    start_time = time.time()
    if args.type == "mvae":
        train_mvae(model, train_loader, criterion1, criterion2, Config.beta, optimizer, Config.device, mcs)
        wd = create_g_wasserstein(mcs=mcs, p=2)
    total_time = time.time() - start_time
    time_per_epoch = total_time / Config.n_epochs

    embeddings, BCE_losses = get_embs(model, test_loader, Config.device)
    avg_BCE = torch.cat(BCE_losses).mean().item()

    test_W2 = []
    test_NLL = []
    for embedding in embeddings:
        if Config.loss2 in ['mcstsw']:
            sphere_samples = get_prior(Config.prior, Config.d, embedding.size(0), Config.device, mcs=mcs)
        else:
            sphere_samples = get_prior(Config.prior, Config.d, embedding.size(0), Config.device)
        embedding = embedding.to(Config.device) 
        W2_dist = wd(embedding, sphere_samples)
        test_W2.append(W2_dist)

        if Config.prior == 'mixture_wrap_normal':
            global prior_params
            nll = -mixture_log_prob(mcs, prior_params, embedding)
            test_NLL.append(nll.mean())
        else:
            test_NLL.append(0.)

    avg_log_W2 = torch.Tensor(test_W2).log().mean().item()
    avg_test_NLL = torch.Tensor(test_NLL).mean().item()

    result_line = (
        f"Dataset: {Config.dataset}, "
        f"MCS: N={mcs.N}, M={mcs.M}, K={mcs.K},"
        f"Learning Rate: {Config.lr}, "
        f"Epochs: {Config.n_epochs}, "
        f"Embedding Dim: {Config.d}, "
        f"Prior: {Config.prior}, "
        f"Loss 1: {Config.loss1}, "
        f"Loss 2: {Config.loss2}, "
        + (f"NProjs: {Config.n_projs}, " if Config.loss2 != "stsw" else
        f"Trees: {Config.ntrees}, "
        f"Lines: {Config.nlines}, "
        f"Delta: {Config.delta}, ")
        + f"Beta: {Config.beta}, "
        f"Log W2: {avg_log_W2:.4f}, "
        f"Average NLL: {avg_test_NLL:.4f}, "
        f"Average BCE: {avg_BCE:.4f}\n"
        f"Seed: {Config.seed}, "
        f"Time per Epoch: {time_per_epoch:.4f}s, "
        f"Total Time: {total_time:.4f}s\n"
    )
    with open('all_results.txt', 'a') as f:
        f.write(result_line)

def train_mvae(model, train_loader, criterion1, criterion2, beta, optimizer, device, mcs: MCS):
    for epoch in tqdm(range(Config.n_epochs), desc='Training SW'):
        for data in train_loader:
            images, _ = data
            images = images.to(device)
            optimizer.zero_grad()
            outputs, embeddings = model(images)
            loss1 = criterion1(outputs, images)
            batch_prior_samples = get_prior(Config.prior, Config.d, images.size(0), device, model.mcs)
            loss2 = criterion2(embeddings, batch_prior_samples)
            loss = loss1 + beta * loss2
            loss.backward()
            optimizer.step()
    save_filename = f"results/SWAE_{Config.dataset}_lr{Config.lr}_epoch{Config.n_epochs}_dim{Config.d}_prior{Config.prior}_loss1{Config.loss1}_loss2{Config.loss2}_beta{Config.beta}_trees{Config.ntrees}_lines{Config.nlines}.pt"
    torch.save(model.state_dict(), save_filename)

def get_loss_func(loss_name, device, mcs: MCS =None):
    if loss_name == 'mcstsw':
        return lambda X, Y: mcstsw(X, Y, ncomp=mcs.N, dcomp=mcs.M, K=mcs.K,
                                   ntrees=Config.ntrees, nlines=Config.nlines, p=2, 
                                   delta=Config.delta, device=device)

def get_prior(prior, dim, n_samples, device, mcs: MCS = None):
    if prior == 'uniform':
        return rand_unif(n_samples, dim, device)
    elif prior == 'wrap_normal':
        sigma = torch.ones(mcs.N).to(mcs.device)
        return mcs.sample_wrap_normal(mcs.zeros, sigma, batch=(n_samples,))
    elif prior == 'mixture_wrap_normal':
        global prior_params
        if prior_params == None:
            s = 0.05
            K = 10

            r_per_comp = torch.Tensor([0.05]*mcs.N).to(device)

            out = learn_mus_by_spread(
                mcs, n_components=K,
                radius_per_comp=r_per_comp,
                steps=1500, lr=0.2, momentum=0.9, seed=42, verbose=True
            )
            mus = out["mus"]  # (K, N, M) on the manifold
            # per-component tangent std (broadcasts inside log_wrapped_normal)
            sigmas  = torch.full((mcs.N,), s, device=mcs.device)
            # mixture weights (uniform here)
            weights = torch.full((K,), 1.0 / K, device=mcs.device)
            prior_params = {"mus": mus, "sigmas": sigmas, "weights": weights}
        z, _ = mixture_sample(mcs, prior_params, n_samples=n_samples)     # z: (n_samples, N, M)
        return z


def get_embs(model, data_loader, device):
    model.eval()
    embeddings = []
    BCE_losses = []
    with torch.no_grad():
        for data in data_loader:
            images, _ = data
            images = images.to(device)
            outputs, embedding = model(images)
            images = images.clamp(0, 1)
            outputs = outputs.clamp(0, 1)
            BCE_loss = F.binary_cross_entropy(outputs, images, reduction='none')
            BCE_loss = BCE_loss.mean(dim=[1, 2, 3]).detach().cpu()
            BCE_losses.append(BCE_loss)
            embeddings.append(embedding.detach().cpu())
    return embeddings, BCE_losses

if __name__ == '__main__':
    main()
