import numpy as np
import torch
import torch.nn.functional as F

class BaseFunction(object):
    def value(self, point):
        raise NotImplementedError()
    
    def gradient(self, point):
        raise NotImplementedError()
    
    def dim(self):
        raise NotImplementedError()


def generate_random_vector(dim, seed):
    generator = np.random.default_rng(seed)
    return generator.random((dim,), dtype=np.float32)


def generate_random_nonegative_symmetric_matrix(dim, seed, reg=None):
    generator = np.random.default_rng(seed)
    A = 2 * generator.random((dim, dim), dtype=np.float32) - 1
    B = np.dot(A, A.transpose())
    if reg is not None:
        B = B + reg * np.eye(dim, dtype=np.float32)
    return B


class QuadraticFunction(BaseFunction):
    '''
    Function f(x) = 1/2 x^T A x - b^T x
    '''
    def __init__(self, A, b, check=False):
        super(QuadraticFunction, self).__init__()
        if check:
            self._check(A)
        self._A = torch.tensor(A)
        self._b = torch.tensor(b)
    
    def value(self, point):
        with torch.no_grad():
            return ((1/2.) * torch.dot(point, torch.mv(self._A, point)) - torch.dot(self._b, point)).detach()
        
    def gradient(self, point):
        with torch.no_grad():
            return (torch.mv(self._A, point) - self._b).detach()
    
    def dim(self):
        return self._A.shape[1]
    
    def _check(self, A):
        try:
            np.linalg.cholesky(A)
        except np.linalg.LinAlgError as err:
            raise RuntimeError("Something wrong with matrix")
    
    @staticmethod
    def create_random(dim, seed=None, reg=None):
        generator = np.random.default_rng(seed=seed)
        return QuadraticFunction(generate_random_nonegative_symmetric_matrix(dim, generator, reg),
                                 generate_random_vector(dim, generator))


class StochasticQuadraticFunction(BaseFunction):
    def __init__(self, A, b, seed, noise=1.0):
        super(StochasticQuadraticFunction, self).__init__()
        self._generator = np.random.default_rng(seed)
        self._quadratic_function = QuadraticFunction(A, b)
        self._noise = noise
        
    def dim(self):
        return self._quadratic_function.dim()
    
    def stochastic_gradient(self, point):
        noise = self._noise * self._generator.normal()
        stochastic_gradient = self._quadratic_function.gradient(point)
        stochastic_gradient = stochastic_gradient + noise
        return stochastic_gradient
    
    def gradient(self, point):
        return self._quadratic_function.gradient(point)

    @staticmethod
    def create_random(dim, seed=None, reg=None, noise=1.0):
        generator = np.random.default_rng(seed=seed)
        A = generate_random_nonegative_symmetric_matrix(dim, generator, reg)
        b = generate_random_vector(dim, generator)
        return StochasticQuadraticFunction(A, b, seed=generator, noise=noise)


class LogisticRegressionFunction(BaseFunction):
    def __init__(self, X, y, rng, num_classes=2, reg=0.0, batch_size=1, iid_batch=True):
        self.X = X
        self.y = y.long()
        self.num_classes = num_classes
        assert int(torch.max(y)) + 1 <= num_classes
        self.reg = reg
        self.batch_size = batch_size
        self.n, self.d = X.shape
        
        self.iid_batch = iid_batch
        if iid_batch is False:
            self.data_indicies_iterator = None
        self.rng = rng

    def _value(self, point, X, y):
        point = point.reshape(-1, self.num_classes)
        logits = X.matmul(point)
        loss = F.cross_entropy(logits, y, reduction='mean')
        reg_term = 0.5 * self.reg * torch.sum(point ** 2)
        return loss + reg_term

    def value(self, point):
        return self._value(point[0], self.X, self.y).detach()

    def dim(self):
        return self.d * self.num_classes

    def stochastic_gradient(self, point):
        if self.iid_batch:
            indices = torch.randint(0, self.n, (self.batch_size,))
        else:
            indices, self.data_indicies_iterator = get_next_batch_indexes(
                self.rng, self.data_indicies_iterator, self.n, self.batch_size
            )
            
        X_batch = self.X[indices]
        y_batch = self.y[indices]
        
        point_tensor = torch.tensor(point[0], requires_grad=True)
        value = self._value(point_tensor, X_batch, y_batch)
        value.backward()
        grad = point_tensor.grad.detach().reshape(-1)
        return [grad]
    
    def accuracy(self, point):
        point = point[0].reshape(-1, self.num_classes)
        logits = self.X.matmul(point)
        predictions = torch.argmax(logits, dim=1)
        correct = (predictions == self.y).sum().item()
        return correct / self.n

def get_next_batch_indexes(rng, iterator, iter_len, batch_size):
    full_iter_fn = lambda: rng.permutation(np.arange(iter_len)).reshape((-1, batch_size))
    
    if iterator is None:
        iterator = iter(full_iter_fn())
    value = next(iterator, None)
    if value is None:
        iterator = iter(full_iter_fn())
        value = next(iterator, None)

    return value, iterator


