import numpy as np
import pandas as pd

import torch

from tqdm import tqdm 

from vit import ViT
from data_utils_ebm import *
from torch.optim import AdamW


import warnings
warnings.filterwarnings('ignore')

from ebm_trainer import EMB_Trainer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

######################################################################################################################################################################################
##################################################################### MODEL INITIALIZATION ###########################################################################################
######################################################################################################################################################################################

BATCH_SIZE = 250
N_ITER = 40001

NUM_BLOCKS = 4
MODEL_DIM = 128
NUM_HEADS = 4
PATCH_SIZE = 10
FS = 100
L = 10
C_IN = 1

######################################################################################################################################################################################
########################################################################## LOADING THE DATA ##########################################################################################
######################################################################################################################################################################################

csv = np.array(pd.read_csv("training_data_2.csv"))



######################################################################################################################################################################################
########################################################################## TRAINING SET UP ###########################################################################################
######################################################################################################################################################################################





######################################################################################################################################################################################
######################################################################## TRAINING PROCEDURE ##########################################################################################
######################################################################################################################################################################################




for method in ["pclr", "clocs", "mix_up"]:
    print("\nStarting Training Procedure\n")
    print("Method: " + method)

    model = ViT(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=MODEL_DIM, 
                patch_size=PATCH_SIZE, in_channels=C_IN, fs=FS, l=L, do_prob=0.1)

    model = model.to(device)
    trainer = EMB_Trainer(model, method, model_dim=MODEL_DIM).to(device)

    optimizer = AdamW(trainer.parameters(),  lr=1e-3, betas=(0.9, 0.95))
    
    ds = SimCLR_DataLoader(csv, method=method, l = L)
    dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle=True, drop_last=True)
    
    EPOCHS = int(np.ceil(N_ITER / len(dl)))
    print(EPOCHS)
    
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                        milestones=[int(0.75 * (EPOCHS)), int(0.9 * (EPOCHS))], 
                                                        gamma=0.1)

    model_weights = method + ".pth"
    print("Model Weights: " + model_weights)
    counter = -1
    p_bar = tqdm(range(N_ITER))

    for epoch in range(EPOCHS):
        
        for batch in iter(dl):
            batch_loss = train_batch(batch, trainer, optimizer)
            counter += 1
            p_bar.update(1)
            p_bar.refresh()
            p_bar.set_description("Loss %s" % batch_loss)
            
            if counter == N_ITER:
                torch.save(trainer.base_net.to("cpu").state_dict(), model_weights)
                break
