import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pytorch_lightning as pl
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm

# Constants
LEARNING_RATE = 0.0001
BATCH_SIZE = 64 * 2
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True



class CNN1(nn.Module):

      def __init__(self, num_features, num_classes, temp=26):
          super().__init__()
  
          # Convolution 1
          self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(4, num_features), padding=(3, 0), dilation=(2, 1))
          self.relu1 = nn.LeakyReLU()
  
          # Convolution 2
          self.conv2 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=(4,))
          self.relu2 = nn.LeakyReLU()
  
          # Max pool 1
          self.maxpool1 = nn.MaxPool1d(kernel_size=2)
  
          # Convolution 3
          self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=(3,), padding=2)
          self.relu3 = nn.LeakyReLU()
  
          # Convolution 4
          self.conv4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=(3,), padding=2)
          self.relu4 = nn.LeakyReLU()
  
          # Max pool 2
          self.maxpool2 = nn.MaxPool1d(kernel_size=2)
  
          # Fully connected 1
          self.fc1 = nn.Linear(temp*32, 32)
          self.relu5 = nn.LeakyReLU()
  
          # Fully connected 2
          self.fc2 = nn.Linear(32, num_classes)
  
      def forward(self, x, _0, _1, _2):
          # Adding the channel dimension
          x = x[:, None, :]  # x.shape = [batch_size, 1, 100, 40]
  
          # Convolution 1
          out = self.conv1(x)
          out = self.relu1(out)
          out = out.reshape(out.shape[0], out.shape[1], -1)
  
          # Convolution 2
          out = self.conv2(out)
          out = self.relu2(out)
  
          # Max pool 1
          out = self.maxpool1(out)
  
          # Convolution 3
          out = self.conv3(out)
          out = self.relu3(out)
  
          # Convolution 4
          out = self.conv4(out)
          out = self.relu4(out)
  
          # Max pool 2
          out = self.maxpool2(out)
  
          # flatten
          out = out.view(out.size(0), -1)
  
          # Linear function 1
          out = self.fc1(out)
          out = self.relu5(out)
  
          # Linear function (readout)
          out = self.fc2(out)
  
          return out


def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2


def test(model, test_loader):
    all_predictions = []
    all_targets = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            all_predictions.append(outputs)
            all_targets.append(targets)
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    mse, corr, r2 = evaluate_predictions(all_targets, all_predictions)
    print(f"Test - MSE: {mse}, Correlation: {corr}, R2: {r2}")


def train(train_loader, val_loader, num_features, lr, decay, num_epochs, scheduler_flag=None, scheduler_step_flag=None, step_size=None, gamma=None, loss_type='MSE', loss_calculation_mp_flag=True, file_name_prefix=None):
    
    model = CNN1(
            num_features=num_features,
            num_classes=1
        )
    train_losses, valid_losses = [], []
    train_mse, train_r2, train_corr = [], [], []
    valid_mse, valid_r2, valid_corr = [], [], []

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
    if loss_type == 'MSE':
        criterion = torch.nn.MSELoss()
    elif loss_type == 'L1':
        criterion = torch.nn.L1Loss()

    if scheduler_flag:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    epochs = num_epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        total_loss = 0
        for idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")):
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if scheduler_step_flag:
            scheduler.step()
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        model.eval()
        valid_loss = 0
        all_targets, all_outputs = [], []
        for inputs, targets in val_loader:
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            all_targets.extend(targets.cpu().numpy())
            all_outputs.extend(outputs.detach().cpu().numpy())
        avg_valid_loss = valid_loss / len(val_loader)
        valid_losses.append(avg_valid_loss)
        train_mse_val = mean_squared_error(all_targets, all_outputs)
        train_r2_val = r2_score(all_targets, all_outputs)
        train_corr_val, _ = pearsonr(np.array(all_targets).flatten(), np.array(all_outputs).flatten())
        valid_mse.append(train_mse_val)
        valid_r2.append(train_r2_val)
        valid_corr.append(train_corr_val)
        all_targets, all_outputs = [], []
        for inputs, targets in train_loader:
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            all_targets.extend(targets.cpu().numpy())
            all_outputs.extend(outputs.detach().cpu().numpy())
        train_mse_t = mean_squared_error(all_targets, all_outputs)
        train_r2_t = r2_score(all_targets, all_outputs)
        train_corr_t, _ = pearsonr(np.array(all_targets).flatten(), np.array(all_outputs).flatten())
        train_mse.append(train_mse_t)
        train_r2.append(train_r2_t)
        train_corr.append(train_corr_t)
        print(f"Epoch {epoch + 1}, Training Loss: {avg_train_loss}, Validation Loss: {avg_valid_loss}")
        print(f"Training - MSE: {train_mse_t}, R2: {train_r2_t}, Corr: {train_corr_t}")
        print(f"Validation - MSE: {train_mse_val}, R2: {train_r2_val}, Corr: {train_corr_val}")
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.title('Training and Validation Loss per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    if file_name_prefix:
        plt.savefig(f"{file_name_prefix}_loss.png")
    plt.show()
    train_df = pd.DataFrame({'Epoch': range(1, epochs + 1), 'MSE': train_mse, 'R2': train_r2, 'Corr': train_corr, 'Type': 'Training'})
    valid_df = pd.DataFrame({'Epoch': range(1, epochs + 1), 'MSE': valid_mse, 'R2': valid_r2, 'Corr': valid_corr, 'Type': 'Validation'})
    metrics_df = pd.concat([train_df, valid_df])
    return model



    
