import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
import csv
import torch
import os
import wandb
from copy import deepcopy

from core.data import subset_by_indices

from core.sampler import SamplerBase

class StoppingCriterion:
    def __init__(self, key, value, min=True):
        self.key = key
        self.value = value
        self.min = min
    
    def __call__(self, logs):
        if self.min:
            return np.min(logs[self.key]) < self.value
        else:
            return np.max(logs[self.key]) > self.value
    
class StoppingCriterionAveraged(StoppingCriterion):
    def __init__(self, key, value, min=True, window_size=5):
        super().__init__(key, value, min)
        self.window_size = window_size
    
    def __call__(self, logs):
        if self.min:
            return np.mean(logs[self.key][-self.window_size:]) < self.value
        else:
            return np.mean(logs[self.key][-self.window_size:]) > self.value
        
    def __str__(self):
        return f"StoppingCriterionAveraged({self.key}, {self.value}, {self.min}, {self.window_size})"

class CrossDeviceOptimizerBase:
    def __init__(self,
                 workers_count,
                 batch_size,
                 epochs,
                 seed,
                 sampler,
                 model,
                 loss,
                 cohort_optimizer,
                 cohort_optimizer_hparams,
                 local_epochs,
                 device,
                 plot=False,
                 tqdm=False,
                 log_every=5,
                 logs_filename=None,
                 process_count=1,
                 stopping_criterion=None,
                 wandb_params=None
                 ):
        self.workers_count = workers_count
        self.batch_size = batch_size
        self.epochs = epochs
        self.seed = seed
        self.sampler = sampler
        self.model = model
        self.loss = loss
        self.cohort_optimizer = cohort_optimizer
        self.cohort_optimizer_hparams = cohort_optimizer_hparams
        self.device = device
        self.logs = {
            "train_loss": [],
            "train_accuracy": [],
            "test_loss": [],
            "test_accuracy": [],
            "grad_norm": [],
            "epoch": [],
        }
        self.plot = plot
        self.local_epochs = local_epochs
        self.tqdm = tqdm
        self.log_every = log_every
        self.logs_filename = logs_filename
        self.process_count = process_count
        self.stopping_criterion = stopping_criterion
        self.wandb_params = wandb_params

        # WandB
        if wandb_params is not None:
            wandb.login(key=wandb_params["key"])
            filename = os.path.basename(logs_filename)
            wandb.init(project=wandb_params["project"], entity=wandb_params["entity"], config=wandb_params, name=filename)

    def train_routine(self, sample, train_dataset, workers_dataset_indices):
        sample_dataset_indices = [workers_dataset_indices[i] for i in sample]
        sample_datasets = [subset_by_indices(train_dataset, indices) for indices in sample_dataset_indices]
        optimizer = self.cohort_optimizer(self.cohort_optimizer_hparams, 
                                         device=self.device, 
                                         process_count=self.process_count)
        self.model = optimizer(self.model, sample_datasets, self.loss, self.local_epochs)
    
    def train(self, train_dataset, workers_dataset_indices, test_dataset=None):
        if self.plot:
            plt.ion()
        for i in range(self.epochs):
            # Logging
            if i % self.log_every == 0:
                self.log(train_dataset, test_dataset, i)
            sample = self.sampler()
            self.train_routine(sample, train_dataset, workers_dataset_indices)
        if self.plot:
            plt.ioff()
            plt.show()
        if self.wandb_params is not None:
            wandb.finish()
        return self.model

    def calculate_metrics(self, dataset):
        self.model.eval()
        batch_size = self.cohort_optimizer_hparams['batch_size']
        train_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        if self.tqdm:
            loader = tqdm(loader, desc='Evaluating train/test loss and accuracy')
        for data, target in loader:
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            # print(output, target)
            train_loss += self.loss(output, target).item()
            # Calculate accuracy
            _, predicted = torch.max(output, 1)
            total_predictions += target.size(0)
            correct_predictions += (predicted == target).sum().item()
        average_loss = train_loss / len(loader)
        accuracy = 100.0 * correct_predictions / total_predictions
        return average_loss, accuracy
    
    def calulate_grad_norm(self, test_dataset):
        # optimizer = self.cohort_optimizer_hparams["optimizer"](model_copy.parameters(), **self.cohort_optimizer_hparams["optimizer_hparams"])
        loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True, num_workers=self.process_count)
        total_norms = []
        for data, target in loader:
            self.model.zero_grad()
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            loss = self.loss(output, target)
            loss.backward()
            if len(total_norms) == 0:
                for p in self.model.parameters():
                    if p.grad is not None:
                        total_norms.append(0.0)
            i = 0
            for p in self.model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norms[i] += param_norm.item() ** 2
                    i += 1
        total_norm = np.sqrt(np.sum(total_norms))
        return total_norm

    def log(self, train_dataset, test_dataset, i):
        grad_norm = self.calulate_grad_norm(test_dataset)
        train_loss, train_accuracy = self.calculate_metrics(train_dataset)
        if test_dataset is not None:
            test_loss, test_accuracy = self.calculate_metrics(test_dataset)
        else: 
            test_loss, test_accuracy = None, None
        log_params = {
            "epoch": i,
            "train_loss": train_loss,
            "train_accuracy": train_accuracy,
            "test_loss": test_loss,
            "test_accuracy": test_accuracy,
        }
        if i > 0:
            log_params["grad_norm"] = grad_norm
        for log_key in log_params.keys():
            self.logs[log_key].append(log_params[log_key])
        if self.wandb_params is not None:
            wandb.log(log_params)
        if self.plot:
            self.do_plots(["train_loss", "train_accuracy", "test_loss", "test_accuracy"], i)
        if self.logs_filename is not None:
            self.log_training_data(self.logs_filename, i, train_loss, train_accuracy, test_loss, test_accuracy)
        if self.stopping_criterion is not None:
            if self.stopping_criterion(self.logs):
                print("Stopping criterion is satisfied")
                exit()

    def do_plots(self, names, j):
        plt.figure(figsize=(12, 8))
        for i, name in enumerate(names):
            plt.subplot(2, 2, i + 1)
            plt.plot(np.arange(0, j + 1, self.log_every), self.logs[name])
            plt.title(name)
            plt.xlabel('Epoch')
            plt.ylabel(name)
            if 'loss' in name:
                plt.yscale('log')
            plt.tight_layout()
            plt.grid(True)
        plt.draw()
        plt.pause(0.001)

    def log_training_data(self, filename, epoch, train_loss, train_accuracy, test_loss, test_accuracy):
        file_exists = os.path.isfile(filename)
        with open(filename, 'a', newline='') as file:
            writer = csv.writer(file)
            if not file_exists:
                writer.writerow(["Epoch", "Train Loss", "Train Accuracy", "Test Loss", "Test Accuracy"])
            writer.writerow([epoch, train_loss, train_accuracy, test_loss, test_accuracy])
