import os
import numpy as np
import torch
from tqdm import tqdm

import functools
import torch.optim as optim
from torch.optim import AdamW
from torch_ema import ExponentialMovingAverage
from torch.utils.data import Dataset, DataLoader

import pandas as pd

from .utils import CustomDataset, marginal_prob_std, diffusion_coeff, loss_fn_norm, get_data_paths, print_size, calculate_scores
from .models.mlp_models import MLPDiffusion, MLPDiffusionPara, MLPDiffusionICL, MLPDiffusionVAE, MLPDiffusionTabSyn, MLPDiffusionBig512
from .models.mlp_models import MLPDiffusionTabSyn1024, MLPDiffusionVAE1024, MLP2048
from .models.scorewave import SCOREWAVENET
from .models.ddpm import ResNetDiffusion

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')


def Sb_fit(training_data_pkl, dataset_name):
    training_data_pkl = torch.tensor(training_data_pkl, dtype=torch.float32).unsqueeze(1)
    print(training_data_pkl.shape)
    batch_size = 16
    dataset = CustomDataset(training_data_pkl)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = 22
    
    # predefine model
    if model == 0:
        model_config ={}
        model_config["in_channels"] = 1
        model_config["out_channels"] = 1
        model_config["num_res_layers"] = 3
        model_config["res_channels"] = 32
        model_config["skip_channels"] = 32
        model_config["dilation_cycle"] = 3
        model_config["diffusion_step_embed_dim_in"] = 16
        model_config["diffusion_step_embed_dim_mid"] = 64
        model_config["diffusion_step_embed_dim_out"] = 64   
        net = SCOREWAVENET(**model_config).to(device)
        name = "SCOREWAVENET" 
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 1:
        model_config ={}
        model_config["in_channels"] = 1
        model_config["out_channels"] = 1
        model_config["num_res_layers"] = 36
        model_config["res_channels"] = 256
        model_config["skip_channels"] = 256
        model_config["dilation_cycle"] = 12
        model_config["diffusion_step_embed_dim_in"] = 128
        model_config["diffusion_step_embed_dim_mid"] = 512
        model_config["diffusion_step_embed_dim_out"] = 512   
        net = SCOREWAVENET(**model_config).to(device)
        name = "SCOREWAVENET" 
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 3:
        net = MLPDiffusion(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusion"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 6:
        params = {}
        params["d_main"] = 128
        params["n_blocks"] = 3
        params["d_hidden"] = 512
        params["dropout_first"] = 0
        params["dropout_second"] = 0
        net = ResNetDiffusion(params, d_in=training_data_pkl.shape[2]).to(device)
        name = "ResNet"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200 
    elif model == 9:
        net = MLPDiffusionPara(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionPara"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 10:
        net = MLPDiffusionICL(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionICL"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 11:
        net = MLPDiffusionVAE(d_in=training_data_pkl.shape[2], dim_t=4*training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionVAE"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 14:
        net = MLPDiffusionTabSyn(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionTabSyn"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 15:
        net = MLPDiffusionBig512(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionBig512"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 18:
        net = MLPDiffusionTabSyn1024(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionTabsyn1024"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 21:
        net = MLPDiffusionVAE1024(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionVAE1024"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 22:
        net = MLP2048(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLP2048"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120,160], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    else:
        print('Model chosen not available.')
      

    output_directory = f"./NCSBAD/ckps/{name}/{dataset_name}"
    if not os.path.exists(output_directory):
            os.makedirs(output_directory) 

    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma = 0.01, device = device)
  

    epoch = 0
    
    for epoch in tqdm(range(num_epochs+1)): 
        net.train()
        total_loss = 0.
        num_items = 0
        
        for i, X in enumerate(tqdm(train_loader)):             
            x = X.to(device)    
            loss = loss_fn_norm(net, x, marginal_prob_std_fn)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            ema.update()
            total_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]

        epoch += 1
        scheduler.step()
        
        avg_loss = total_loss / num_items

        print(f"Epoch: {epoch}")
            
        print('-> Average Loss: {:.2f}'.format(avg_loss))
        fin = 0
        
    checkpoint_name = 'best.pkl'
    torch.save({'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'ckp_epoch': epoch},
            os.path.join(output_directory, checkpoint_name))
    print('model at iteration %s is saved' % epoch)                      

    return net, name


def Sb_predict_score(test_data_pkl, fin, net, dataset_name, name, labels=[1]):

    if fin == 1:
        test_data_pkl = test_data_pkl
        train_data_pkl = test_data_pkl
        label = labels
    else:
        test_data_pkl = test_data_pkl
        label = labels
    test_data_pkl = torch.tensor(test_data_pkl, dtype=torch.float32).unsqueeze(1)
    
    batch_size = 512
    dataset = CustomDataset(test_data_pkl)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    
    if fin == 1:
        output_directory = f"./NCSBAD/ckps/{name}/{dataset_name}"
        model_path = os.path.join(output_directory, "best.pkl")
        print(model_path)
        checkpoint = torch.load(model_path, map_location='cpu')
        optimizer = torch.optim.Adam(net.parameters(), lr=0.0004)

        # feed model dict and optimizer state
        net.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        print('Successfully loaded model at iteration best')

    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma = 0.01, device = device)

    batch_scores = None

    net.eval()


    for i, x in enumerate(tqdm(test_loader)):
        sample_length = x.size(2)

        sample_batch_size = x.shape[0]
        t = torch.ones(sample_batch_size, device=device) * 1

        all_scores = None
        
        scores = 0.
        with torch.no_grad():

            num_iterations = 70
            extended_batch = x.unsqueeze(0).repeat(num_iterations, 1, 1, 1).to(device)  # Shape: [100, 50, 1, 6]
            extended_t = t.unsqueeze(0).repeat(num_iterations, 1)  # Shape: [100, 50]

            ms = marginal_prob_std_fn(extended_t)[:, :, None, None]
            n = torch.randn_like(extended_batch, device=device)
            z = extended_batch + n * ms

            # Flatten the batch and time dimensions to match the network input requirements
            z_flat = z.view(-1, z.shape[-2], z.shape[-1])  # Shape: [100*50, 1, 6]
            t_flat = extended_t.view(-1)  # Shape: [100*50]

            # Pass through the network
            scores = net(z_flat, t_flat)

            # Reshape back to [100, 50, 1, 6]
            scores = scores.view(num_iterations, sample_batch_size, 1, sample_length)

            # Perform required operations
            scores = scores + n
            scores = scores.permute(0, 1, 3, 2)  # Shape: [100, 50, 6, 1]
            scores = scores.unsqueeze(2)  # Shape: [100, 50, 1, 6, 1]
            scores = (scores**2).mean(dim=2, keepdim=True)  # Shape: [100, 50, 1, 6, 1]

            # Sum the results across the 100 iterations and average
            scores = scores.sum(dim=0) #/ num_iterations  # Shape: [50, 1, 6, 1]

        scores /= 1


 

        all_scores = torch.cat((all_scores, scores), dim = 0) if all_scores != None else scores
        batch_scores = torch.cat((batch_scores, all_scores)) if batch_scores != None else all_scores

    preds = batch_scores.cpu().detach().sum(1, keepdim = True).numpy()

    if fin == 1:
        scores = []
    
        img = preds[:, 0, :, :]
            
        summed_array = np.sum(img, axis=1)

        for i in summed_array:
            scores.append(i)

        scores = np.squeeze(scores, axis=1)

        return scores
    else:
        f1_score, aucroc, aucrp, preds = calculate_scores(preds, label)

        return f1_score, aucroc, aucrp, preds

    