import os
import numpy as np
import csv
import torch
from torch import optim

from data.dataset import DWDataset
from utils.miscs import *
from model.classifier import *

def main(dim, lag_num, data_dir, OUTPUT_DIR, model, lr, device):

    batchsize = 2048
    optimizer = optim.Adam(model.parameters(), lr=lr)

    dataset = DWDataset(
        data_dir=data_dir,
        mode='train', 
        lag_num=lag_num
    )
    with open(f'{OUTPUT_DIR}/mse.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'lag_num', 'mse']) 

    epochs = 500

    if dataset.K == 1:
        dataset.update_weight()
    x_data, y_data, gr_weight_ori = dataset.normalized_gr()
    gr_weight = dataset.weight
    x_data = torch.from_numpy(x_data).float().to(device)
    y_data = torch.from_numpy(y_data).float().to(device)

    step = 0
    for epoch in tqdm(range(epochs)):
        if epoch % 25 == 0:
            with torch.no_grad():
                model_weight = model.update_weight(x_data.reshape([-1, dim]),
                                                y_data.reshape([-1, dim])
                ).reshape([dataset.n_traj, -1])
            model_weight /= np.mean(model_weight, axis=1, keepdims=True)
            mse = np.mean((gr_weight - model_weight)**2)
            with open(f'{OUTPUT_DIR}/mse.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([epoch, dataset.K*dataset.lag_num, mse.item()])
            

        model.train()
        loss_total = []
        iter_num = len(dataset) // batchsize
        idx_list = np.random.permutation(len(dataset))
        for batch_iter in range(iter_num):
            idx = idx_list[batch_iter*batchsize:(batch_iter+1)*batchsize]
            traj_idx = idx // (dataset.data_len - dataset.K * dataset.lag_num)
            num_idx = idx % (dataset.data_len - dataset.K * dataset.lag_num)
            x = x_data[traj_idx, num_idx].reshape([batchsize, dim])
            y = y_data[traj_idx, num_idx].reshape([batchsize, dim]) # (B, dim)
            weight = dataset.weight[traj_idx, num_idx].reshape(-1, 1) # (B, 1)
            weight = torch.from_numpy(weight).float().to(device)
            loss = model(x, y, weight)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_total.append(loss.item())
            step += 1
        
        print('Train Epoch: {}/{} ({:.0f}%)\tLogistic Regression Loss: {:.6f}'.format(
                epoch + 1, epochs, 100. * (epoch + 1) / epochs,
                np.array(loss_total).mean()))
        if epoch < 50:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.95

        if (epoch+1) % 25 == 0:
            model.eval()
            with torch.no_grad():
                model_weight = model.update_weight(x_data.reshape([-1, dim]),
                                                y_data.reshape([-1, dim])
                ).reshape([dataset.n_traj, -1])
            model_weight /= np.mean(model_weight, axis=1, keepdims=True)
            mse = np.mean((gr_weight - model_weight)**2)
            with open(f'{OUTPUT_DIR}/mse.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([epoch, dataset.K*dataset.lag_num, mse.item()])
            np.save(f'{OUTPUT_DIR}/model_weight_lag{(dataset.K)*dataset.lag_num}.npy', model_weight)
            print(f'{dataset.K*dataset.lag_num} finished!')

            dataset.K += 1
            model_weight = model_weight[:, :-dataset.lag_num]
            dataset.update_weight(model_weight)
            x_data, y_data, gr_weight_ori = dataset.normalized_gr()
            gr_weight = dataset.weight
            x_data = torch.from_numpy(x_data).float().to(device)
            y_data = torch.from_numpy(y_data).float().to(device)
            
    torch.save(model.state_dict(), f'{OUTPUT_DIR}/last.pt')    

if __name__ == '__main__':
    cuda_idx = 0
    torch.cuda.set_device(cuda_idx)
    device = 'cuda:' + str(cuda_idx)
    dim = 1
    lag_num = 50
    data_dir = '' # TODO: add data dir

    model = ClassifierMSDense2(dim=dim).to(device)
    lr = 1e-3
    OUTPUT_DIR = '' # TODO: add output dir
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
    main(dim, lag_num, data_dir, OUTPUT_DIR, model, lr, device)
