import os
import argparse
import json
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
import pickle
from tqdm import tqdm

import functools
import torch.optim as optim
from torch.optim import AdamW
from torch_ema import ExponentialMovingAverage
from torch.utils.data import Dataset, DataLoader

from .utils import CustomDataset, marginal_prob_std, loss_fn_norm, print_size, calculate_scores
from .models.mlp_models import MLPDiffusion, MLPDiffusionPara, MLPDiffusionICL, MLPDiffusionVAE,  MLPDiffusionTabSyn, MLPDiffusionBig512
from .models.mlp_models import  MLPDiffusionTabSyn1024, MLPDiffusionVAE1024, MLP2048
from .models.scorewave import SCOREWAVENET
from .models.ddpm import ResNetDiffusion

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')


def Sb_fit_val(training_data_pkl, validation_data_pkl, val_labels, dataset_name):
    training_data_pkl = torch.tensor(training_data_pkl, dtype=torch.float32).unsqueeze(1)
    batch_size = 16
    dataset = CustomDataset(training_data_pkl)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = 22
    
    # predefine model
    if model == 0:
        model_config ={}
        model_config["in_channels"] = 1
        model_config["out_channels"] = 1
        model_config["num_res_layers"] = 3
        model_config["res_channels"] = 32
        model_config["skip_channels"] = 32
        model_config["dilation_cycle"] = 3
        model_config["diffusion_step_embed_dim_in"] = 16
        model_config["diffusion_step_embed_dim_mid"] = 64
        model_config["diffusion_step_embed_dim_out"] = 64   
        net = SCOREWAVENET(**model_config).to(device)
        name = "SCOREWAVENET" 
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 1:
        model_config ={}
        model_config["in_channels"] = 1
        model_config["out_channels"] = 1
        model_config["num_res_layers"] = 36
        model_config["res_channels"] = 256
        model_config["skip_channels"] = 256
        model_config["dilation_cycle"] = 12
        model_config["diffusion_step_embed_dim_in"] = 128
        model_config["diffusion_step_embed_dim_mid"] = 512
        model_config["diffusion_step_embed_dim_out"] = 512   
        net = SCOREWAVENET(**model_config).to(device)
        name = "SCOREWAVENET" 
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 3:
        net = MLPDiffusion(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusion"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 6:
        params = {}
        params["d_main"] = 128
        params["n_blocks"] = 3
        params["d_hidden"] = 512
        params["dropout_first"] = 0
        params["dropout_second"] = 0
        net = ResNetDiffusion(params, d_in=training_data_pkl.shape[2]).to(device)
        name = "ResNet"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200 
    elif model == 9:
        net = MLPDiffusionPara(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionPara"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 10:
        net = MLPDiffusionICL(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionICL"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 11:
        net = MLPDiffusionVAE(d_in=training_data_pkl.shape[2], dim_t=4*training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionVAE"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 14:
        net = MLPDiffusionTabSyn(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionTabSyn"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 15:
        net = MLPDiffusionBig512(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionBig512"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 18:
        net = MLPDiffusionTabSyn1024(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionTabsyn1024"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 21:
        net = MLPDiffusionVAE1024(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionVAE1024"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 22:
        net = MLP2048(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLP2048"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120,160], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    else:
        print('Model chosen not available.')
    print_size(net)
    print(net)   

    output_directory = f"./NCSBAD/ckps_val/{name}/{dataset_name}"
    if not os.path.exists(output_directory):
            os.makedirs(output_directory) 

    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma = 0.01, device = device)
    aucroccomp = 0

    

    epoch = 0
    
    for epoch in tqdm(range(num_epochs+1)): #n_iter < n_iters + 1:
        net.train()
        total_loss = 0.
        num_items = 0
        # for batch in training_data:
        for i, X in enumerate(tqdm(train_loader)):             
            x = X.to(device)    
            loss = loss_fn_norm(net, x, marginal_prob_std_fn)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            ema.update()
            total_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]

        epoch += 1
        scheduler.step()
        
        avg_loss = total_loss / num_items

        print(f"Epoch: {epoch}")
            
        print('-> Average Loss: {:.2f}'.format(avg_loss))
        fin = 0
        f1_score, aucroc, aucrp, score = Sb_predict_score_val(validation_data_pkl, fin, net, dataset_name, name, val_labels)

        if aucroc > aucroccomp:
            aucroccomp = aucroc
            checkpoint_name = 'best.pkl'
            torch.save({'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'ckp_epoch': epoch},
                    os.path.join(output_directory, checkpoint_name))
            print('model at iteration %s is saved' % epoch)                      

    return net, name


def Sb_predict_score_val(test_data_pkl, fin, net, dataset_name, name, labels=[1]):
    if fin == 1:
        test_data_pkl = test_data_pkl
        train_data_pkl = test_data_pkl
        label = labels
    else:
        test_data_pkl = test_data_pkl
        label = labels
    test_data_pkl = torch.tensor(test_data_pkl, dtype=torch.float32).unsqueeze(1)
    batch_size = 512
    dataset = CustomDataset(test_data_pkl)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    if fin == 1:
        output_directory = f"./NCSBAD/ckps_val/{name}/{dataset_name}"
        model_path = os.path.join(output_directory, "best.pkl")
        print(model_path)
        checkpoint = torch.load(model_path, map_location='cpu')
        optimizer = torch.optim.Adam(net.parameters(), lr=0.0004)

        net.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        print('Successfully loaded model at iteration best')

    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma = 0.01, device = device)

    batch_scores = None

    net.eval()

    for i, x in enumerate(tqdm(test_loader)):
        sample_length = x.size(2)

        sample_batch_size = x.shape[0]
        t = torch.ones(sample_batch_size, device=device) * 1

        all_scores = None
        
        #print(x.shape)
        scores = 0.
        with torch.no_grad():

            num_iterations = 70
            extended_batch = x.unsqueeze(0).repeat(num_iterations, 1, 1, 1).to(device)  # Shape: [100, 50, 1, 6]
            extended_t = t.unsqueeze(0).repeat(num_iterations, 1)  # Shape: [100, 50]

            ms = marginal_prob_std_fn(extended_t)[:, :, None, None]
            n = torch.randn_like(extended_batch, device=device) 
            z = extended_batch + n * ms

            # Flatten the batch and time dimensions to match the network input requirements
            z_flat = z.view(-1, z.shape[-2], z.shape[-1])  # Shape: [100*50, 1, 6]
            t_flat = extended_t.view(-1)  # Shape: [100*50]

            # Pass through the network
            scores = net(z_flat, t_flat)

            # Reshape back to [100, 50, 1, 6]
            scores = scores.view(num_iterations, sample_batch_size, 1, sample_length)

            # Perform required operations
            scores = scores + n
            scores = scores.permute(0, 1, 3, 2)  # Shape: [100, 50, 6, 1]
            scores = scores.unsqueeze(2)  # Shape: [100, 50, 1, 6, 1]
            scores = (scores**2).mean(dim=2, keepdim=True)  # Shape: [100, 50, 1, 6, 1]

            # Sum the results across the 100 iterations and average
            scores = scores.sum(dim=0) #/ num_iterations  # Shape: [50, 1, 6, 1]

        scores /= 1


 

        all_scores = torch.cat((all_scores, scores), dim = 0) if all_scores != None else scores
        batch_scores = torch.cat((batch_scores, all_scores)) if batch_scores != None else all_scores


    preds = batch_scores.cpu().detach().sum(1, keepdim = True).numpy()

    if fin == 1:
        scores = []
    
        img = preds[:, 0, :, :]
            
        summed_array = np.sum(img, axis=1)

        for i in summed_array:
            scores.append(i)
        scores = np.squeeze(scores, axis=1)

        return scores
    else:
        f1_score, aucroc, aucrp, preds = calculate_scores(preds, label)

        return f1_score, aucroc, aucrp, preds

    

