
import torch
from .Networks import CA_CNN_Convolutional
from .Automaton import CoarseWrapper, LifeLikeAutomaton2D
import wandb

class DataLogger():

    def __init__(self, use_wandb=True, use_local=True, wandbProject='PredictingCA2025', wandbRunName='Run', localPath='results/', localName='Run', console=True, wandb_run=None):
        self.use_wandb = use_wandb
        self.use_local = use_local

        if self.use_wandb:
            import wandb
            if wandb_run is not None:
                self.wandb = wandb_run
            else:
                self.wandb = wandb
                self.wandb.init(project=wandbProject)
                self.wandb.run.name = wandbRunName

        if self.use_local:
            self.localPath = localPath
            self.localName = localName

        self.console = console


    def init(self, network, automaton, num_iters, batchsize, lr, **kwargs):
        #get network configuration
        self.net_config = network.get_config()
        #get automaton configuration
        self.auto_config = automaton.get_config()
        #get training configuration
        self.train_config = {'num_iters': num_iters, 'batchsize': batchsize, 'lr': lr, **kwargs}

        if self.use_wandb:
            self.wandb.config.update(self.net_config)
            self.wandb.config.update(self.auto_config)
            self.wandb.config.update(self.train_config)

        if self.use_local:
            #save as json rows
            with open(self.localPath + self.localName + '.json', 'w') as f:
                f.write(f'network: {self.net_config}\n')
                f.write(f'automaton: {self.auto_config}\n')
                f.write(f'training: {self.train_config}\n')


    def add_config_values(self, **kwargs):
        if self.use_wandb:
            self.wandb.config.update(kwargs)

    def log(self, indx, **kwargs):

        if self.use_wandb:
            self.wandb.log({'iteration': indx, **kwargs}, step=indx)

        if self.use_local:
            with open(self.localPath + self.localName + '.json', 'a') as f:
                f.write(f'iteration{indx}:, {kwargs}\n')

        if self.console:
            print(f'iteration{indx}:, {kwargs}')

    def log_prediction(self, x, prediction, y):

        x = x[0][0]
        prediction=prediction[0]
        y = y[0]     
        print(x.shape, prediction.shape, y.shape)
        if self.use_wandb:
            self.wandb.log({'images/start': wandb.Image(x), 'images/prediction': wandb.Image(prediction), 'images/labels': wandb.Image(y)})

    def log_acc_positions(self, x, prediction, y):
        import numpy as np
        corr = (prediction==y)
        corr_heatmap = np.mean(corr, axis=0)
       

        if self.use_wandb:
            self.wandb.log({'images/acc_positions': wandb.Image(corr_heatmap)})

 



    def log_accuracy_comparison(self, accuracies):

        acc_wandb = {f'accuracy/{k}': v for k, v in accuracies.items()}
        if self.use_wandb:
            self.wandb.log(acc_wandb)
        if self.use_local:
            with open(self.localPath + self.localName + '_acc.json', 'a') as f:
                f.write(f'accuracy_comparison: {accuracies}\n')
        
        if self.console:
            print(f'accuracy_comparison: {accuracies}')

        

      
    def finish(self, final_ps):

        if self.use_wandb:
            self.wandb.config.update({'final_perturbation_sensitivity': final_ps})
            self.wandb.finish()

        if self.use_local:
            with open(self.localPath + self.localName + '.json', 'a') as f:
                f.write(f'final_perturbation_sensitivity: {final_ps}\n')

        if self.console:
            print(f'final_perturbation_sensitivity: {final_ps}')



