# -*- coding: utf-8 -*-
"""
Rotated MNIST Reset
"""

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, RandomSampler, ConcatDataset
from torch.distributions import Categorical
import numpy as np
import random
import pickle

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load data
class RotatableMNIST(datasets.MNIST):
    def __init__(self, root, train=True, download=False, transform=None):
        super().__init__(root=root, train=train, download=download,
                         transform=transform)
        self._base_transform = transform

    def rot(self, angle):
        """
        Rotations
        """
        if isinstance(self._base_transform, transforms.Compose):
            base_ts = self._base_transform.transforms
        elif self._base_transform:
            base_ts = [self._base_transform]
        else:
            base_ts = []
        new_transform = transforms.Compose([transforms.Lambda(lambda img: TF.rotate(img, angle))] + base_ts)
        
        return RotatableMNIST(
            root=self.root,
            train=self.train,
            download=self.download,
            transform=new_transform
        )

base_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
# train dataset
train_ds = RotatableMNIST(root="./data",train=True, download=True, transform=base_transform)
# test dataset
test_ds = RotatableMNIST(root="./data", train=False, download=True, transform=base_transform)

## Network Model
# --- Layers ---
class GainModLayer(nn.Module):
    def __init__(self, input_dim, output_dim, g0=1.0, gamma=0.9, eta:float=0.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * 0.01)
        self.bias   = nn.Parameter(torch.zeros(output_dim))
        self.gains  = nn.Parameter(torch.ones(output_dim), requires_grad=False)
        self.g0     = g0
        self.gamma  = gamma
        self.eta    = eta

    def decay(self, H_drive: float = 0.0):
        with torch.no_grad():
            self.gains.mul_(self.gamma).add_((1 - self.gamma) * self.g0).add_(self.eta * H_drive)
    
    def forward(self, x, H_drive: float = 0.0):
        if self.training: self.decay(H_drive)
        z = (x @ self.weight.t()) * self.gains + self.bias      
        return torch.relu(z)


class GainModOut(nn.Module):
    def __init__(self, input_dim, output_dim, g0=1.0, gamma=0.9, eta:float=0.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * 0.05)
        self.bias   = nn.Parameter(torch.zeros(output_dim))
        self.gains  = nn.Parameter(torch.ones(output_dim), requires_grad=False)
        self.g0     = g0
        self.gamma  = gamma
        self.eta    = eta
    
    def decay(self, H_drive: float = 0.0):
        with torch.no_grad():
            self.gains.mul_(self.gamma).add_((1 - self.gamma) * self.g0).add_(self.eta * H_drive)
     
    
    def forward(self, x, H_drive: float = 0.0):
        if self.training: self.decay(H_drive)
        z = (x @ self.weight.t()) * self.gains + self.bias      
        return z

# --- Network ---
class GainModNet(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dim=400, output_dim=10, g0=1.0, gamma=0.9, eta:float=0.0):
        super().__init__()
        self.layer1 = GainModLayer(input_dim,  hidden_dim, g0, gamma, eta)
        self.layer2 = GainModLayer(hidden_dim, hidden_dim, g0, gamma, eta)
        self.out    = GainModOut(hidden_dim,  output_dim, g0, gamma, eta)

    def forward(self, x, H_drive: float = 0.0):
        x  = x.view(x.size(0), -1)
        h1 = self.layer1(x, H_drive)
        h2 = self.layer2(h1, H_drive)
        y  = self.out(h2, H_drive)
        return y

