# import os
# import time
# import torch
# import numpy as np
# import pandas as pd

# from copy import deepcopy

# import src
# from utils_train import make_dataset
# from tabrep_flow.models.modules import MLPDiffusion, Model
# from tabrep_flow.models.flow_matching import ConditionalFlowMatcher

# def bits_needed(categories):
#     return 2 * np.ones_like(categories)

# def get_model(
#     model_name,
#     model_params,
#     n_num_features,
#     category_sizes
# ): 
#     # print(model_name)
#     if model_name == 'mlp':
#         model = MLPDiffusion(**model_params)
#     else:
#         raise "Unknown model!"
#     return model




# class Trainer:
#     def __init__(self, model, train_iter, lr, weight_decay, steps, model_save_path, device=torch.device('cuda:1')):
#         self.model = model
#         # self.ema_model = deepcopy(self.model.flow_net_cont)
#         # for param in self.ema_model.parameters():
#         #     param.detach_()

#         self.train_iter = train_iter
#         self.steps = steps
#         self.init_lr = lr
#         self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
#         self.device = device
#         self.loss_history = pd.DataFrame(columns=['step', 'loss'])
#         self.model_save_path = model_save_path
#         self.model.cfm_cat.total_steps = steps  # Set total steps
#         self.model.cfm_cont.total_steps = steps  # Set total steps

#         columns = list(np.arange(5)*200)
#         columns[0] = 1
#         columns = ['step'] + columns
 

#         self.log_every = 50
#         self.print_every = 1
#         self.ema_every = 1000

#     def _anneal_lr(self, step):
#         frac_done = step / self.steps
#         lr = self.init_lr * (1 - frac_done)
#         for param_group in self.optimizer.param_groups:
#             param_group["lr"] = lr

#     def _run_step(self, x):
#         x = x.to(self.device)

#         self.optimizer.zero_grad()

#         loss = self.model(x)

#         loss.backward()
#         self.optimizer.step()

#         return loss


#     def run_loop(self):
#         step = 0
#         curr_loss = 0.0

#         curr_count = 0
#         self.print_every = 1
#         self.log_every = 1

#         best_loss = np.inf
#         print('Steps: ', self.steps)
#         start_time = time.time()
#         while step < self.steps:
#             x = next(self.train_iter)[0]

            
#             batch_loss = self._run_step(x)

#             # self._anneal_lr(step)

#             curr_count += len(x)
#             curr_loss += batch_loss.item() * len(x)

#             if (step + 1) % self.log_every == 0:
#                 loss = np.around(curr_loss / curr_count, 4)
#                 if np.isnan(loss):
#                     print('Finding Nan')
#                     break
                
#                 if (step + 1) % self.print_every == 0:
#                     print(f'Step {(step + 1)}/{self.steps} Loss: {loss}')
#                 self.loss_history.loc[len(self.loss_history)] =[step + 1, loss]

#                 np.set_printoptions(suppress=True)
          
#                 curr_count = 0
#                 curr_loss = 0.0
#                 if loss < best_loss:
#                     best_loss = loss
#                     torch.save(self.model.flow_net_cat.state_dict(), os.path.join(self.model_save_path, 'model_cat.pt'))
#                     torch.save(self.model.flow_net_cont.state_dict(), os.path.join(self.model_save_path, 'model_cont.pt'))
  
#                 # if (step + 1) % 500 == 0: #
#                 #     torch.save(self.model.flow_net.state_dict(), os.path.join(self.model_save_path, f'model_{step+1}.pt')) #

#             # update_ema(self.ema_model.parameters(), self.model.flow_net.parameters())

#             step += 1
#         end_time = time.time()
#         print('Time: ', end_time - start_time)

# def train(
#     model_save_path,
#     real_data_path,
#     steps = 1000,
#     lr = 0.002,
#     weight_decay = 1e-4,
#     batch_size = 1024,
#     task_type = 'binclass',
#     model_type = 'mlp',
#     model_params = None,
#     model_params1 = None,
#     model_params2=None,
#     num_timesteps = 1000,
#     gaussian_loss_type = 'mse',
#     scheduler = 'cosine',
#     T_dict = None,
#     num_numerical_features = 0,
#     device = torch.device('cuda:0'),
#     seed = 0,
#     change_val = False
# ):
#     real_data_path = os.path.normpath(real_data_path)

#     T = src.Transformations(**T_dict)
    
