import torch
import torch.nn as nn
from data_loader_MLP import TrafficData
from torch.utils.data import DataLoader,Dataset,TensorDataset,random_split
import torch.optim as optim
from torch.autograd import Variable

import os
import sys
import wandb

sys.path.append(sys.path[0] + '/..')
from utils import set_seed, count_params


class MLPModel(nn.Module):
    """ Simple MLP classifer for shapes. """
    def __init__(self, in_size, out_size=1):
        super(MLPModel, self).__init__()
        self.mlp = nn.Sequential(nn.Linear(in_size, 10),
                                nn.ReLU(),
                                nn.Linear(10, 20),
                                nn.ReLU(),
                                nn.Linear(20, out_size)
                                )
        
        # self.mlp = nn.Sequential(nn.Linear(in_size, 10),
        #                         nn.ReLU(),
        #                         nn.Linear(10, 20),
        #                         nn.ReLU(),
        #                         nn.Linear(20, 10),
        #                         nn.ReLU(),
        #                         nn.Linear(10, out_size)
        #                         )

        # self.mlp = nn.Sequential(nn.Linear(3, 10),
        #                         nn.ReLU(),
        #                         nn.Linear(10, out_size)
        #                         )


    def forward(self, x):
        result = self.mlp(x)

        return result


def save_checkpoint(model, optimizer, epoch, save_path, scheduler=None):
    """
    Save training checkpoint.

    Args:
        model: PyTorch model (nn.Module)
        optimizer: optimizer
        epoch: current epoch number
        save_path: output checkpoint file path (e.g., './checkpoint.pth')
        scheduler: optional LR scheduler
    """
    state = {
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
    }

    if scheduler is not None:
        state['scheduler_state'] = scheduler.state_dict()

    torch.save(state, save_path)
    print(f"=> Checkpoint saved at: {save_path}")


def load_checkpoint(model, optimizer, load_path, scheduler=None, device='cpu'):
    """
    Load training checkpoint and resume.

    Args:
        model: PyTorch model (nn.Module)
        optimizer: optimizer
        load_path: checkpoint file path
        scheduler: optional LR scheduler
        device: 'cpu' or 'cuda'

    Returns:
        start_epoch: epoch to resume from
    """
    if not os.path.isfile(load_path):
        raise FileNotFoundError(f"No checkpoint found at {load_path}")

    checkpoint = torch.load(load_path, map_location=device)

    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])

    if scheduler is not None and 'scheduler_state' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state'])

    start_epoch = checkpoint['epoch'] + 1  # resume next epoch

    print(f"=> Checkpoint loaded from: {load_path} (epoch {checkpoint['epoch']})")
    return start_epoch


