import csv
import itertools
import os
import time
from collections import defaultdict
import copy

import pathlib
from pathlib import Path
from datetime import date
import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.distributions import Categorical
import torch.nn.functional as F
import torch.optim as optim
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model
import wandb
import pandas as pd
from PIL import Image
from timm.utils import NativeScaler
from timm.optim import create_optimizer

import utils
from dataset import get_loaders
from engine import train_one_epoch, evaluate, layer_select
from EMSA import EMSA_optimizer
import random

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
@torch.no_grad()
def test(model, loader, criterion, cfg):
    model.eval()
    all_test_corrects = []
    total_loss = 0.0
    for x, y in loader:
        x, y = x.to(cfg.args.device), y.to(cfg.args.device)
        logits = model(x)
        # for name, layer in model.named_children():
        #     if name == 'fc':
        #         x = torch.nn.functional.avg_pool2d(x, 8)
        #         x = x.view(-1, 640)
        #     x = layer(x)

        #     print(f"After {name}: {x.shape}")
        # logits = x
        loss = criterion(logits, y)
        all_test_corrects.append(torch.argmax(logits, dim=-1) == y)
        total_loss += loss
    acc = torch.cat(all_test_corrects).float().mean().detach().item()
    total_loss = total_loss / len(loader)
    total_loss = total_loss.detach().item()
    return acc, total_loss

def get_lr_weights(model, loader, cfg):
    layer_names = [
        n for n, _ in model.named_parameters() if "bn" not in n
    ] 
    metrics = defaultdict(list)
    average_metrics = defaultdict(float)
    partial_loader = itertools.islice(loader, 5)
    xent_grads, entropy_grads = [], []
    for x, y in partial_loader:
        x, y = x.to(cfg.args.device), y.to(cfg.args.device)
        logits = model(x)

        loss_xent = F.cross_entropy(logits, y)
        grad_xent = torch.autograd.grad(
            outputs=loss_xent, inputs=model.parameters(), retain_graph=True
        )
        xent_grads.append([g.detach() for g in grad_xent])

    def get_grad_norms(model, grads, cfg):
        _metrics = defaultdict(list)
        grad_norms, rel_grad_norms = [], []
        for (name, param), grad in zip(model.named_parameters(), grads):
            if name not in layer_names:
                continue
            if cfg.args.auto_tune == 'eb-criterion':
                tmp = (grad*grad) / (torch.var(grad, dim=0, keepdim=True)+1e-8)
                _metrics[name] = tmp.mean().item()
            else:
                _metrics[name] = torch.norm(grad).item() / torch.norm(param).item()

        return _metrics

    for xent_grad in xent_grads:
        xent_grad_metrics = get_grad_norms(model, xent_grad, cfg)
        for k, v in xent_grad_metrics.items():
            metrics[k].append(v)
    for k, v in metrics.items():
        average_metrics[k] = np.array(v).mean(0)
    return average_metrics

def train(model, loader, criterion, opt, cfg, orig_model=None):
    all_train_corrects = []
    total_loss = 0.0
    magnitudes = defaultdict(float)

    for x, y in loader:
        x, y = x.to(cfg.args.device), y.to(cfg.args.device)
        logits = model(x)
        loss = criterion(logits, y)
        all_train_corrects.append(torch.argmax(logits, dim=-1) == y)
        total_loss += loss

        opt.zero_grad()
        loss.backward()
        opt.step()

    acc = torch.cat(all_train_corrects).float().mean().detach().item()
    total_loss = total_loss / len(loader)
    total_loss = total_loss.detach().item()
    return acc, total_loss, magnitudes

def tune_params_dict_func(cfg, model):
    if cfg.data.dataset_name == "cifar10":
        tune_params_dict = {
            "all": [model.parameters()],
            "first_two_block": [
                model.conv1.parameters(),
                model.block1.parameters(),
            ],
            "second_block": [
                model.block2.parameters(),
            ],
            "third_block": [
                model.block3.parameters(),
            ],
            "last": [model.fc.parameters()],
        }
    elif cfg.data.dataset_name == "imagenet-c" or 'living17':
        tune_params_dict = {
            "all": [model.model.parameters()],
            "first_second": [
                model.model.conv1.parameters(),
                model.model.layer1.parameters(),
                model.model.layer2.parameters(),
            ],
            "first_two_block": [
                model.model.conv1.parameters(),
                model.model.layer1.parameters(),
            ],
            "second_block": [
                model.model.layer2.parameters(),
            ],
            "third_block": [
                model.model.layer3.parameters(),
            ],
            "fourth_block": [
                model.model.layer4.parameters(),
            ],
            "last": [model.model.fc.parameters()],
        }
    return tune_params_dict