## Continual evaluation
def continual_train(model, optimizer, criterion, contexts, batch, ctx_iter, rho_eval, train_ds, test_ds, device, mode, reset_optimizer=False):

    # --- DATA STORAGE ---
    hist = {k: [] for k in [
        "acc_train", "loss_train",
        "acc_test",  "loss_test",
        "w1", "w2", "w_out",
        "gain1", "gain2", "gain_out"
    ]}
    eval_loaders = {}
    step = 0
    H_prev = 0.0

    # --- CONTEXT LOOP ---
    for task_id, ctx in enumerate(contexts):
        print(f"\nContext {task_id+1}/{len(contexts)}: {ctx}")

        # --- ORACLE RESET ---
        if reset_optimizer and task_id > 0:
            optimizer.state.clear()

        # --------------------- TRAIN LOADER --------------------------
        train_rots = [train_ds.rot(angle) for angle in ctx]
        train_subsets = ConcatDataset(train_rots)
        sampler = RandomSampler(train_subsets,replacement=True,num_samples=ctx_iter * batch)
        train_loader = DataLoader(train_subsets,batch_size=batch,sampler=sampler)
        train_iter = iter(train_loader)

        # --------------------- TEST LOADER ---------------------------
        test_rots = [test_ds.rot(angle) for angle in [0]] # Always test on Task 1 (0 degrees)
        test_subsets = ConcatDataset(test_rots)
        test_loader  = DataLoader(test_subsets, batch_size=batch, shuffle=False)
        eval_loaders[task_id] = test_loader
        
        # --------------------- TRAINING ------------------------------
        model.train()
        for _ in range(ctx_iter):
            
            inputs, targets = next(train_iter)
            inputs, targets = inputs.to(device), targets.to(device)

            # --- forward + backward + update ---
            optimizer.zero_grad()
            outputs = model(inputs.view(inputs.size(0), -1), H_prev)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # compute entropy 
            if mode == "gain-H":
                with torch.no_grad():
                    dist    = Categorical(logits=outputs)
                    H_prev = dist.entropy().mean().item()

            # --- train metrics ---
            hist["acc_train"].append( (outputs.argmax(1) == targets).float().mean().item() )
            hist["loss_train"].append(loss.item())
            # --- weights and gains stats ---
            hist["w1"].append(model.layer1.weight.data.abs().mean().item())
            hist["w2"].append(model.layer2.weight.data.abs().mean().item())
            hist["w_out"].append(model.out.weight.data.abs().mean().item())
            hist["gain1"].append(model.layer1.gains.mean().item())
            hist["gain2"].append(model.layer2.gains.mean().item())
            hist["gain_out"].append(model.out.gains.mean().item())

            # ---------------- CONTINUAL EVAL --------------------------
            if step % rho_eval == 0:
                model.eval()
                total, correct, loss_sum = 0, 0, 0.0
                with torch.no_grad():
                    for ev_loader in eval_loaders.values():
                        for x_t, y_t in ev_loader:
                            x_t, y_t = x_t.to(device), y_t.to(device)
                            logits = model(x_t.view(x_t.size(0), -1), H_prev)
                            loss_ev = criterion(logits, y_t)

                            loss_sum += loss_ev.item() * y_t.size(0)
                            correct  += (logits.argmax(1) == y_t).sum().item()
                            total    += y_t.size(0)
                            
                # --- test metrics ---
                hist["acc_test"].append(100.0 * correct / total)
                hist["loss_test"].append(loss_sum / total)
                model.train()

            step += 1

    return hist

# --- Hyperparameters ---
batch = 128 
lr_weights = 0.01
rho_eval = 1
ctx_iter = 400 
contexts = [[0], [0,80], [0,80,160]]
criterion = nn.CrossEntropyLoss()

