import numpy as np
import pandas as pd

from mae_trainer import Mae
from cupid_trainer import Cupid
from data_utils import *

import torch
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader

from tqdm import tqdm
import random
import os

os.environ['PATH']

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)


######################################################################################################################################################################################
########################################################################## REPRODUCIBILITY ###########################################################################################
######################################################################################################################################################################################

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


######################################################################################################################################################################################
########################################################################## LOADING THE DATA ##########################################################################################
######################################################################################################################################################################################

csv = pd.read_csv("training_data_2.csv")

######################################################################################################################################################################################
##################################################################### MODEL INITIALIZATION ###########################################################################################
######################################################################################################################################################################################

ENC_NUM_BLOCKS = 4
DEC_NUM_BLOCKS = 2

ENC_DIM = 128
DEC_DIM = 128

NUM_HEADS = 4

PATCH_SIZE = 10
FS = 100
L = 10


NUM_STRIPS = 4
BATCH_SIZE = 250

N_ITER = 40001
print("\nModel Initialization")

for use_spec in [1, 0]:

    for MASKING_RATIO in [0.5, 0.4, 0.6]:


        if use_spec == 0:
            weights_name = "mae_spec_" + str(MASKING_RATIO)
            


            # Loading The Model
            trainer = Mae(enc_num_blocks=ENC_NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=ENC_DIM, do_prob=0.1, patch_size=PATCH_SIZE, 
                            in_channels=1, fs=FS, l=L, dec_num_blocks=DEC_NUM_BLOCKS, dec_dim = DEC_DIM,
                            mask_ratio=MASKING_RATIO).to(device)


            optimizer = AdamW(trainer.parameters(), 
                        lr=1e-3, 
                        betas=(0.9, 0.95))
            
        else:
            weights_name = "cupid_" + str(MASKING_RATIO)

            # Loading The Model
            trainer = Cupid(enc_num_blocks=ENC_NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=ENC_DIM, do_prob=0.1, patch_size=PATCH_SIZE, 
                            in_channels=1, fs=FS, l=L, dec_num_blocks=DEC_NUM_BLOCKS, dec_dim = DEC_DIM,
                            mask_ratio=MASKING_RATIO).to(device)


            optimizer = AdamW(trainer.parameters(), 
                        lr=1e-3, 
                        betas=(0.9, 0.95))
            
        print("Training " + str(weights_name))
        print(" ")

        ds = SHHS_DataLoader(csv, num_strips = NUM_STRIPS, fs=FS, l = L)
        dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle=True, drop_last=True)

        EPOCHS = int(np.ceil(N_ITER / len(dl)))

        scheduler_1 = LinearLR(optimizer, total_iters=10, verbose=False)
        p_bar = tqdm(range(N_ITER))


        counter = 0
        best_accu = 0

        total_rec_losses = []
        epoch_rec_losses = []

        for epoch in range(EPOCHS):
            for batch in iter(dl):

                trainer.train()
                optimizer.zero_grad()
  
                x1 = batch
                rec_loss = trainer(x1)
                

                rec_loss.backward()
                optimizer.step()
                epoch_rec_losses.append(rec_loss.item())
                
                counter += 1

                if counter == N_ITER:
                    torch.save(trainer.to("cpu").state_dict(), weights_name + ".pth")
                    trainer.to(device)
                    break

                # p_bar.set_description("Loss : %s, Dis Loss : %s, Gra Loss : %s, Static Loss : %s" % (loss.item(), dissim_loss, gradual_loss, static_loss))
                p_bar.set_description("Loss : %s " % (rec_loss.item()))
                p_bar.update(1)
                p_bar.refresh()
                
                if counter % 200 == 0:
                    scheduler_1.step()


        np.save(weights_name + "_losses.npy", total_rec_losses)



                
                