# @hydra.main(version_base = None, config_path="config", config_name="cifar-10c")
@hydra.main(version_base = None, config_path="config", config_name="imagenet-c")
def main(cfg):
    cfg.args.log_dir = pathlib.Path.cwd()
    cfg.args.log_dir = os.path.join(
        cfg.args.log_dir, "results", cfg.data.dataset_name, date.today().strftime("%Y.%m.%d"), cfg.args.auto_tune
    )
    device = torch.device(cfg.args.device)
    print(f"Log dir: {cfg.args.log_dir}")
    os.makedirs(cfg.args.log_dir, exist_ok=True)

    tune_options = [
        "first_two_block",
        "second_block",
        "third_block",
        "last",
        "all",
    ]
    

    if cfg.data.dataset_name == "imagenet-c" or cfg.data.dataset_name == "living17":
        # tune_options.append("fourth_block")
        tune_options = [
        "first_two_block",
        "second_block",
        "third_block",
        "fourth_block",
        "last",
        "all",
    ]
    tune_options_orig = tune_options.copy()[:-1]

    if cfg.args.auto_tune != 'none':
        tune_options = ["all"]
    if cfg.args.epochs == 0: tune_options = ['all']
    if cfg.data.dataset_name == "living17" or (cfg.data.dataset_name == "cifar10" and cfg.data.flip == True):
        corruption_types = [None]
    else:
        corruption_types = cfg.data.corruption_types
    for corruption_type in corruption_types:
        cfg.wandb.exp_name = f"{cfg.data.dataset_name}_{cfg.args.opt}_{cfg.args.auto_tune}"
        if cfg.wandb.use:
            utils.setup_wandb(cfg)
        utils.set_seed_everywhere(cfg.args.seed)
        loaders = get_loaders(cfg, corruption_type, cfg.data.severity)

        for num, tune_option in enumerate(tune_options):
            if cfg.args.layerwise != 'None' and  num != cfg.args.layerwise:
                continue
            tune_metrics = defaultdict(list)
            lr_wd_grid = [
                # (1e-1, 1e-4),
                # (1e-2, 1e-4),
                # (1e-3, 1e-4),
                (1e-4, 1e-4),
                # (1e-5, 1e-4),
            ]
            for lr, wd in lr_wd_grid:
                dataset_name = (
                    "imagenet"
                    if cfg.data.dataset_name == "imagenet-c" or cfg.data.dataset_name == "living17"
                    else cfg.data.dataset_name
                )
                model = load_model(
                    cfg.data.model_name,
                    # cfg.user.ckpt_dir,
                    './',
                    dataset_name,
                    ThreatModel.corruptions,
                )
                if cfg.data.dataset_name == "living17":
                    num_classes = 17
                    model.model.fc = torch.nn.Linear(model.model.fc.in_features, num_classes)
                    model.load_state_dict(torch.load('resnet50_living17_LP.pth'))
                model = model.to(device)

                # if cfg.args.opt != 'EMSA':
                orig_model = copy.deepcopy(model)
                tune_params_dict = tune_params_dict_func(cfg, model)
                    

                params_list = list(itertools.chain(*tune_params_dict[tune_option]))
                N = sum(p.numel() for p in params_list if p.requires_grad)
                layer_weights = [0 for layer, _ in model.named_parameters() if 'bn' not in layer]
                layer_names = [layer for layer, _ in model.named_parameters() if 'bn' not in layer]
                opt = optim.Adam(params_list, lr=lr, weight_decay=wd)



                if cfg.args.opt == 'EMSA':
                    if dataset_name == "cifar10":
                        first_two_block = torch.nn.Sequential(model.conv1, model.block1)
                        all_blocks = [first_two_block, model.block2, model.block3, model.bn1, model.relu, model.fc]
                        # all_blocks = [model.conv1, model.block1, model.block2, model.block3, model.bn1, model.relu, model.fc]
                        # block_final = torch.nn.Sequential(model.layer4, model.avgpool, model.fc)
                        all_blocks = torch.nn.Sequential(*all_blocks)
                        block_id = [0,1,2,5,-1]
                    elif dataset_name == "imagenet":
                        first_two_block = torch.nn.Sequential(model.model.conv1, model.model.bn1, model.model.relu, model.model.maxpool, model.model.layer1)
                        all_blocks = [model.normalize, first_two_block, model.model.layer2, model.model.layer3, model.model.layer4, model.model.avgpool, model.model.fc]
                        all_blocks = torch.nn.Sequential(*all_blocks)
                        block_id = [1,2,3,4,6,-1]                    
                    emsa_optimizer = EMSA_optimizer(args=cfg.args, all_blocks=all_blocks)
                    optimizer = None



                print(
                    f"\nTrain mode={cfg.args.train_mode}, using {cfg.args.train_n} corrupted images for training"
                )
                print(
                    f"Re-training {tune_option} ({N} params). lr={lr}, wd={wd}. Corruption {corruption_type}"
                )

                criterion = F.cross_entropy

                best_val_acc = 0.0
                for epoch in range(1, cfg.args.epochs + 1):
                    if cfg.args.train_mode == "train":
                        model.train()
                    if cfg.args.auto_tune != 'none' and cfg.args.opt != 'EMSA': 
                        if cfg.args.auto_tune == 'RGN':
                            weights = get_lr_weights(model, loaders["train"], cfg)
                            max_weight = max(weights.values())
                            for k, v in weights.items(): 
                                weights[k] = v / max_weight
                            layer_weights = [sum(x) for x in zip(layer_weights, weights.values())]
                            tune_metrics['layer_weights'] = layer_weights
                            params = defaultdict()
                            for n, p in model.named_parameters():
                                if "bn" not in n:
                                    params[n] = p 
                            params_weights = []
                            for param, weight in weights.items():
                                params_weights.append({"params": params[param], "lr": weight*lr})
                            opt = optim.Adam(params_weights, lr=lr, weight_decay=wd)
                        elif cfg.args.auto_tune == 'eb-criterion':
                            # Go by individual layers
                            weights = get_lr_weights(model, loaders["train"], cfg)
                            print(f"Epoch {epoch}, autotuning weights {min(weights.values()), max(weights.values())}")
                            tune_metrics['max_weight'].append(max(weights.values()))
                            tune_metrics['min_weight'].append(min(weights.values()))
                            print(weights.values())
                            for k, v in weights.items(): 
                                weights[k] = 0.0 if v < 0.95 else 1.0
                            print("weight values", weights.values())
                            layer_weights = [sum(x) for x in zip(layer_weights, weights.values())]
                            tune_metrics['layer_weights'] = layer_weights
                            params = defaultdict()
                            for n, p in model.named_parameters():
                                if "bn" not in n:
                                    params[n] = p 
                            params_weights = []
                            for k, v in params.items():
                                if k in weights.keys():
                                    params_weights.append({"params": params[k], "lr": weights[k]*lr})
                                else:
                                    params_weights.append({"params": params[k], "lr": 0.0})
                            opt = optim.Adam(params_weights, lr=lr, weight_decay=wd)
                        
                        elif cfg.args.auto_tune in ['LIFT1' ,  'LIFT2', 'LIFT3']:
                            layer_id = int((epoch-1) /cfg.args.epochs* len(tune_options_orig))
                            if cfg.args.auto_tune == 'LIFT2':
                                layer_id = len(tune_options_orig) - 1-layer_id
                            elif cfg.args.auto_tune == 'LIFT3':
                                layer_id = random.randint(0, len(tune_options_orig)-1)
                            tune_option = tune_options_orig[layer_id]
                            tune_params_dict = tune_params_dict_func(cfg, model)
                            params_list = list(itertools.chain(*tune_params_dict[tune_option]))
                            opt = optim.Adam(params_list, lr=lr, weight_decay=wd)
                        elif cfg.args.auto_tune == 'LISA':
                            trainable_params = []
                            train_probability = 0.3
                            if cfg.args.model == 'resnet50':
                                trainable_params += list(model.model.conv1.parameters())
                                trainable_params += list(model.model.layer1.parameters())
                                trainable_params += list(model.model.fc.parameters())     
                                for layer in [model.model.layer2, model.model.layer3, model.model.layer4]:
                                    if random.random() < train_probability:
                                        trainable_params += list(layer.parameters())
                            elif cfg.args.model == 'resnet26':
                                trainable_params += list(model.conv1.parameters()) 
                                trainable_params += list(model.block1.parameters()) 
                                trainable_params += list(model.fc.parameters())     
                                for layer in [model.block2, model.block3]:
                                    if random.random() < train_probability:
                                        trainable_params += list(layer.parameters())
                            opt = optim.Adam(trainable_params, lr=lr, weight_decay=wd)
                        else:
                            # Log rough fraction of parameters being tuned
                            no_weight = 0
                            for elt in params_weights:
                                if elt['lr'] == 0.:
                                    no_weight += elt['params'][0].flatten().shape[0]
                            total_params = sum(p.numel() for p in model.parameters())
                            tune_metrics['frac_params'].append((total_params-no_weight)/total_params)
                            print(f"Tuning {(total_params-no_weight)} out of {total_params} total")

                    if cfg.args.opt != 'EMSA':    
                        acc_tr, loss_tr, grad_magnitudes = train(model, loaders["train"], criterion, opt, cfg, orig_model=orig_model)

                    # For layer selection
                    # layer_select( model, criterion, loaders["train"], device, 0, optimizer, emsa_optimizer,wandb=wandb, LMSA = True)

                    if cfg.args.opt == 'EMSA':
                        acc_tr, loss_tr, grad_magnitudes = train_one_epoch(model, None, criterion, loaders["train"], device, epoch, None, optimizer, emsa_optimizer, block_id[num], wandb=False)
                    acc_te, loss_te = test(model, loaders["test"], criterion, cfg)
                    acc_val, loss_val = test(model, loaders["val"], criterion, cfg)
                    tune_metrics["acc_train"].append(acc_tr)
                    tune_metrics["acc_val"].append(acc_val)
                    tune_metrics["acc_te"].append(acc_te)
                    log_dict = {
                        f"{tune_option}/train/acc": acc_tr,
                        f"{tune_option}/train/loss": loss_tr,
                        f"{tune_option}/val/acc": acc_val,
                        f"{tune_option}/val/loss": loss_val,
                        f"{tune_option}/test/acc": acc_te,
                        f"{tune_option}/test/loss": loss_te,
                        "epoch": epoch,
                    }
                    print(f"Epoch {epoch:2d} Train acc: {acc_tr:.4f}, Val acc: {acc_val:.4f}")

                    if cfg.wandb.use:
                        wandb.log(log_dict)

                    if acc_val > best_val_acc:
                        best_val_acc = acc_val
                        best_model = copy.deepcopy(model)


                tune_metrics["lr_tested"].append(lr)
                tune_metrics["wd_tested"].append(wd)
                # break
            # torch.save(best_model.state_dict(), 'resnet50_living17_LP-FT.pth')
            # Get test acc according to best val acc
            best_run_idx = np.argmax(np.array(tune_metrics["acc_val"]))
            best_testacc = tune_metrics["acc_te"][best_run_idx]
            best_lr_wd = best_run_idx // (cfg.args.epochs)

            print(
                f"Best epoch: {best_run_idx % (cfg.args.epochs)+1}, Test Acc: {best_testacc}"
            )

            data = {
                "corruption_type": corruption_type,
                "train_mode": cfg.args.train_mode,
                "tune_option": tune_option,
                "auto_tune": cfg.args.auto_tune,
                "train_n": cfg.args.train_n,
                "seed": cfg.args.seed,
                "lr": tune_metrics["lr_tested"][best_lr_wd],
                "wd": tune_metrics["wd_tested"][best_lr_wd],
                "val_acc": tune_metrics["acc_val"][best_run_idx],
                "best_testacc": best_testacc,
            }

            recorded = False
            fieldnames = data.keys()
            csv_file_name = f"{cfg.args.log_dir}/results_seed{cfg.args.seed}.csv"
            write_header = True if not os.path.exists(csv_file_name) else False
            while not recorded:
                try:
                    with open(csv_file_name, "a") as f:
                        csv_writer = csv.DictWriter(f, fieldnames=fieldnames, restval=0.0)
                        if write_header:
                            csv_writer.writeheader()
                        csv_writer.writerow(data)
                    recorded = True
                except:
                    time.sleep(5)


if __name__ == "__main__":
    main()