## Simulation
def run_all_models_nSims(nSims, base_seed=0):
    all_results = []

    for sim in range(nSims):
        seed = base_seed + sim
        # --- Seeds ---
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        print(f"\n=== Sim {sim+1}/{nSims}, seed={seed} ===")
        
        # -------------------------------------------------------
        # 1. ENTROPY GAIN MODEL (NGM-SGD)
        # -------------------------------------------------------
        print("ENTROPY GAIN MODEL")
        hgainModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                g0=1.0, gamma=0.9, eta=0.5).to(device)
        hgainOptimizer = optim.SGD([
           {'params': [param for name, param in hgainModel.named_parameters() if 'weight' in name],
            'lr': lr_weights}, # 1.0 * lr
           {'params': [param for name, param in hgainModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.0)
        hgain_hist = continual_train(hgainModel, hgainOptimizer, criterion,
                                     contexts, batch, ctx_iter, rho_eval,
                                     train_ds, test_ds, device, mode="gain-H", reset_optimizer=False)
        
        # -------------------------------------------------------
        # 2. MSGD (Standard)
        # -------------------------------------------------------
        print("MOMENTUM SGD")
        msgdModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                               g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdOptimizer = optim.SGD([
           {'params': [param for name, param in msgdModel.named_parameters() if 'weight' in name],
            'lr': 10*lr_weights}, # 10.0 * lr
           {'params': [param for name, param in msgdModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.9)
        msgd_hist = continual_train(msgdModel, msgdOptimizer, criterion,
                                    contexts, batch, ctx_iter, rho_eval,
                                    train_ds, test_ds, device, mode="no-gain", reset_optimizer=False)

        # -------------------------------------------------------
        # 3. MSGD RESET
        # -------------------------------------------------------
        print("MOMENTUM SGD RESET")
        msgdResetModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                    g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdResetOptimizer = optim.SGD([
           {'params': [param for name, param in msgdResetModel.named_parameters() if 'weight' in name],
            'lr': 10*lr_weights}, # 10.0 * lr
           {'params': [param for name, param in msgdResetModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.9)
        msgd_reset_hist = continual_train(msgdResetModel, msgdResetOptimizer, criterion,
                                          contexts, batch, ctx_iter, rho_eval,
                                          train_ds, test_ds, device, mode="no-gain", reset_optimizer=True)

        # -------------------------------------------------------
        # 4. ADAM (Standard)
        # -------------------------------------------------------
        print("VANILLA ADAM")
        adamModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamOptimizer = optim.Adam([
            {'params': [param for name, param in adamModel.named_parameters() if 'weight' in name],
             'lr': 0.1*lr_weights}, # 0.1 * lr
            {'params': [param for name, param in adamModel.named_parameters() if 'weight' not in name],
             'lr': 0},
        ])
        adam_hist = continual_train(adamModel, adamOptimizer, criterion,
                                    contexts, batch, ctx_iter, rho_eval,
                                    train_ds, test_ds, device, mode="no-gain", reset_optimizer=False)
        
        # -------------------------------------------------------
        # 5. ADAM RESET
        # -------------------------------------------------------
        print("VANILLA ADAM RESET")
        adamResetModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                    g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamResetOptimizer = optim.Adam([
            {'params': [param for name, param in adamResetModel.named_parameters() if 'weight' in name],
             'lr': 0.1*lr_weights}, # 0.1 * lr
            {'params': [param for name, param in adamResetModel.named_parameters() if 'weight' not in name],
             'lr': 0},
        ])
        adam_reset_hist = continual_train(adamResetModel, adamResetOptimizer, criterion,
                                          contexts, batch, ctx_iter, rho_eval,
                                          train_ds, test_ds, device, mode="no-gain", reset_optimizer=True)

        # -------------------------------------------------------
        # 6. SGD (Vanilla)
        # -------------------------------------------------------
        print("SGD")
        sgdModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                               g0=1.0, gamma=0.0, eta=0.0).to(device)
        sgdOptimizer = optim.SGD([
           {'params': [param for name, param in sgdModel.named_parameters() if 'weight' in name],
            'lr': 10*lr_weights}, # 10.0 * lr
           {'params': [param for name, param in sgdModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.0)
        sgd_hist = continual_train(sgdModel, sgdOptimizer, criterion,
                                   contexts, batch, ctx_iter, rho_eval,
                                   train_ds, test_ds, device, mode="no-gain", reset_optimizer=False)

        # Storage
        results = {
           "ENTROPY GAIN": hgain_hist,
           "ADAM":         adam_hist,
           "ADAM reset":   adam_reset_hist,
           "MSGD":         msgd_hist,
           "MSGD reset":   msgd_reset_hist,
           "SGD":          sgd_hist,
        }
        all_results.append({"seed": seed, "results": results})

    return all_results

## Run
all_sims = run_all_models_nSims(nSims=5, base_seed=42)

# Save
fname = "Rex4_rotatedMNIST_reset.pkl"
with open(fname, "wb") as f:
    pickle.dump(all_sims, f)
print(f"Done! Saved to {fname}")