import torch
import torch.nn as nn
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import os
import math
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader



class MyBatchSampler:
    def __init__(self, dataset_size, batch_size, shuffle=True):
        self.indices = np.arange(dataset_size)
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
        # Yield batches of indices.
        for i in range(0, len(self.indices), self.batch_size):
            yield self.indices[i:i + self.batch_size].tolist()

    def __len__(self):
        return int(np.ceil(len(self.indices) / self.batch_size))

class VectorizedDataLoader(DataLoader):
    def __iter__(self):
        # Instead of the default __iter__ that does:
        #   for indices in self.batch_sampler:
        #       yield default_collate([self.dataset[i] for i in indices])
        # We directly call __getitem__ once per batch.
        for batch_indices in self.batch_sampler:
            yield self.dataset[batch_indices]

class LinearModel(nn.Module):
    def __init__(self, delay_dim, embedding_dim,
                 channel_dim, use_bias = False, mean = None, std = None, onset_bias = False):
        super().__init__()
        self.delay_dim = delay_dim
        self.embedding_dim = embedding_dim
        self.channel_dim = channel_dim
        
        self.n_features = delay_dim * embedding_dim
        self.n_targets = channel_dim
        weight_initialization = (torch.normal(0, 1,(self.n_targets, self.n_features))/(self.n_features)).detach()
        self.weights = nn.Parameter(weight_initialization, requires_grad=True)
        self.use_bias = use_bias
        self.onset_bias = onset_bias
        
        if onset_bias:
            self.delay_bias = nn.Parameter(torch.zeros(self.delay_dim, self.channel_dim), requires_grad = True)
            
        self.bias = None
        if use_bias:
            self.bias = nn.Parameter(torch.zeros(self.n_targets))
        if mean is None:
            self.mean = mean
        else:
            self.mean = nn.Parameter(torch.tensor(mean, requires_grad=False))
        
        if std is None:
            self.std = std
        else:
            self.std = nn.Parameter(torch.tensor(std, requires_grad=False))
    
    def get_weights(self):
        return self.weights
        
    def forward(self, X):
        if not self.mean is None:
            X = X - self.mean[None, :]
        if not self.std is None:
            X = X/self.std[None, :]
        out = torch.matmul(self.weights, X.T).T
        
        if self.onset_bias:
            with torch.no_grad():
                X_onset_view = X.reshape(-1, self.embedding_dim, self.delay_dim)
                is_onset = (torch.all(X_onset_view == 0, dim=1) == False).float()
            onsets = torch.matmul(is_onset, self.delay_bias)
            out = out + onsets
            
        if self.use_bias:
            out = out + self.bias
            
        return out
    
    def numpy_forward(self, X):
        device = next(self.parameters()).device
        X = torch.tensor(X, device=device,  dtype=torch.float32)
        out = self.forward(X)
        numpy_out = out.detach().cpu().numpy()
        return numpy_out

