import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pytorch_lightning as pl
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm

from data_preprocessing.data_utils import load_data

LEARNING_RATE = 0.0001
BATCH_SIZE = 32
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True


def evaluate_predictions(y_true: list, y_pred: list):
    print('calculating metrics')
    assert len(y_true) == len(y_pred)
    y_true = np.array(y_true).flatten()
    y_pred = np.array(y_pred).flatten()
    mse = mean_squared_error(y_true, y_pred)
    corr, _ = pearsonr(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    return mse, corr, r2

class LeftPad1d(nn.Module):
    def __init__(self, left_pad):
        super(LeftPad1d, self).__init__()
        self.left_pad = left_pad

    def forward(self, x):
        return F.pad(x, (self.left_pad, 0))

class TransLob(pl.LightningModule):
    def __init__(self, seq_len, num_classes=3, in_c=40, out_c=14, n_attlayers=2, n_heads=3, dim_linear=64, dim_feedforward=60, dropout=.1, btch_sz=32):
        super().__init__()

        '''
        Args:
          in_c: the number of input channels for the first Conv1d layer in the CNN
          out_c: the number of output channels for all Conv1d layers in the CNN
          seq_len: the sequence length of the input data
          n_attlayers: the number of attention layers in the transformer encoder
          n_heads: the number of attention heads in the transformer encoder
          dim_linear: the number of neurons in the first linear layer (fc1)
          dim_feedforward: the number of neurons in the feed-forward layer of the transformer encoder layer
          dropout: the dropout rate for the Dropout layer
        '''

        self.conv = nn.Sequential(
            LeftPad1d(1*(2-1)),
            nn.Conv1d(in_channels=in_c, out_channels=out_c, kernel_size=2, stride=1),
            nn.ReLU(),
            LeftPad1d(2*(2-1)),
            nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=2, dilation=2),
            nn.ReLU(),
            LeftPad1d(4*(2-1)),
            nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=2, dilation=4),
            nn.ReLU(),
            LeftPad1d(8*(2-1)),
            nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=2, dilation=8),
            nn.ReLU(),
            LeftPad1d(16*(2-1)),
            nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=2, dilation=16),
            nn.ReLU(),
        )

        self.dropout = nn.Dropout(dropout)

        self.activation = nn.ReLU()

        self.pe_init(btch_sz)
      
        d_model = out_c + 1
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads,
                                                        dim_feedforward=dim_feedforward,
                                                        dropout=0.0, batch_first=True, device=cst.DEVICE_TYPE)

        self.layer_norm = nn.LayerNorm([seq_len, out_c])

        self.transformer = nn.TransformerEncoder(self.encoder_layer, n_attlayers)

        self.fc1 = nn.Linear(seq_len * d_model, dim_linear)
        self.fc2 = nn.Linear(dim_linear, num_classes)

    def forward(self, x):
        x = torch.permute(x, (0, 2, 1))   # batch, 100, 40

        # Pass the input tensor through a series of convolutional layers
        x = self.conv(x)

        # Permute the dimensions of the output from the convolutional layers so that the second dimension becomes the first
        x = x.permute(0, 2, 1)

        # Normalize the output from the convolutional layers
        x = self.layer_norm(x)

        # Apply positional encoding to the output from the layer normalization
        x = self.positional_encoding(x)

        # Pass the output from the previous steps through the transformer encoder
        x = self.transformer(x)

        # Reshape the output from the transformer encoder to have only two dimension
        x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))

      
        # Apply dropout and activation function to the output from the previous step, then pass it through the first linear layer
        x = self.dropout(self.activation(self.fc1(x)))

        # Pass the output from the previous step through the second linear layer
        x = self.fc2(x)

        return x

    def pe_init(self, btch_sz):
        n_levels = 100
        pos = torch.arange(0, n_levels, 1, dtype=torch.float32, device=cst.DEVICE_TYPE) / (n_levels - 1)
        pos = (pos + pos) - 1

        pos_final = torch.tile(pos, (btch_sz,1))
        pos_final = pos_final[:, :, None]

        self.pe = pos_final

    def positional_encoding(self, x):
        x = torch.cat((x, self.pe), 2)
        return x


