import os
import numpy as np
import torch
from tqdm import tqdm

from torch.cuda.amp import autocast, GradScaler

import functools
import torch.optim as optim
from torch.optim import AdamW
from torch_ema import ExponentialMovingAverage
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_squared_error
import sklearn.metrics as skm


import pandas as pd
import pickle
import loss_util
import model_util
import utils.data_handling
import test
import utils.metric
import interference_util


def train_loop(monitor, config, training_data_pkl, dataset_name, set_val, val_x, val_y, seed):
    if "use_gpu" in config.keys():
        device = torch.device(config["use_gpu"])
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    training_data_pkl = torch.tensor(training_data_pkl, dtype=torch.float32).unsqueeze(1)
    # print(training_data_pkl.shape)
    batch_size = config["batch_size"]
    net, optimizer, scheduler, ema = model_util.get_model(config, device, training_data_pkl.shape[2], training_data_pkl)
    if config["backbone_model"] == "TTVAE":
        training_data_pkl = net.transform_data(training_data_pkl)
    dataset = utils.data_handling.CustomDataset(training_data_pkl)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=48, pin_memory=True, prefetch_factor=4, persistent_workers=True)

    output_directory = config["model_path"] + config["model_type"] + "/" + config["backbone_model"] + "_" + str(monitor.id) + "/" + dataset_name
    if not os.path.exists(output_directory):
            os.makedirs(output_directory) 
    if set_val:
        anomalie_score_fn = interference_util.get_anomalie_score(config)

    loss_fn = loss_util.get_loss_fn(config, device)
    net.to(device)
    ema.to(device)

    for epoch in tqdm(range(config["num_epochs"]+1)): 
        net.train()
        total_loss = 0.
        num_items = 0
        
        for i, X in enumerate((train_loader)):                 
            x = X.to(device, non_blocking=True)    
            loss = loss_fn(net, x)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            ema.update()
            total_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]

        epoch += 1
        scheduler.step()
        
        avg_loss = total_loss / num_items

        fin = 0
        
        monitor.update_values(dataset_name, epoch, avg_loss)
        
        if config["save_epoch"] > 0:
            if epoch % config["save_epoch"] == 0:
                checkpoint_name ="seed_" + str(seed) + "_epoch" + str(epoch) + '.pkl'
                save_model(config, net, optimizer, epoch, checkpoint_name, output_directory)
    
    checkpoint_name = "seed_" + str(seed) + 'best.pkl'
    save_model(config, net, optimizer, epoch, checkpoint_name, output_directory)
    return net

def save_model(config, net, optimizer, epoch, checkpoint_name, output_directory):
    state_dict = net.state_dict()
    torch.save({'model_state_dict': net.state_dict(),
               'optimizer_state_dict': optimizer.state_dict(),
                'ckp_epoch': epoch},
                os.path.join(output_directory, checkpoint_name))
    if config["backbone_model"] in ["TTVAE", "TabM", "MLP2048", "Base_Transformer"]:
        torch.save(net,os.path.join(output_directory, "model_object_" + checkpoint_name))
    print('model at iteration %s is saved' % epoch)