class Trainer():


    def __init__(self, network, automaton, data_logger):
        self.network = network
        self.automaton = automaton
        self.data_logger = data_logger
        


    def get_comparison(self):
        import numpy as np
        trainX, trainY = self.automaton.get_batch(32)
        #print(trainX)
        #print(trainY)
        
        testX, testY = self.automaton.get_batch(16)
        num_timesteps = self.automaton.time_factor
        patch_size = 1+2*num_timesteps

        #majority voting
        best_color = torch.mean(trainY.float())
        best_color = best_color>0.5

        accuracy_majority = torch.mean((trainY==best_color).float())


        
        
        patches = np.zeros((0, patch_size*patch_size))
        labels = []

        from einops import rearrange
        trainX = rearrange(trainX, 'b d w h  -> b w h d') 
        #print(trainX.shape)
        #print(patch_size)
        #patch trainX into patches of patch_size
        for i in range(trainX.shape[1]-patch_size):
            for j in range(trainX.shape[2]-patch_size):
                patch = trainX[:,i:i+patch_size, j:j+patch_size, 1]
                patch = patch.reshape(-1, patch_size*patch_size)
                if i==0 and j==0:
                    patches = patch
                else:
                    patches = np.concatenate((patches, patch), axis=0)

                label = trainY[:,i+patch_size//2, j+patch_size//2]
                labels.append(label)
        labels = np.array(labels)
        labels = labels.reshape(-1)
        #print(labels.shape)

            #do same for testX
        testX = rearrange(testX, 'b d w h  -> b w h d') 
        #print(testX.shape)
        #print(patch_size)
        test_patches = np.zeros((0, patch_size*patch_size))
        test_labels = []
        #patch testX into patches of patch_size
        for i in range(testX.shape[1]-patch_size):
            for j in range(testX.shape[2]-patch_size):
                patch = testX[:,i:i+patch_size, j:j+patch_size, 1]
                patch = patch.reshape(-1, patch_size*patch_size)
                if i==0 and j==0:
                    test_patches = patch
                else:
                    test_patches = np.concatenate((test_patches, patch), axis=0)

                label = testY[:,i+patch_size//2, j+patch_size//2]
                test_labels.append(label)
        test_labels = np.array(test_labels)
        test_labels = test_labels.reshape(-1)
        #print(test_labels.shape)     
        from sklearn.linear_model import LogisticRegression
        try:
            clf = LogisticRegression(max_iter=1000)
            clf.fit(patches, labels)

            test_accuracy = clf.score(test_patches, test_labels)
        except:
            test_accuracy = accuracy_majority

        print(f"Majority voting accuracy: {accuracy_majority}")
        print(f"Logistic regression accuracy: {test_accuracy}")
        #assert False
        return {"majority_voting": accuracy_majority, "logistic_regression": test_accuracy}

        



    def train(self, num_iters, batchsize, lr, early_stopping=None, nat_its=0):
        nat_its=nat_its

        self.data_logger.init(self.network, self.automaton, num_iters, batchsize, lr, optim="Adam")
        
        self.data_logger.add_config_values(naturalize_iterations=nat_its)

        optimizer = torch.optim.Adam(self.network.parameters(), lr=lr, eps=1e-4)

        for i in range(num_iters):
            optimizer.zero_grad()
            batchX, batchY = self.automaton.get_batch(batchsize,naturalize_iterations=nat_its)
            batchX = batchX.to("cuda")
            batchY = batchY.to("cuda")

            prediction, loss, accuracy = self.network.forward(batchX, labels=batchY)

            loss.backward()
            optimizer.step()
            self.data_logger.log(i, loss=loss.item(), accuracy=accuracy)
            if early_stopping is not None:
                if accuracy > early_stopping:
                    print(f"Early stopping at iteration {i} with accuracy {accuracy}")
                    break
        #print(prediction)
        self.data_logger.log_prediction(batchX.detach().cpu().numpy(), torch.argmax(prediction,dim=1).detach().cpu().numpy(), batchY.detach().cpu().numpy())
        self.data_logger.log_acc_positions(batchX.detach().cpu().numpy(), torch.argmax(prediction,dim=1).detach().cpu().numpy(), batchY.detach().cpu().numpy())
        
        final_ps = self.network.calc_perturbation_sensitivity()


        accuracies = self.get_comparison()
        
        accuracies["CNN"]= accuracy
        accuracies["CNN_diff_logistic"] = accuracy - accuracies["logistic_regression"]
        accuracies["CNN_diff_majority"] = accuracy - accuracies["majority_voting"]

        self.data_logger.log_accuracy_comparison(accuracies)

        self.data_logger.finish(final_ps)

        
class SGDTrainer(Trainer):

    def train(self, num_iters, batchsize, lr):

        self.data_logger.init(self.network, self.automaton, num_iters, batchsize, lr, optim="SGD")

        optimizer = torch.optim.SGD(self.network.parameters(), lr=lr)

        for i in range(num_iters):
            optimizer.zero_grad()
            batchX, batchY = self.automaton.get_batch(batchsize)
            batchX = batchX.to("cuda")
            batchY = batchY.to("cuda")

            prediction, loss, accuracy = self.network.forward(batchX, labels=batchY)

            loss.backward()
            optimizer.step()
            self.data_logger.log(i, loss=loss.item(), accuracy=accuracy)
        #print(prediction)
        self.data_logger.log_prediction(batchX.detach().cpu().numpy(), torch.argmax(prediction,dim=1).detach().cpu().numpy(), batchY.detach().cpu().numpy())

        final_ps = self.network.calc_perturbation_sensitivity()
        self.data_logger.finish(final_ps)