def test(model, test_loader):
    all_outputs = []
    all_targets = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(targets.detach().cpu().numpy())
    mse, corr, r2 = evaluate_predictions(all_targets, all_outputs)
    print(f"Test - MSE: {mse}, Correlation: {corr}, R2: {r2}")


def train(model, train_loader, val_loader, lr, decay, num_epochs, scheduler_flag=None, scheduler_step_flag=None, step_size=None, gamma=None, loss_type='MSE', loss_calculation_mp_flag=True, file_name_prefix=None):


    train_losses, valid_losses = [], []
    train_mse, train_r2, train_corr = [], [], []
    valid_mse, valid_r2, valid_corr = [], [], []

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
    if loss_type == 'MSE':
        criterion = torch.nn.MSELoss()
    elif loss_type == 'L1':
        criterion = torch.nn.L1Loss()

    if scheduler_flag:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    epochs = num_epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        total_loss = 0
        all_targets, all_outputs = [], []
        for idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")):
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            all_targets.extend(targets.cpu().numpy())
            all_outputs.extend(outputs.detach().cpu().numpy())
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if scheduler_step_flag:
            scheduler.step()
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        train_mse_t, train_corr_t, train_r2_t = evaluate_predictions(all_targets, all_outputs)
        train_mse.append(train_mse_t)
        train_r2.append(train_r2_t)
        train_corr.append(train_corr_t)

        model.eval()
        with torch.no_grad():
            valid_loss = 0
            all_targets, all_outputs = [], []
            for inputs, targets in val_loader:
                targets = targets.unsqueeze(1)
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                all_targets.extend(targets.cpu().numpy())
                all_outputs.extend(outputs.detach().cpu().numpy())
                loss = criterion(outputs, targets)
                valid_loss += loss.item()
            avg_valid_loss = valid_loss / len(val_loader)
            valid_losses.append(avg_valid_loss)
            val_mse_t, val_r2_t, val_corr_t = evaluate_predictions(all_targets, all_outputs)
            valid_mse.append(val_mse_t)
            valid_r2.append(val_r2_t)
            valid_corr.append(val_corr_t)

        print(f"Epoch {epoch + 1}, Training Loss: {avg_train_loss}, Validation Loss: {avg_valid_loss}")
        print(f"Training - MSE: {train_mse_t}, R2: {train_r2_t}, Corr: {train_corr_t}")
        print(f"Validation - MSE: {val_mse_t}, R2: {val_r2_t}, Corr: {val_corr_t}")

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.title('Training and Validation Loss per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    if file_name_prefix:
        plt.savefig(f"mprf/{file_name_prefix}_loss.png")
    # plt.show()
    train_df = pd.DataFrame({'Epoch': range(1, epochs + 1), 'MSE': train_mse, 'R2': train_r2, 'Corr': train_corr, 'Type': 'Training'})
    valid_df = pd.DataFrame({'Epoch': range(1, epochs + 1), 'MSE': valid_mse, 'R2': valid_r2, 'Corr': valid_corr, 'Type': 'Validation'})
    metrics_df = pd.concat([train_df, valid_df])
    return model


def run_experiment(model, train_loader, val_loader, test_loader, PRED_LEN, LEARNING_RATE, WEIGHT_DECAY, BATCH_SIZE, EPOCHS, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, STEP_SIZE=None, GAMMA=None):
    seed = 1
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    file_name_prefix = f'pred{PRED_LEN}_lr{LEARNING_RATE}_wd{WEIGHT_DECAY}_bs{BATCH_SIZE}_ep{EPOCHS}_loss{LOSS_TYPE}_mp{LOSS_CALCULATION_MP_FLAG}'
    if SCHEDULER_FLAG:
        file_name_prefix += f'_sch{SCHEDULER_FLAG}_ss{STEP_SIZE}_g{GAMMA}'

    model = train(model, train_loader, val_loader, lr=LEARNING_RATE, decay=WEIGHT_DECAY, num_epochs=EPOCHS, scheduler_flag=SCHEDULER_FLAG, scheduler_step_flag=SCHEDULER_STEP_FLAG, step_size=STEP_SIZE, gamma=GAMMA, loss_type=LOSS_TYPE, loss_calculation_mp_flag=LOSS_CALCULATION_MP_FLAG, file_name_prefix=file_name_prefix)
    test(model, test_loader)