def main():
    set_seed(42)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    num_epoch = 800
    out_size = 8
    Batch_size = 56 #50
    lr = 0.0001
    # dataset = "data_class_vSimon2"
    # dataset = "data_class_v5"
    # dataset = "traffic_data"
    dataset = "real_data"
    run_name = f"Classification-epoch{num_epoch}"
    ckpt_path = f"./checkpoints/{run_name}"
    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)

    attack_types = ['all', 'elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']
    attack = 'sticker'

    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="FilterVAE-MLP",
        name=run_name,
        
        # track hyperparameters and run metadata
        config={
            "learning_rate": lr,
            "dataset": "real_data",
            "epochs": num_epoch,
        }
    )
    
    ## import data
    if "data_class_v5" in dataset.lower():
        training_set = TrafficData('../ControlVAE/data_class_v5/train_z_label_semi.csv')
        test_set = TrafficData('../ControlVAE/data_class_v5/test_z_label_semi.csv')
        feature_indices = [0, 1, 3]
    elif "data_class_vsimon2" in dataset.lower():
        training_set = TrafficData('../ControlVAE/data_class_vSimon2/train_z_label_semi.csv')
        test_set = TrafficData('../ControlVAE/data_class_vSimon2/test_z_label_semi.csv')
        feature_indices = [1, 7, 9]
    elif dataset == "traffic_data":
        training_set = TrafficData('/data/open-datasets/traffic/train/train_z_label.csv')
        test_set = TrafficData(f'/data/open-datasets/traffic/val/val_z_label_{attack}.csv')
        feature_indices = [2, 8, 9]
    elif dataset == "real_data":
        dataset = TrafficData('../ControlVAE/data_real/train_z_label_230.csv')
        dataset_size = len(dataset)
        train_size = int(0.8 * dataset_size)
        test_size = dataset_size - train_size
        training_set, test_set = random_split(dataset, [train_size, test_size])
        # feature_indices = [0, 2, 4, 6, 8] # 200
        # feature_indices = [0, 5, 6, 8] # 210
        feature_indices = [0, 1, 3, 8] # 230
    else:
        raise NameError(dataset)

    train_generator = DataLoader(training_set, batch_size=Batch_size,\
                                    shuffle=True,num_workers=1)
    test_generator = DataLoader(test_set, batch_size=Batch_size,\
                                shuffle=False,num_workers=1)
    
    model = MLPModel(len(feature_indices), out_size)
    print("Number of pytorch parameters: ", count_params(model))
    
    optimizer = optim.Adam(model.parameters(),lr=lr)  ## optimizer
    # loss_fun = nn.BCELoss()
    loss_fun = nn.CrossEntropyLoss()
    model.to(device)
    model = model.train()
    global_step = 0
    
    for epoches in range(num_epoch):
        for x, class_label, shape_label, color_label in train_generator:
            label = class_label # define feature type here
            label = label.type(torch.LongTensor)
            label = Variable(label).to(device)
            global_step += 1
            ## define feature index of z
            x_list = []
            for i in range(len(feature_indices)):
                x_list.append(x[:, feature_indices[i]: feature_indices[i]+1])
            x = torch.cat(x_list, 1)
            x = Variable(x).to(device)
            labels = torch.zeros((Batch_size, out_size))
            for i in range(labels.shape[0]):
                labels[i][label[i]] = 1.
            labels = Variable(labels).to(device)

            out_y = model(x)
            loss = loss_fun(out_y, labels)

            ## back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % 200 == 0:

                ## For binary shapes
                # out_y = torch.round(out_y)
                # accy_count = (out_y==label).sum().item()
                # accy = accy_count/len(label)
                # print("epoch: {0} step:{1} accy: {2} loss: {3}".format(epoches, global_step, accy,loss.item()))

                ## For multiple shapes
                predicted = torch.max(out_y, 1)[1]
                accy_count = (predicted==label).sum().item()
                accy = accy_count/len(label)
                print("epoch: {0} step:{1} accy: {2} loss: {3}".format(epoches, global_step, accy, loss.item()))

                wandb.log({"epoch": epoches, "acc": accy, "loss": loss.item()})
                
        if (epoches+1) % 100 == 0:
            ckpt_name = os.path.join(ckpt_path, str(epoches))
            save_checkpoint(model, optimizer, epoches, ckpt_name)



    ## for testing
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for batch_id, (x, class_label, shape_label, color_label) in enumerate(test_generator):
            label = class_label
            ## define feature index of z
            x_list = []
            for i in range(len(feature_indices)):
                x_list.append(x[:, feature_indices[i]: feature_indices[i]+1])
            x = torch.cat(x_list, 1)
            x = Variable(x).to(device)
            label = Variable(label).to(device)
            out_y = model(x)

            ## for binary shapes
            # out_y = torch.round(out_y)
            # accy_count = (out_y==label).sum().item()

            predicted = torch.max(out_y, 1)[1]
            accy_count = (predicted==label).sum().item()
            print("accy_count: ", accy_count)
            total_correct += accy_count
            total += len(label)
    accy = total_correct/total
    print("testing accy: ", accy)
    wandb.log({"test acc": accy})

    wandb.finish()


if __name__ == "__main__":
    main()