class LowRankTensorLinearModel(nn.Module):
    def __init__(self, rank, delay_dim, embedding_dim,
                 channel_dim, use_bias = False, mean = None, std = None,
                 onset_bias = False):
        super().__init__()

        self.delay_dim = delay_dim
        self.embedding_dim = embedding_dim
        self.channel_dim = channel_dim
        self.use_bias = use_bias
        
        self.n_features = delay_dim*embedding_dim
        self.n_targets = channel_dim

        self.time_factors = nn.Parameter(torch.randn(rank, self.delay_dim)/(rank*3*self.delay_dim), requires_grad=True)
        self.space_factors = nn.Parameter(torch.randn(rank, self.channel_dim))
        self.embedding_factors = nn.Parameter(torch.randn(rank, self.embedding_dim)/(rank*3*self.embedding_dim), requires_grad=True)
        
        self.onset_bias = onset_bias
        if onset_bias:
            self.delay_bias = nn.Parameter(torch.zeros(self.delay_dim, self.channel_dim), requires_grad=True)
            
        self.bias = None
        if use_bias:
            self.bias = nn.Parameter(torch.zeros(self.channel_dim), requires_grad=True)
        
        if mean is None:
            self.mean = mean
        else:
            self.mean = nn.Parameter(torch.tensor(mean, requires_grad=False))
        
        if std is None:
            self.std = std
        else:
            self.std = nn.Parameter(torch.tensor(std, requires_grad=False))

    def get_weights(self):
        rank_weights = torch.einsum("rt,re,rs->rtes", self.time_factors, self.embedding_factors,self.space_factors)
        out_weights = torch.sum(rank_weights,dim=0).transpose(0, 1)
        dim_collapsed_weights = out_weights.reshape(-1, self.channel_dim).contiguous().T
        return dim_collapsed_weights
    
    def forward(self, X):
        if not self.mean is None:
            X = X - self.mean[None, :]
        if not self.std is None:
            X = X/self.std[None, :]
        weights = self.get_weights()
        out = torch.matmul(weights, X.T).T
        if self.use_bias:
            out = out + self.bias
        if self.onset_bias:
            with torch.no_grad():
                X_onset_view = X.reshape(-1, self.embedding_dim, self.delay_dim)
                is_onset = (torch.all(X_onset_view == 0, dim=1) == False).float()
            onsets = torch.matmul(is_onset, self.delay_bias)
            out = out + onsets
        return out
    
    def numpy_forward(self, X):
        device = next(self.parameters()).device
        X = torch.tensor(X, device=device, dtype=torch.float32)
        out = self.forward(X)
        numpy_out = out.detach().cpu().numpy()
        return numpy_out
    
class RunningAverage():
    def __init__(self):
        self.counter = 0
        self.value = torch.tensor(0)

    @torch.no_grad()
    def add(self, x):
        self.counter += 1
        self.value = self.value*(self.counter - 1)/self.counter + x/self.counter

    def reset(self):
        self.counter = 0
        self.value = torch.tensor(0)

def collate_fn(device):
    def _collate(batch):
        xs, ys = zip(*batch)
        batch_X = torch.stack(xs, dim=0)
        batch_Y = torch.stack(ys, dim=0)
        return batch_X, batch_Y
    return _collate

