# -*- coding: utf-8 -*-
"""
Domain CIFAR-100 Reset
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu, avg_pool2d
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, RandomSampler
from torch.distributions import Categorical
import numpy as np
import random
import pickle

# --------------------
# Device
# --------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --------------------
# Transforms
# --------------------
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616))
])

# --------------------
# CIFAR-100 datasets
# --------------------
train_ds = datasets.CIFAR100(root="./data", train=True, download=True,
                             transform=train_transform)
test_ds  = datasets.CIFAR100(root="./data", train=False, download=True,
                             transform=test_transform)

# Tensors of fine labels
train_fine = torch.tensor(train_ds.targets)
test_fine  = torch.tensor(test_ds.targets)

# --------------------
# Superclass taxonomy
# --------------------
superclass_to_fines = {
    0: [4, 30, 55, 72, 95],   # aquatic mammals
    1: [1, 32, 67, 73, 91],   # fish
    2: [54, 62, 70, 82, 92],  # flowers
    3: [9, 10, 16, 28, 61],   # food containers
    4: [0, 51, 53, 57, 83],   # fruit and vegetables
    5: [22, 39, 40, 86, 87],  # household electrical devices
    6: [5, 20, 25, 84, 94],   # household furniture
    7: [6, 7, 14, 18, 24],    # insects
    8: [3, 42, 43, 88, 97],   # large carnivores
    9: [12, 17, 37, 68, 76],  # large man-made outdoor things
    10:[23, 33, 49, 60, 71],  # large natural outdoor scenes
    11:[15, 19, 21, 31, 38],  # large omnivores and herbivores
    12:[34, 63, 64, 66, 75],  # medium-sized mammals
    13:[26, 45, 77, 79, 99],  # non-insect invertebrates
    14:[2, 11, 35, 46, 98],   # people
    15:[27, 29, 44, 78, 93],  # reptiles
    16:[36, 50, 65, 74, 80],  # small mammals
    17:[47, 52, 56, 59, 96],  # trees
    18:[8, 13, 48, 58, 90],   # vehicles 1
    19:[41, 69, 81, 85, 89],  # vehicles 2
}

# fine->coarse
fine_to_coarse_vec = torch.empty(100, dtype=torch.long)
for coarse, fines in superclass_to_fines.items():
    for f in fines:
        fine_to_coarse_vec[f] = coarse

# --------------------
# --------------------
def build_three_tasks_and_accumulated():
    tasks_fine = []
    for t in range(3):
        task_fines = [superclass_to_fines[c][t] for c in range(20)]
        tasks_fine.append(sorted(task_fines))

    # Print per-task fine classes
    for k, fins in enumerate(tasks_fine, 1):
        print(f"[Per-task] Context {k} fine classes (20 ids): {fins}")

    # Build accumulated contexts
    contexts_cumulative = []
    seen = set()
    for t_fines in tasks_fine:
        seen.update(t_fines)
        contexts_cumulative.append(sorted(seen))

    # Print accumulated sets
    for k, fins in enumerate(contexts_cumulative, 1):
        print(f"[Accumulated] Context {k} fine classes ({len(fins)} ids): {fins}")

    return tasks_fine, contexts_cumulative

tasks_fine, contexts = build_three_tasks_and_accumulated()

# --------------------
# Model
# --------------------
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2   = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion)
            )

    def forward(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, nf, global_pooling, input_size):
        super().__init__()
        self.global_pooling = global_pooling
        self.in_planes = nf
        self.conv1 = conv3x3(3, nf)
        self.bn1   = nn.BatchNorm2d(nf)
        self.layer1 = self._make_layer(block, nf,     num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)

        input_size = tuple(input_size[-3:])
        if input_size == (3,32,32):
            self.feature_size = 160  if global_pooling else 2560
        elif input_size == (3,84,84):
            self.feature_size = 640  if global_pooling else 19360
        elif input_size == (3,96,96):
            self.feature_size = 1440 if global_pooling else 23040
        else:
            raise ValueError(f"Input size err: {input_size}")

    def _make_layer(self, block, planes, n_blocks, stride):
        strides = [stride] + [1]*(n_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        if self.global_pooling:
            out = avg_pool2d(out, 4)
        return out.view(out.size(0), -1)

def ResNet18feat(input_size, nf=20, global_pooling=True):
    return ResNet(BasicBlock, [2,2,2,2], nf, global_pooling, input_size)

class GainModLinear(nn.Module):
    def __init__(self, in_dim, out_dim, g0=1.0, gamma=0.9, eta=0.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim) * 0.05)
        self.bias   = nn.Parameter(torch.zeros(out_dim))
        self.gains  = nn.Parameter(torch.ones(out_dim) * g0, requires_grad=False)
        self.g0, self.gamma, self.eta = g0, gamma, eta

    def decay(self, H_drive: float = 0.0):
        with torch.no_grad():
            self.gains.mul_(self.gamma).add_((1 - self.gamma) * self.g0).add_(self.eta * H_drive)

    def forward(self, x, H_drive: float = 0.0):
        if self.training:
            self.decay(H_drive)
        W_scaled = self.weight * self.gains.view(-1, 1)
        return F.linear(x, W_scaled, self.bias)

class GainSlimResNet18(nn.Module):
    def __init__(self, input_size=(3,32,32), nf=20, output_dim=20, g0=1.0, gamma=0.9, eta=0.0):
        super().__init__()
        self.backbone = ResNet18feat(input_size, nf, global_pooling=True)
        feat_dim = self.backbone.feature_size
        self.fc  = GainModLinear(feat_dim, output_dim, g0=g0, gamma=gamma, eta=eta)

    def forward(self, x, H_prev: float = 0.0):
        feats = self.backbone(x)
        return self.fc(feats, H_prev)

# --------------------
# Data loader
# --------------------
def make_loader_for_task(ds, fine_labels_tensor, task_fines, batch, ctx_iter=None, train=False):
    mask = torch.isin(fine_labels_tensor, torch.tensor(task_fines))
    idx  = torch.nonzero(mask, as_tuple=False).squeeze()
    subset = Subset(ds, idx)
    if train:
        sampler = RandomSampler(subset, replacement=True, num_samples=ctx_iter * batch)
        return DataLoader(subset, batch_size=batch, sampler=sampler)
    else:
        return DataLoader(subset, batch_size=batch, shuffle=False)

# --------------------
# Continual eval
# --------------------
def continual_train(model, optimizer, criterion,
                    contexts_fine, batch, ctx_iter, rho_eval,
                    train_ds, test_ds, device, mode, reset_optimizer=False):

    hist = {k: [] for k in [
        "acc_train", "loss_train",
        "acc_test",  "loss_test",
        "w_out", "gain_out"
    ]}

    task1_fines = tasks_fine[0]
    eval_loader_task1 = make_loader_for_task(test_ds, test_fine, task1_fines, batch, train=False)

    step = 0
    H_prev = 0.0

    for task_id, task_fines_union in enumerate(contexts_fine):
        print(f"\n[Train] Context {task_id+1}/{len(contexts_fine)} with {len(task_fines_union)} fine ids.")

        # --- ORACLE RESET ---
        if reset_optimizer and task_id > 0:
            optimizer.state.clear()

        # Train loader
        train_loader = make_loader_for_task(train_ds, train_fine, task_fines_union, batch, ctx_iter, train=True)
        train_iter = iter(train_loader)

        model.train()
        for _ in range(ctx_iter):
            inputs, y_fine = next(train_iter)
            inputs = inputs.to(device)
            y_coarse = fine_to_coarse_vec[y_fine].to(device)

            optimizer.zero_grad()
            outputs = model(inputs, H_prev)
            loss = criterion(outputs, y_coarse)
            loss.backward()
            optimizer.step()

            if mode == "gain-H":
                with torch.no_grad():
                    dist    = Categorical(logits=outputs)
                    H_prev = dist.entropy().mean().item()

            # Train
            with torch.no_grad():
                preds = outputs.argmax(1)
                acc   = (preds == y_coarse).float().mean().item()
                hist["acc_train"].append(acc)
                hist["loss_train"].append(loss.item())
                hist["w_out"].append(model.fc.weight.data.abs().mean().item())
                hist["gain_out"].append(model.fc.gains.mean().item())

            # Evaluation
            if step % rho_eval == 0:
                model.eval()
                total, correct, loss_sum = 0, 0, 0.0
                with torch.no_grad():
                    for x_t, y_fine_t in eval_loader_task1:
                        x_t = x_t.to(device)
                        y_coarse_t = fine_to_coarse_vec[y_fine_t].to(device)
                        logits = model(x_t, H_prev)
                        loss_ev = criterion(logits, y_coarse_t)
                        loss_sum += loss_ev.item() * y_coarse_t.size(0)
                        correct  += (logits.argmax(1) == y_coarse_t).sum().item()
                        total    += y_coarse_t.size(0)

                hist["acc_test"].append(100.0 * correct / total)
                hist["loss_test"].append(loss_sum / total)
                model.train()

            step += 1

    return hist

# --------------------
# Hyperparameters
# --------------------
batch      = 256
lr_weights = 0.01
rho_eval   = 1
ctx_iter   = 800
criterion  = nn.CrossEntropyLoss()

# --------------------
# Run all methods
# --------------------
def run_all_models_nSims(nSims, base_seed=0):
    all_results = []

    for sim in range(nSims):
        seed = base_seed + sim
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        print(f"\n=== Simulation {sim+1}/{nSims}, seed={seed} ===")

        # ---  ENTROPY GAIN ---
        print("ENTROPY GAIN")
        hgainModel = GainSlimResNet18(output_dim=20, nf=20, g0=1.0, gamma=0.9, eta=0.1).to(device)
        hgainOptimizer = optim.SGD([
           {'params': [p for n, p in hgainModel.named_parameters()
                       if 'gains' not in n and not n.startswith('fc.')],
            'lr': lr_weights, 'momentum': 0.9},
           {'params': [p for n, p in hgainModel.named_parameters()
                       if n == 'fc.weight'],
            'lr': lr_weights, 'momentum': 0.0}
        ])
        hgain_hist = continual_train(
            hgainModel, hgainOptimizer, criterion,
            contexts_fine=contexts, batch=batch, ctx_iter=ctx_iter,
            rho_eval=rho_eval, train_ds=train_ds, test_ds=test_ds,
            device=device, mode="gain-H", reset_optimizer=False
        )

        # --- MOMENTUM SGD ---
        print("MOMENTUM SGD")
        msgdModel = GainSlimResNet18(output_dim=20, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdOptimizer = optim.SGD(
            [{'params': [p for n, p in msgdModel.named_parameters()
                         if 'gains' not in n and n != 'fc.bias'],
              'lr': lr_weights}],
            momentum=0.9
        )
        msgd_hist = continual_train(
            msgdModel, msgdOptimizer, criterion,
            contexts_fine=contexts, batch=batch, ctx_iter=ctx_iter,
            rho_eval=rho_eval, train_ds=train_ds, test_ds=test_ds,
            device=device, mode="no-gain", reset_optimizer=False
        )
        
        # --- MOMENTUM SGD RESET ---
        print("MOMENTUM SGD RESET")
        msgdResetModel = GainSlimResNet18(output_dim=20, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdResetOptimizer = optim.SGD(
            [{'params': [p for n, p in msgdResetModel.named_parameters()
                         if 'gains' not in n and n != 'fc.bias'],
              'lr': lr_weights}],
            momentum=0.9
        )
        msgd_reset_hist = continual_train(
            msgdResetModel, msgdResetOptimizer, criterion,
            contexts_fine=contexts, batch=batch, ctx_iter=ctx_iter,
            rho_eval=rho_eval, train_ds=train_ds, test_ds=test_ds,
            device=device, mode="no-gain", reset_optimizer=True
        )

        # --- VANILLA ADAM ---
        print("VANILLA ADAM")
        adamModel = GainSlimResNet18(output_dim=20, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamOptimizer = optim.Adam(
            [{'params': [p for n, p in adamModel.named_parameters()
                         if 'gains' not in n and n != 'fc.bias'],
              'lr': 0.1 * lr_weights}]
        )
        adam_hist = continual_train(
            adamModel, adamOptimizer, criterion,
            contexts_fine=contexts, batch=batch, ctx_iter=ctx_iter,
            rho_eval=rho_eval, train_ds=train_ds, test_ds=test_ds,
            device=device, mode="no-gain", reset_optimizer=False
        )

        # --- VANILLA ADAM RESET ---
        print("VANILLA ADAM RESET")
        adamResetModel = GainSlimResNet18(output_dim=20, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamResetOptimizer = optim.Adam(
            [{'params': [p for n, p in adamResetModel.named_parameters()
                         if 'gains' not in n and n != 'fc.bias'],
              'lr': 0.1 * lr_weights}]
        )
        adam_reset_hist = continual_train(
            adamResetModel, adamResetOptimizer, criterion,
            contexts_fine=contexts, batch=batch, ctx_iter=ctx_iter,
            rho_eval=rho_eval, train_ds=train_ds, test_ds=test_ds,
            device=device, mode="no-gain", reset_optimizer=True
        )

        # --- SGD ---
        print("SGD")
        sgdModel = GainSlimResNet18(output_dim=20, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        sgdOptimizer = optim.SGD(
            [{'params': [p for n, p in sgdModel.named_parameters()
                         if 'gains' not in n and n != 'fc.bias'],
              'lr': 10 * lr_weights}],
            momentum=0.0
        )
        sgd_hist = continual_train(
            sgdModel, sgdOptimizer, criterion,
            contexts_fine=contexts, batch=batch, ctx_iter=ctx_iter,
            rho_eval=rho_eval, train_ds=train_ds, test_ds=test_ds,
            device=device, mode="no-gain", reset_optimizer=False
        )

        results = {
            "ENTROPY GAIN": hgain_hist,
            "ADAM":         adam_hist,
            "ADAM reset":   adam_reset_hist,
            "MSGD":         msgd_hist,
            "MSGD reset":   msgd_reset_hist,
            "SGD":          sgd_hist,
        }
        all_results.append({"seed": seed, "results": results})

    return all_results

# --------------------
# Run
# --------------------
all_sims = run_all_models_nSims(nSims=5, base_seed=42)

# Save
out_pkl = "Rex5_domainCIFAR100_reset2.pkl"
with open(out_pkl, "wb") as f:
    pickle.dump(all_sims, f)
print(f"Done! Saved to {out_pkl}")