#     dataset = make_dataset(
#         real_data_path,
#         T,
#         task_type = task_type,
#         change_val = False,
#     )

    
#     K = np.array(dataset.get_category_sizes('train'))
#     print(K)
    
#     num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0
#     print(num_numerical_features)

#     num_bits_per_cat_feature = bits_needed(K) if len(K) > 0 else np.array([0])


#     d_in = np.sum(num_bits_per_cat_feature) + num_numerical_features
#     # d_in = dataset.X_cat['train'].shape[1] + num_numerical_features


    
#     model_params['d_in'] = d_in
#     # print(num_bits_per_cat_feature)
#     # print(d_in)

#     train_loader = src.prepare_fast_dataloader(dataset, split='train', batch_size=batch_size)


#     model_params['d_in']= np.sum(num_bits_per_cat_feature)
#     # model_params['d_in']=dataset.X_cat['train'].shape[1]

#     flow_net_cat  = get_model(
#         model_type,
#         model_params,
#         num_numerical_features,
#         category_sizes=dataset.get_category_sizes('train')
#     )

#     model_params['d_in']= num_numerical_features 

    
#     flow_net_cont  = get_model(
#         model_type,
#         model_params,
#         num_numerical_features,
#         category_sizes=dataset.get_category_sizes('train')
#     )


#     cfm_cat = ConditionalFlowMatcher(sigma=0.0, pred_x1=False,cat =True,device=device)
#     cfm_cont = ConditionalFlowMatcher(sigma=0.0, pred_x1=False,cat=False,device=device)

#     model = Model(
#         flow_net_cat,
#         flow_net_cont,
#         cfm_cat,
#         cfm_cont,
#         num_numerical_features,
#         K,
#         num_bits_per_cat_feature,
#     )
#     model.to(device)
#     model.train()

#     trainer = Trainer(
#         model,
#         train_loader,
#         lr=lr,
#         weight_decay=weight_decay,
#         steps=steps,
#         model_save_path=model_save_path,
#         device=device
#     )
#     start_time = time.time()
#     trainer.run_loop()
#     end_time = time.time()

#     print('Training Time: ', end_time - start_time)

#     torch.save(model.flow_net_cat.state_dict(), os.path.join(model_save_path, 'model_cat.pt'))
#     torch.save(model.flow_net_cont.state_dict(), os.path.join(model_save_path, 'model_cont.pt'))
#     # torch.save(cfm_cat.time_warper.state_dict(), os.path.join(model_save_path, 'time_warper_cat.pt'))
#     # torch.save(cfm_cont.time_warper.state_dict(), os.path.join(model_save_path, 'time_warper_cont.pt'))


#     # torch.save(trainer.ema_model.state_dict(), os.path.join(model_save_path, 'model_ema.pt'))

#     trainer.loss_history.to_csv(os.path.join(model_save_path, 'loss.csv'), index=False)







import os
import time
import torch
import numpy as np
import pandas as pd

from copy import deepcopy

import src
from utils_train import make_dataset
from .models.modules import MLPDiffusion, Model
from .models.flow_matching import ConditionalFlowMatcher

def bits_needed(categories):
    return 2 * np.ones_like(categories)

# def bits_needed(K):
#     """Returns the dimensionality of the categorical features."""
#     # Each categorical feature is encoded in 2D (latitude and longitude on a sphere)
#     return [2 for _ in K]  # two angles for each feature (2D representation)


def get_model(
    model_name,
    model_params,
    n_num_features,
    category_sizes
): 
    # print(model_name)
    if model_name == 'mlp':
        model = MLPDiffusion(**model_params)
    else:
        raise "Unknown model!"
    return model

