import numpy as np
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import os,wandb
import random
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import one_hot
from tqdm import tqdm
from torch.utils.data import Subset




def split_context(y):
    y_context = torch.zeros_like(y)
    y_context[:, :-1] = y[:, :-1]
    y_test = y[:, -1]
    return y_context,y_test

def find_perfomrance(model, dataset, loss,device='cpu',n_context = -1):
    model.eval()
    model.to(device)
    running_loss = 0
    n_samples = 0
    step = 0
    for (x,y) in dataset:

        X = x.to(device)
        Y = y.to(device)
        y_context, y_test = split_context(Y[:, :n_context])

        outputs = model(X[:,:n_context,:],y_context)

        y_hats = torch.squeeze(outputs)
        running_loss = running_loss + loss(y_hats, y_test.float()).item()*y_hats.shape[0]
        n_samples = n_samples+outputs.shape[0]
        step = step+1

        del x,y,outputs

    torch.cuda.empty_cache()
    return running_loss/n_samples





def train_model(model,criterion,dataloader_train, test_dataloaders,device,run,
                n_epochs=10,train_on = [20], test_on=  [20]):




    print(f'n_training point:{len(dataloader_train.dataset)}')
    subset_size = min(20000,len(dataloader_train.dataset))
    n_train = len(dataloader_train.dataset)
    indices = list(range(n_train))
    subset_indices = random.sample(indices, subset_size)

    if hasattr(dataloader_train.dataset, 'x'):
        X_subset_train = dataloader_train.dataset.x[subset_indices,:]
        y_subset_train = dataloader_train.dataset.y[subset_indices]
    elif hasattr(dataloader_train.dataset, 'fxi'):
        X_subset_train = dataloader_train.dataset.xi[subset_indices,:]
        y_subset_train = dataloader_train.dataset.fxi[subset_indices]
    else:
        X_subset_train = dataloader_train.dataset.xi[subset_indices,:]
        y_subset_train = dataloader_train.dataset.fx[subset_indices]


    dataset_subset_train = TensorDataset(X_subset_train, y_subset_train)
    subset_dataloader_train = DataLoader(dataset_subset_train, batch_size=128, shuffle=True)


    ########Optimizer ########
    lr = 0.00001
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in tqdm(range(n_epochs), desc="Training Epochs"):
        for (x, y) in tqdm(dataloader_train, desc="Batch Progress", leave=False):

            d = x.shape[-1]
            x = x.to(device)
            y = y.to(device)

            # autoregressive training
            for t in train_on:


                y_context_current, y_true_current = split_context(y[:, :t + 1])

                # zero the parameter gradients
                optimizer.zero_grad()
                model.train()


                # Obtain model output for the current step
                output = model(x[:, :t + 1, :], y_context_current)
                y_hat_current = torch.squeeze(output)

                # Calculate loss only for the current step's prediction
                loss = criterion(y_hat_current, y_true_current.float())
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()


        dict = {}

        for c in train_on:
            train_loss = find_perfomrance(model, subset_dataloader_train, criterion, device=device, n_context=c + 1)
            dict[f'Train loss -- CL {c}'] = train_loss

        for c in test_on:
            for test_dataloader in test_dataloaders:
                test_loss = find_perfomrance(model, test_dataloader[0], criterion, device=device, n_context=c + 1)
                dict[f'Test loss -- {test_dataloader[1]} -- CL {c}'] = test_loss

        if run:
            run.log(dict)
        else:
            for key, value in dict.items():
                print(f"{key}: {value}")














