import os
import matplotlib.pyplot as plt
import numpy as np


import pickle
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from copy import deepcopy


import logging
from time import time, strftime, localtime


from utils import MyDataset, requires_grad, update_ema, create_logger

import numpy as np 
from time import time, strftime, localtime



import matplotlib.pyplot as plt
from collections import OrderedDict

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import MinMaxScaler, RobustScaler

from inception import Inception, InceptionBlock


from utils import draw_figure, show_figure, if_nan
from collections import Counter


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return torch.flatten(x, start_dim=1)

class Reshape(nn.Module):
    def __init__(self, out_shape):
        super(Reshape, self).__init__()
        self.out_shape = out_shape

    def forward(self, x):
        return x.view(-1, *self.out_shape)
# --------------------------------------------------------------------------------------------------------------------------

start_time = time()
formatted_time = strftime("%Y-%m-%d_%H-%M-%S", localtime(start_time))

# ------- create logger --------

os.makedirs(f'./results/{formatted_time}/log')
logger = create_logger(f'./results/{formatted_time}/log')
logger.info(f"Experiment directory created at './results/f'{formatted_time}'/log'")
# ---------------------------------------- Create model: ----------------------------------------------------------------

os.makedirs(f"./results/{formatted_time}/model")
device = torch.device("cuda")
model = nn.Sequential(                  # input_size = （B，C，L）
                    InceptionBlock(
                        in_channels=5, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    InceptionBlock(
                        in_channels=32*4, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    nn.AdaptiveAvgPool1d(output_size=1),
                    Flatten(),
                    nn.Linear(in_features=4*32*1, out_features=6)
        ).cuda()

# logger.info(f"Classifier(input_size=1200, in_channels=1, hidden_size=16, depth=6, num_heads=4, num_classes=3)")

ema = deepcopy(model).to(device)
requires_grad(ema, False)
# ----------------------------------------
criterion = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
# ----------------------model test -----------------------------
# x = torch.randn(32, 1, 1200)

# # ---------------------------------------- Setup data: -----------------------------------------------------------------

# --------------mix data, 1. driving data -----------------

train_dataset = torch.load(f'./dataset/argo/train_dataset.pth')
val_dataset = torch.load(f'./dataset/argo/val_dataset.pth')
test_dataset = torch.load(f'./dataset/argo/test_dataset.pth')

val_dataset = ConcatDataset([train_dataset, test_dataset])
# print(len(train_dataset))
# print(len(val_dataset))
# print(len(test_dataset))
# assert 1==2

training_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)

valid_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)

logger.info(f"Dataset contains {len(train_dataset):,} data")

# # ---------------------------------------- start training: -----------------------------------------------------------------

start_time = time()

logger.info(f"Training for {100} epochs...")

for epoch in range(800):

    model.eval()
    # -----------------valid_loss
    eval_loss = 0
    eval_steps = 0
    with torch.no_grad():
        for x, y in valid_loader:
            x = x.to(device)
            y = y.to(device).long()

            out = model(x)
            loss = criterion(out, y)
            eval_loss += loss.item()
            eval_steps +=1

    valid_loss = torch.tensor(eval_loss / eval_steps, device=device)

    # ----------------------valid correct rate -------------------
    total_num = 0
    correct_num = 0
    with torch.no_grad():
        for x, y in valid_loader:
            x = x.to(device)
            y = y.to(device).long()

            out = model(x)
            _, predicted = torch.max(out, 1)

            correct_num += (predicted == y).sum().item()
            total_num += y.shape[0] 


    model.train()
    running_loss = 0
    log_steps = 0

    logger.info(f"Beginning epoch {epoch}...")

    for x, y in training_loader:
        x = x.to(device)
        y = y.to(device).long()

        out = model(x)

        loss = criterion(out, y)

        opt.zero_grad()
        loss.backward()
        opt.step()
        update_ema(ema, model)

        # ------------------ log loss value -------------------------------

        log_steps += 1
        running_loss += loss.item()
        


    end_time = time()

    # Reduce loss history over all processes:
    avg_loss = torch.tensor(running_loss / log_steps, device=device)

    correct_rate = correct_num/total_num
    logger.info(f"(epoch={epoch:04d}) Train Loss: {avg_loss:.4f}, Correct rate: {correct_rate:.4f}")



    # ------------------ save checkpoint -------------------------------

    if epoch % 5 == 0 and epoch > 0:
        
        checkpoint = {
            "model": model.state_dict(),
            "ema": ema.state_dict(),
            # "opt": opt.state_dict(),
            # "args": args
        }
        checkpoint_path = f"./results/{formatted_time}/model/{epoch:03d}.pt"
        torch.save(checkpoint, checkpoint_path)

# checkpoint = {
#     "model": model.state_dict(),
#     "ema": ema.state_dict(),
#     # "opt": opt.state_dict(),
#     # "args": args
# }
# checkpoint_path = f"./results/{formatted_time}/model/{train_steps:07d}.pt"
# torch.save(checkpoint, checkpoint_path)

model.eval()
logger.info("Done!")