class SGDRegression():
    def __init__(self, model, ridge_params, device="cuda"):
        self.model = model.to(device)
        self.ridge_params = torch.tensor(ridge_params, dtype=torch.float32, device=device)
        self.device = device
    
    def fit(self, dataset, max_epochs = 1000, lr= 0.001,
            batch_size = 3000, num_workers = 0, model_save_loc = "./runs/test_run", amsgrad = False):
        if os.path.isfile(model_save_loc + "/checkpoint.pt"):
            print("Loading From Checkpoint")
            self.fit_from_checkpoint(dataset, max_epochs, lr,
                                     batch_size, num_workers, model_save_loc, amsgrad=amsgrad)
        else:
            print("Training From Scratch")
            self.fit_from_scratch(dataset, max_epochs, lr,
                                  batch_size, num_workers, model_save_loc, amsgrad=amsgrad)   
  
    def save_checkpoint(self):
        torch.save({
        'epoch': self.epoch,
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'train_losses': self.train_losses,
        'best_loss':self.best_loss
        }, self.model_save_loc + "/checkpoint.pt")

    def fit_from_scratch(self, dataset, max_epochs = 1000, lr= 0.001,
            batch_size = 3000, num_workers = 0, model_save_loc = "./runs/test_run", amsgrad=False):
        os.makedirs(model_save_loc, exist_ok=True)

        self.model_save_loc = model_save_loc
        self.train_losses = []
        batches_per_epochs = len(dataset)/batch_size
        if batches_per_epochs <= 1:
            drop_last = False
        else:
            drop_last = True
        batch_sampler = MyBatchSampler(len(dataset), batch_size=batch_size, shuffle=True)

        train_loader = VectorizedDataLoader(dataset, 
                                                   batch_sampler=batch_sampler,
                                                   num_workers=num_workers,
                                                   persistent_workers=True,
                                                   pin_memory=True
                                                   )
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = lr, amsgrad=amsgrad)
        self.best_loss = torch.inf
        print(f"Batches per epoch: {len(dataset)/batch_size}")
        self.fit_initialized(train_loader, 0, max_epochs)

    def fit_from_checkpoint(self, dataset, max_epochs = 1000, lr= 0.001,
            batch_size = 3000, num_workers = 0, model_save_loc = "./runs/test_run", amsgrad=False):
        self.model_save_loc = model_save_loc
        checkpoint = torch.load(self.model_save_loc + "/checkpoint.pt", weights_only=False)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.train_losses = checkpoint['train_losses']
        self.best_loss = checkpoint["best_loss"]
        if len(self.train_losses) >= max_epochs:
            print("Model Already Fit")
            return

        batch_sampler = MyBatchSampler(len(dataset), batch_size=batch_size, shuffle=True)

        train_loader = VectorizedDataLoader(dataset, 
                                                   batch_sampler=batch_sampler,
                                                   num_workers=num_workers,
                                                   persistent_workers=True,
                                                   pin_memory=True,
                                                   prefetch_factor=4
                                                   )
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = lr, amsgrad=amsgrad)
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print(f"Batches per epoch: {len(dataset)/batch_size}")
        self.fit_initialized(train_loader, checkpoint["epoch"] + 1, max_epochs)

    def fit_initialized(self, train_loader, start_epoch, max_epochs):
        train_loss = RunningAverage()
        for epoch in tqdm(range(start_epoch, max_epochs)):
            self.epoch = epoch
            for X_batch, Y_batch in train_loader:
                print("batch step")
                #print(X_batch.shape)
                X_batch = X_batch.to(self.device, non_blocking=True)
                Y_batch = Y_batch.to(self.device, non_blocking=True)
                self.optimizer.zero_grad(set_to_none=True)
                Y_batch_hat = self.model(X_batch)
                channelwise_l2_loss = torch.linalg.norm(self.model.get_weights(), dim=1)**2*self.ridge_params
                channelwise_mse_loss = torch.mean(torch.square(Y_batch - Y_batch_hat), dim=0)
                total_loss = torch.mean(channelwise_l2_loss+channelwise_mse_loss)
                total_loss.backward()
                self.optimizer.step()
                train_loss.add(total_loss.detach())
            self.train_losses.append(train_loss.value.detach().cpu().numpy())
            train_loss.reset()
            torch.save(self.model, self.model_save_loc + "/last_model.pt")
            if self.train_losses[-1] < self.best_loss:
                torch.save(self.model, self.model_save_loc + "/best_model.pt")
                self.best_loss = self.train_losses[-1]
            plt.figure()
            plt.plot(self.train_losses)
            plt.savefig(self.model_save_loc + "/progress.png")
            plt.cla()
            plt.close()
            self.save_checkpoint()
        
    @torch.no_grad()
    def predict(self, validation_dataset, batch_size):
        batch_sampler = MyBatchSampler(len(validation_dataset), batch_size=batch_size, shuffle=False)
        train_loader = VectorizedDataLoader(validation_dataset, batch_sampler=batch_sampler)
        out_predict = []
        out_true = []
        for X_batch, Y_batch in train_loader:
            X_batch = X_batch.to(self.device, non_blocking=True)
            Y_batch = Y_batch.to(self.device, non_blocking=True)
            Y_batch_hat = self.model(X_batch)
            Y_batch_hat_numpy = Y_batch_hat.cpu().numpy()
            out_predict.append(Y_batch_hat_numpy)
            out_true.append(Y_batch.cpu().numpy())
        return np.concatenate(out_predict, axis=0), np.concatenate(out_true, axis=0)
    