class Trainer:
    def __init__(self, cfm_model, train_iter, lr, weight_decay, steps, model_save_path, device=torch.device('cuda:1')):
        self.model = cfm_model
        self.ema_model = deepcopy(self.model.flow_net)
        for param in self.ema_model.parameters():
            param.detach_()

        self.train_iter = train_iter
        self.steps = steps
        self.init_lr = lr
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.device = device
        self.loss_history = pd.DataFrame(columns=['step', 'loss'])
        self.model_save_path = model_save_path

        self.cfm_model = cfm_model  # Store reference to CFM
        self.cfm_model.cfm.total_steps = steps  # Set total steps

        columns = list(np.arange(5)*200)
        columns[0] = 1
        columns = ['step'] + columns
 

        self.log_every = 50
        self.print_every = 1
        self.ema_every = 1000

    def _anneal_lr(self, step):
        frac_done = step / self.steps
        lr = self.init_lr * (1 - frac_done)
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def _run_step(self, x):
        x = x.to(self.device)

        self.optimizer.zero_grad()

        loss = self.model(x)

        loss.backward()
        self.optimizer.step()

        self.cfm_model.cfm.current_step += 1

        return loss


    def run_loop(self):
        step = 0
        curr_loss = 0.0

        curr_count = 0
        self.print_every = 1
        self.log_every = 1

        best_loss = np.inf
        print('Steps: ', self.steps)
        start_time = time.time()
        while step < self.steps:
            x = next(self.train_iter)[0]

            
            batch_loss = self._run_step(x)

            # self._anneal_lr(step)

            curr_count += len(x)
            curr_loss += batch_loss.item() * len(x)

            if (step + 1) % self.log_every == 0:
                loss = np.around(curr_loss / curr_count, 4)
                if np.isnan(loss):
                    print('Finding Nan')
                    break
                
                if (step + 1) % self.print_every == 0:
                    print(f'Step {(step + 1)}/{self.steps} Loss: {loss}')
                self.loss_history.loc[len(self.loss_history)] =[step + 1, loss]

                np.set_printoptions(suppress=True)
          
                curr_count = 0
                curr_loss = 0.0
                if loss < best_loss:
                    best_loss = loss
                    torch.save(self.model.flow_net.state_dict(), os.path.join(self.model_save_path, 'model.pt'))
  
                # if (step + 1) % 500 == 0: #
                #     torch.save(self.model.flow_net.state_dict(), os.path.join(self.model_save_path, f'model_{step+1}.pt')) #

            # update_ema(self.ema_model.parameters(), self.model.flow_net.parameters())

            step += 1
        end_time = time.time()
        print('Time: ', end_time - start_time)

def train(
    model_save_path,
    real_data_path,
    steps = 1000,
    lr = 0.002,
    weight_decay = 1e-4,
    batch_size = 1024,
    task_type = 'binclass',
    model_type = 'mlp',
    model_params = None,
    num_timesteps = 1000,
    gaussian_loss_type = 'mse',
    scheduler = 'cosine',
    T_dict = None,
    num_numerical_features = 0,
    device = torch.device('cuda:0'),
    seed = 0,
    change_val = False
):
    real_data_path = os.path.normpath(real_data_path)

    T = src.Transformations(**T_dict)
    
    dataset = make_dataset(
        real_data_path,
        T,
        task_type = task_type,
        change_val = False,
    )

    print('dataset',dataset)

    K = np.array(dataset.get_category_sizes('train'))

    num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0


    num_bits_per_cat_feature = bits_needed(K) if len(K) > 0 else np.array([0])

    print('num_bits_per_cat_feature',num_bits_per_cat_feature)
    


    d_in = np.sum(num_bits_per_cat_feature) + num_numerical_features


    model_params['d_in'] = d_in
    # print(num_bits_per_cat_feature)
    # print(d_in)

    train_loader = src.prepare_fast_dataloader(dataset, split='train', batch_size=batch_size)

    flow_net = get_model(
        model_type,
        model_params,
        num_numerical_features,
        category_sizes=dataset.get_category_sizes('train')
    )
    cfm = ConditionalFlowMatcher(sigma=0e-2, pred_x1=False,device=device,visual=False)
    model = Model(
        flow_net,
        cfm,
        num_numerical_features,
        K,
        num_bits_per_cat_feature,
    )
    model.to(device)
    model.train()

    trainer = Trainer(
        model,
        train_loader,
        lr=lr,
        weight_decay=weight_decay,
        steps=steps,
        model_save_path=model_save_path,
        device=device
    )
    start_time = time.time()
    trainer.run_loop()
    end_time = time.time()

    print('Training Time: ', end_time - start_time)

    torch.save(model.flow_net.state_dict(), os.path.join(model_save_path, 'model.pt'))
    # torch.save(trainer.ema_model.state_dict(), os.path.join(model_save_path, 'model_ema.pt'))
    # torch.save(cfm.time_warper.state_dict(), os.path.join(model_save_path, 'time_warper.pt'))


    trainer.loss_history.to_csv(os.path.join(model_save_path, 'loss.csv'), index=False)