from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

class ResnetClassificationFunction(BaseFunction):
    def __init__(self, train_dataloader, test_dataloader, batch_size=128, reg=0.0, rng=None, dataset='cifar10', device='cuda'):
        """
        Initialize ResNet50 classification function
        
        Args:
            train_dataloader: DataLoader for training data
            test_dataloader: DataLoader for test/validation data
            batch_size: Size of mini-batches for stochastic gradient
            reg: Regularization parameter
            rng: Random number generator
            dataset: Dataset type ('cifar10' or 'imagenet')
            device: Device to run the model on ('cuda' or 'cpu')
        """
        super().__init__()
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.batch_size = batch_size
        self.reg = reg
        self.rng = np.random.RandomState() if rng is None else rng
        self.dataset = dataset
        self.device = device
        
        self.steps = 0
        
        # Initialize model
        self.model = self._get_model(dataset).to(device)
        self.criterion = nn.CrossEntropyLoss()
        
        # Setup iterators
        self.train_iterator = iter(self.train_dataloader)
    
    def _get_model(self, dataset='cifar10'):
        """Initialize ResNet50 model based on dataset"""
        if dataset == 'cifar10':
            model = resnet50(weights=ResNet50_Weights.DEFAULT)
            # Modify the first conv layer for CIFAR-10
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            model.maxpool = nn.Identity()  # Remove maxpool as CIFAR-10 images are small
            num_classes = 10
        else:  # imagenet
            model = resnet50(weights=ResNet50_Weights.DEFAULT)
            num_classes = 1000
        
        # Modify final FC layer
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model
    
    def value(self, point, train=False):
        """Compute the loss value for the current model parameters"""
        # Set model parameters
        # self._set_model_params(point)
        
        # self.model.eval()
        with torch.no_grad():
            total_loss = 0.0
            
            if train:
                loader = self.train_dataloader
            else:
                loader = self.test_dataloader

            
            for i, (inputs, targets) in enumerate(loader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                # # Add regularization
                # reg_term = 0.0
                # for param in self.model.parameters():
                #     reg_term += torch.sum(param ** 2)
                # loss += 0.5 * self.reg * reg_term
                
                total_loss += loss.item()
                
                if i + 1 >= len(self.test_dataloader):
                    break
                
            return total_loss / len(self.test_dataloader)
    
    def dim(self):
        """Return the dimension of the model parameters"""
        return sum(p.numel() for p in self.model.parameters())
    
    def stochastic_gradient(self, point):
        """Compute stochastic gradient for the current model parameters"""
        # Set model parameters
        self.model.train()
        self.model.zero_grad(set_to_none=False)
        
        self._set_model_params(point)
        
        # if self.steps == 0:
        #     self.inputs, self.targets = next(self.train_iterator)
        #     self.steps += 1
        # inputs, targets = self.inputs, self.targets
        
        # Get next batch
        try:
            inputs, targets = next(self.train_iterator)
            self.steps +=1
            print("self.steps: ", self.steps)
        except StopIteration:
            print("reset loader")
            self.steps = 0
            self.train_iterator = iter(self.train_dataloader)
            inputs, targets = next(self.train_iterator)
        
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        
        # print("inputs: ", targets)
        
        # # Zero gradients
        # for param in self.model.parameters():
        #     if param.grad is not None:
        #         param.grad.zero_()
        
        # print("weights compute: ", list(self.model.parameters())[0].data.flatten()[:10])
        
        # Forward pass
        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)
        
        # # Add regularization
        # reg_term = 0.0
        # for param in self.model.parameters():
        #     reg_term += torch.sum(param ** 2)
        # loss += 0.5 * self.reg * reg_term
        
        # Backward pass
        loss.backward(retain_graph=True)
        
        # Collect gradients
        # grads = {}
        # for name_param, param in self.model.named_parameters():
        #     grads[name_param] = param.grad.clone()
        #     # param.data -= 0.001 * param.grad
        
        grads = []
        for param in self.model.parameters():
            grads.append(param.grad.clone())
            # param.data -= 0.001 * param.grad
        
        # self.model.zero_grad(set_to_none=False)
            
        # print("grads computed: ", grads[list(grads.keys())[0]].flatten()[:10])
        return grads
        # return torch.cat(grads).detach().cpu().numpy()
    
    def accuracy(self, point, train=False):
        """Compute accuracy for the current model parameters"""
        # Set model parameters
        # self._set_model_params(point)
        
        # self.model.eval()
        with torch.no_grad():
            correct = 0
            total = 0

            if train:
                loader = self.train_dataloader
            else:
                loader = self.test_dataloader

            for i, (inputs, targets) in enumerate(loader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                # Forward pass
                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                if i + 1 >= len(self.test_dataloader):
                    break
            
            return correct / total
    
    def _set_model_params(self, point):
        """Set model parameters from flattened point vector"""
        # point_tensor = torch.tensor(point, device=self.device)
        
        # start = 0
        # for param in self.model.parameters():
        #     numel = param.numel()
        #     param.data.copy_(point_tensor[start:start+numel].view_as(param))
        #     start += numel
        
        # for (name, p), param in zip(point, self.model.parameters()):
        #     param.data.copy_(p)
        #     param.grad = None
        
        # for name, param in self.model.named_parameters():
        #     param.data.copy_(point[name].data)
   
        for param, p in zip(self.model.parameters(), point):
            param.data.copy_(p)         

    def get_weight_point(self):
        """Get flattened vector of all model parameters"""
        # params = []
        # for param in self.model.parameters():
        #     params.append(param.data.view(-1).cpu())
        
        # return torch.cat(params)
        
        params = []
        for param in self.model.parameters():
            params.append(param.data.clone())
        
        # params = {}
        # for name_param, param in self.model.named_parameters():
        #     params[name_param] = param.clone()
        
        return params

import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel

class TransformerNextTokenPredictionFunction(BaseFunction):
    def __init__(self, train_dataloader, test_dataloader, batch_size=128, reg=0.0, rng=None, device='cuda'):
        super().__init__()
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.batch_size = batch_size
        self.reg = reg
        self.rng = np.random.RandomState() if rng is None else rng
        self.device = device
        
        self.steps = 0
        
        # Initialize model
        self.model = self._get_model().to(device)
        self.criterion = nn.CrossEntropyLoss()
        
        # Setup iterators
        self.train_iterator = iter(self.train_dataloader)

    def _get_model(self):
        # Small GPT-2 configuration
        config = GPT2Config(
            vocab_size=50257,  # Standard GPT-2 vocabulary size
            n_positions=128,   # Reduced context length
            n_embd=256,        # Reduced embedding dimension (default is 768)
            n_layer=4,         # Reduced number of layers (default is 12)
            n_head=4,          # Reduced number of attention heads (default is 12)
            activation_function='gelu_new'
        )
        
        model = GPT2LMHeadModel(config)
        model.to(self.device)
        return model

    def value(self, point, train=False):
        """Compute the loss value for the current model parameters"""

        with torch.no_grad():
            total_loss = 0.0
            
            if train:
                loader = self.train_dataloader
            else:
                loader = self.test_dataloader

            
            for i, batch in enumerate(loader):
                # Skip empty batches
                if any(t.nelement() == 0 for t in batch.values()):
                    print("Skipping empty batch")
                    continue
                    
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model(**batch)
                loss = outputs.loss
                
                total_loss += loss.item()
                
                if i + 1 >= len(self.test_dataloader):
                    break
                
            return total_loss / len(self.test_dataloader)
    
    def dim(self):
        """Return the dimension of the model parameters"""
        return sum(p.numel() for p in self.model.parameters())
    
    def stochastic_gradient(self, point):
        """Compute stochastic gradient for the current model parameters"""
        # Set model parameters
        self.model.train()
        self.model.zero_grad(set_to_none=False)
        
        self._set_model_params(point)
        
        # Get next batch
        try:
            batch = next(self.train_iterator)
            self.steps +=1
            print("self.steps: ", self.steps)
        except StopIteration:
            print("reset loader")
            self.steps = 0
            self.train_iterator = iter(self.train_dataloader)
            batch = next(self.train_iterator)
            
        batch = {k: v.to(self.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = self.model(**batch)
        loss = outputs.loss
        
        # Backward pass
        loss.backward(retain_graph=True)
        
        grads = []
        for param in self.model.parameters():
            grads.append(param.grad.clone())
        
        return grads
    
    def perplexity(self, point, train=False):
        """Compute perplexity for the current model parameters"""
        with torch.no_grad():
            eval_loss = torch.tensor(0.0, device=self.device)
            eval_steps = 0

            if train:
                loader = self.train_dataloader
            else:
                loader = self.test_dataloader

            for i, batch in enumerate(loader):                
                # Forward pass
                # Skip empty batches
                if any(t.nelement() == 0 for t in batch.values()):
                    print("Skipping empty batch")
                    continue
                    
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model(**batch)
                loss = outputs.loss
                
                eval_loss += loss.item()
                eval_steps += 1
                
                if i + 1 >= len(self.test_dataloader):
                    break
            
            if eval_steps > 0:
                eval_loss = eval_loss / eval_steps
            
            eval_perplexity = torch.exp(eval_loss).item()
            
            return eval_loss.item(), eval_perplexity
    
    def _set_model_params(self, point):
        """Set model parameters from flattened point vector"""
   
        for param, p in zip(self.model.parameters(), point):
            param.data.copy_(p)         

    def get_weight_point(self):
        """Get flattened vector of all model parameters"""
        
        params = []
        for param in self.model.parameters():
            params.append(param.data.clone())

        return params
